diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
334 files changed, 17455 insertions, 15152 deletions
diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala index 5ff07314..7adb28af 100644 --- a/src/main/scala/firrtl/AddDescriptionNodes.scala +++ b/src/main/scala/firrtl/AddDescriptionNodes.scala @@ -12,7 +12,7 @@ import firrtl.options.Dependency * Usually, we would like to emit these descriptions in some way. */ sealed trait DescriptionAnnotation extends Annotation { - def target: Target + def target: Target def description: String } @@ -24,7 +24,7 @@ sealed trait DescriptionAnnotation extends Annotation { case class DocStringAnnotation(target: Target, description: String) extends DescriptionAnnotation { def update(renames: RenameMap): Seq[DocStringAnnotation] = { renames.get(target) match { - case None => Seq(this) + case None => Seq(this) case Some(seq) => seq.map(n => this.copy(target = n)) } } @@ -38,7 +38,7 @@ case class DocStringAnnotation(target: Target, description: String) extends Desc case class AttributeAnnotation(target: Target, description: String) extends DescriptionAnnotation { def update(renames: RenameMap): Seq[AttributeAnnotation] = { renames.get(target) match { - case None => Seq(this) + case None => Seq(this) case Some(seq) => seq.map(n => this.copy(target = n)) } } @@ -78,18 +78,20 @@ case class Attribute(string: StringLit) extends Description { * @param descriptions * @param stmt the encapsulated statement */ -private case class DescribedStmt(descriptions: Seq[Description], stmt: Statement) extends Statement with HasDescription { +private case class DescribedStmt(descriptions: Seq[Description], stmt: Statement) + extends Statement + with HasDescription { override def serialize: String = s"${descriptions.map(_.serialize).mkString("\n")}\n${stmt.serialize}" - def mapStmt(f: Statement => Statement): Statement = f(stmt) - def mapExpr(f: Expression => Expression): Statement = this.copy(stmt = stmt.mapExpr(f)) - def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f)) - def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f)) - def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f)) - def foreachStmt(f: Statement => Unit): Unit = f(stmt) - def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f) - def foreachType(f: Type => Unit): Unit = stmt.foreachType(f) - def foreachString(f: String => Unit): Unit = stmt.foreachString(f) - def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f) + def mapStmt(f: Statement => Statement): Statement = f(stmt) + def mapExpr(f: Expression => Expression): Statement = this.copy(stmt = stmt.mapExpr(f)) + def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f)) + def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f)) + def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f)) + def foreachStmt(f: Statement => Unit): Unit = f(stmt) + def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f) + def foreachType(f: Type => Unit): Unit = stmt.foreachType(f) + def foreachString(f: String => Unit): Unit = stmt.foreachString(f) + def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f) } /** @@ -98,21 +100,24 @@ private case class DescribedStmt(descriptions: Seq[Description], stmt: Statement * @param portDescriptions list of descriptions for the module's ports * @param mod the encapsulated module */ -private case class DescribedMod(descriptions: Seq[Description], +private case class DescribedMod( + descriptions: Seq[Description], portDescriptions: Map[String, Seq[Description]], - mod: DefModule) extends DefModule with HasDescription { + mod: DefModule) + extends DefModule + with HasDescription { val info = mod.info val name = mod.name val ports = mod.ports override def serialize: String = s"${descriptions.map(_.serialize).mkString("\n")}\n${mod.serialize}" - def mapStmt(f: Statement => Statement): DefModule = this.copy(mod = mod.mapStmt(f)) - def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f)) - def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f)) - def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f)) - def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f) - def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f) - def foreachString(f: String => Unit): Unit = mod.foreachString(f) - def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f) + def mapStmt(f: Statement => Statement): DefModule = this.copy(mod = mod.mapStmt(f)) + def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f)) + def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f)) + def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f)) + def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f) + def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f) + def foreachString(f: String => Unit): Unit = mod.foreachString(f) + def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f) } /** Wraps modules or statements with their respective described nodes. Descriptions come from [[DescriptionAnnotation]]. @@ -125,17 +130,19 @@ private case class DescribedMod(descriptions: Seq[Description], class AddDescriptionNodes extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.PropagatePresetAnnotations], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(firrtl.passes.VerilogPrep) ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.PropagatePresetAnnotations], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(firrtl.passes.VerilogPrep) + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -149,18 +156,22 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration { case d: IsDeclaration => Some(d.name) case _ => None } - val descs = sname.flatMap({ case name => - compMap.get(name) + val descs = sname.flatMap({ + case name => + compMap.get(name) }) (descs, s) match { case (Some(d), DescribedStmt(prevDescs, ss)) => DescribedStmt(prevDescs ++ d, ss) - case (Some(d), ss) => DescribedStmt(d, ss) - case (None, _) => s + case (Some(d), ss) => DescribedStmt(d, ss) + case (None, _) => s } } - def onModule(modMap: Map[String, Seq[Description]], compMaps: Map[String, Map[String, Seq[Description]]]) - (mod: DefModule): DefModule = { + def onModule( + modMap: Map[String, Seq[Description]], + compMaps: Map[String, Map[String, Seq[Description]]] + )(mod: DefModule + ): DefModule = { val compMap = compMaps.getOrElse(mod.name, Map()) val newMod = mod.mapStmt(onStmt(compMap)) val portDesc = mod.ports.collect { @@ -210,14 +221,18 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration { rest ++ doc ++ attr } - def collectMaps(annos: Seq[Annotation]): (Map[String, Seq[Description]], Map[String, Map[String, Seq[Description]]]) = { + def collectMaps( + annos: Seq[Annotation] + ): (Map[String, Seq[Description]], Map[String, Map[String, Seq[Description]]]) = { val modList = annos.collect { case DocStringAnnotation(ModuleTarget(_, m), desc) => (m, DocString(StringLit.unescape(desc))) case AttributeAnnotation(ModuleTarget(_, m), desc) => (m, Attribute(StringLit.unescape(desc))) } // map field 1 (module name) -> field 2 (a list of Descriptions) - val modMap = modList.groupBy(_._1).mapValues(_.map(_._2)) + val modMap = modList + .groupBy(_._1) + .mapValues(_.map(_._2)) // and then merge like descriptions (e.g. multiple docstrings into one big docstring) .mapValues(mergeDescriptions) @@ -229,11 +244,16 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration { } // map field 1 (name) -> a map that we build - val compMap = compList.groupBy(_._1).mapValues( - // map field 2 (component name) -> field 3 (a list of Descriptions) - _.groupBy(_._2).mapValues(_.map(_._3)) - // and then merge like descriptions (e.g. multiple docstrings into one big docstring) - .mapValues(mergeDescriptions).toMap) + val compMap = compList + .groupBy(_._1) + .mapValues( + // map field 2 (component name) -> field 3 (a list of Descriptions) + _.groupBy(_._2) + .mapValues(_.map(_._3)) + // and then merge like descriptions (e.g. multiple docstrings into one big docstring) + .mapValues(mergeDescriptions) + .toMap + ) (modMap.toMap, compMap.toMap) } diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index db4853a2..ec09cace 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -13,7 +13,7 @@ import firrtl.annotations._ import firrtl.ir.Circuit import firrtl.Utils.throwInternalError import firrtl.annotations.transforms.{EliminateTargetPaths, ResolvePaths} -import firrtl.options.{DependencyAPI, Dependency, StageUtils, TransformLike} +import firrtl.options.{Dependency, DependencyAPI, StageUtils, TransformLike} import firrtl.stage.Forms /** Container of all annotations for a Firrtl compiler */ @@ -34,19 +34,22 @@ object AnnotationSeq { * Generally only a return value from [[Transform]]s */ case class CircuitState( - circuit: Circuit, - form: CircuitForm, - annotations: AnnotationSeq, - renames: Option[RenameMap]) { + circuit: Circuit, + form: CircuitForm, + annotations: AnnotationSeq, + renames: Option[RenameMap]) { /** Helper for getting just an emitted circuit */ def emittedCircuitOption: Option[EmittedCircuit] = - emittedComponents collectFirst { case x: EmittedCircuit => x } + emittedComponents.collectFirst { case x: EmittedCircuit => x } + /** Helper for getting an [[EmittedCircuit]] when it is known to exist */ def getEmittedCircuit: EmittedCircuit = emittedCircuitOption match { case Some(emittedCircuit) => emittedCircuit case None => - throw new FirrtlInternalException(s"No EmittedCircuit found! Did you delete any annotations?\n$deletedAnnotations") + throw new FirrtlInternalException( + s"No EmittedCircuit found! Did you delete any annotations?\n$deletedAnnotations" + ) } /** Helper function for extracting emitted components from annotations */ @@ -64,7 +67,7 @@ case class CircuitState( def resolvePaths(targets: Seq[CompleteTarget]): CircuitState = targets match { case Nil => this case _ => - val newCS = new EliminateTargetPaths().runTransform(this.copy(annotations = ResolvePaths(targets) +: annotations )) + val newCS = new EliminateTargetPaths().runTransform(this.copy(annotations = ResolvePaths(targets) +: annotations)) newCS.copy(form = form) } @@ -73,8 +76,8 @@ case class CircuitState( * @return */ def resolvePathsOf(annoClasses: Class[_]*): CircuitState = { - val targets = getAnnotationsOf(annoClasses:_*).flatMap(_.getTargets) - if(targets.nonEmpty) resolvePaths(targets.flatMap{_.getComplete}) else this + val targets = getAnnotationsOf(annoClasses: _*).flatMap(_.getTargets) + if (targets.nonEmpty) resolvePaths(targets.flatMap { _.getComplete }) else this } /** Returns all annotations which are of a class in annoClasses @@ -105,7 +108,8 @@ object CircuitState { */ @deprecated( "Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + "FIRRTL 1.3" +) sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] { // Note that value is used only to allow comparisons def compare(that: CircuitForm): Int = this.value - that.value @@ -125,7 +129,8 @@ sealed abstract class CircuitForm(private val value: Int) extends Ordered[Circui */ @deprecated( "Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + "FIRRTL 1.3" +) final case object ChirrtlForm extends CircuitForm(value = 3) { val outputSuffix: String = ".fir" } @@ -139,7 +144,8 @@ final case object ChirrtlForm extends CircuitForm(value = 3) { */ @deprecated( "Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + "FIRRTL 1.3" +) final case object HighForm extends CircuitForm(2) { val outputSuffix: String = ".hi.fir" } @@ -153,7 +159,8 @@ final case object HighForm extends CircuitForm(2) { */ @deprecated( "Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + "FIRRTL 1.3" +) final case object MidForm extends CircuitForm(1) { val outputSuffix: String = ".mid.fir" } @@ -166,7 +173,8 @@ final case object MidForm extends CircuitForm(1) { */ @deprecated( "Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + "FIRRTL 1.3" +) final case object LowForm extends CircuitForm(0) { val outputSuffix: String = ".lo.fir" } @@ -184,7 +192,8 @@ final case object LowForm extends CircuitForm(0) { */ @deprecated( "Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + "FIRRTL 1.3" +) final case object UnknownForm extends CircuitForm(-1) { override def compare(that: CircuitForm): Int = { sys.error("Illegal to compare UnknownForm"); 0 } @@ -212,12 +221,15 @@ private[firrtl] object Transform { logger.info(s"Form: ${after.form}") logger.trace(s"Annotations:") logger.trace { - JsonProtocol.serializeTry(remappedAnnotations).recoverWith { - case NonFatal(e) => - val msg = s"Exception thrown during Annotation serialization:\n " + - e.toString.replaceAll("\n", "\n ") - Try(msg) - }.get + JsonProtocol + .serializeTry(remappedAnnotations) + .recoverWith { + case NonFatal(e) => + val msg = s"Exception thrown during Annotation serialization:\n " + + e.toString.replaceAll("\n", "\n ") + Try(msg) + } + .get } logger.trace(s"Circuit:\n${after.circuit.serialize}") @@ -234,17 +246,18 @@ private[firrtl] object Transform { * @return the updated annotations */ def propagateAnnotations( - name: String, - logger: Logger, - inAnno: AnnotationSeq, - resAnno: AnnotationSeq, - renameOpt: Option[RenameMap]): AnnotationSeq = { + name: String, + logger: Logger, + inAnno: AnnotationSeq, + resAnno: AnnotationSeq, + renameOpt: Option[RenameMap] + ): AnnotationSeq = { val newAnnotations = { val inSet = mutable.LinkedHashSet() ++ inAnno val resSet = mutable.LinkedHashSet() ++ resAnno val deleted = (inSet -- resSet).map { case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno) - case anno => DeletedAnnotation(name, anno) + case anno => DeletedAnnotation(name, anno) } val created = resSet -- inSet val unchanged = resSet & inSet @@ -260,7 +273,7 @@ private[firrtl] object Transform { remappedAnnos.foreach { remapped => val set = remapped2original.getOrElseUpdate(remapped, mutable.LinkedHashSet.empty[Annotation]) set += anno - if(set.size > 1) keysOfNote += remapped + if (set.size > 1) keysOfNote += remapped } remappedAnnos }.toSeq @@ -280,15 +293,11 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform def name: String = this.getClass.getName /** The [[firrtl.CircuitForm]] that this transform requires to operate on */ - @deprecated( - "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3") def inputForm: CircuitForm /** The [[firrtl.CircuitForm]] that this transform outputs */ - @deprecated( - "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3") def outputForm: CircuitForm /** Perform the transform, encode renaming with RenameMap, and can @@ -324,8 +333,9 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform Dependency[SystemVerilogEmitter] :: Nil val emitters = inputForm match { - case C => Dependency[ChirrtlEmitter] :: Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters - case H => Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters + case C => + Dependency[ChirrtlEmitter] :: Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters + case H => Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters case M => Dependency[MiddleFirrtlEmitter] :: lowEmitters case L => lowEmitters case U => Nil @@ -334,9 +344,9 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform val selfDep = Dependency.fromTransform(this) inputForm match { - case C => (fullCompilerSet ++ emitters - selfDep).toSeq - case H => (fullCompilerSet -- Forms.Deduped ++ emitters - selfDep).toSeq - case M => (fullCompilerSet -- Forms.MidForm ++ emitters - selfDep).toSeq + case C => (fullCompilerSet ++ emitters - selfDep).toSeq + case H => (fullCompilerSet -- Forms.Deduped ++ emitters - selfDep).toSeq + case M => (fullCompilerSet -- Forms.MidForm ++ emitters - selfDep).toSeq case L => (fullCompilerSet -- Forms.LowFormOptimized ++ emitters - selfDep).toSeq case U => Nil } @@ -347,9 +357,9 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform override def invalidates(a: Transform): Boolean = { (inputForm, outputForm) match { - case (U, _) | (_, U) => true // invalidate everything + case (U, _) | (_, U) => true // invalidate everything case (i, o) if i >= o => false // invalidate nothing - case (_, C) => true // invalidate everything + case (_, C) => true // invalidate everything case (_, H) => highOutputInvalidates(Dependency.fromTransform(a)) case (_, M) => midOutputInvalidates(Dependency.fromTransform(a)) case (_, L) => false // invalidate nothing @@ -386,7 +396,7 @@ abstract class SeqTransform extends Transform with SeqTransformBased { /* require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") - */ + */ val ret = runTransforms(state) CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } @@ -401,7 +411,7 @@ trait ResolvedAnnotationPaths { val annotationClasses: Traversable[Class[_]] override def prepare(state: CircuitState): CircuitState = { - state.resolvePathsOf(annotationClasses.toSeq:_*) + state.resolvePathsOf(annotationClasses.toSeq: _*) } } @@ -419,6 +429,7 @@ trait Emitter extends Transform { @deprecated("This will be removed in 1.4", "FIRRTL 1.3") object CompilerUtils extends LazyLogging { + /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit * * @param inputForm [[CircuitForm]] to lower from @@ -427,7 +438,8 @@ object CompilerUtils extends LazyLogging { */ @deprecated( "Use a TransformManager requesting which transforms you want to run. This will be removed in 1.4.", - "FIRRTL 1.3") + "FIRRTL 1.3" + ) def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = { // If outputForm is equal-to or higher than inputForm, nothing to lower if (outputForm >= inputForm) { @@ -437,10 +449,15 @@ object CompilerUtils extends LazyLogging { case ChirrtlForm => Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm) case HighForm => - Seq(new IRToWorkingIR, new ResolveAndCheck, new firrtl.transforms.DedupModules, new HighFirrtlToMiddleFirrtl) ++ + Seq( + new IRToWorkingIR, + new ResolveAndCheck, + new firrtl.transforms.DedupModules, + new HighFirrtlToMiddleFirrtl + ) ++ getLoweringTransforms(MidForm, outputForm) - case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) - case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above + case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm) + case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above case UnknownForm => throwInternalError("getLoweringTransforms - UnknownForm") // should be caught by if above } } @@ -479,28 +496,32 @@ object CompilerUtils extends LazyLogging { */ @deprecated( "Use a TransformManager requesting which transforms you want to run. This will be removed in 1.4.", - "FIRRTL 1.3") + "FIRRTL 1.3" + ) def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = { - custom - .sortWith{ - case (a, b) => (a, b) match { + custom.sortWith { + case (a, b) => + (a, b) match { case (_: Emitter, _: Emitter) => false - case (_, _: Emitter) => true - case _ => false }} - .foldLeft(lowering) { case (transforms, xform) => - val index = transforms lastIndexWhere (_.outputForm == xform.inputForm) - assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front - s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform") - val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index - front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back + case (_, _: Emitter) => true + case _ => false + } } + .foldLeft(lowering) { + case (transforms, xform) => + val index = transforms.lastIndexWhere(_.outputForm == xform.inputForm) + assert( + index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front + s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform" + ) + val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index + front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back + } } } -@deprecated( - "Migrate to firrtl.stage.transforms.Compiler. This will be removed in 1.4.", - "FIRRTL 1.3") +@deprecated("Migrate to firrtl.stage.transforms.Compiler. This will be removed in 1.4.", "FIRRTL 1.3") trait Compiler extends Transform with DependencyAPIMigration { def emitter: Emitter @@ -511,15 +532,17 @@ trait Compiler extends Transform with DependencyAPIMigration { def transforms: Seq[Transform] final override def execute(state: CircuitState): CircuitState = - new stage.transforms.Compiler ( + new stage.transforms.Compiler( targets = (transforms :+ emitter).map(Dependency.fromTransform), currentState = prerequisites, knownObjects = (transforms :+ emitter).toSet ).execute(state) - require(transforms.size >= 1, - s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " + - "Use IdentityTransform if you need an identity/no-op transform.") + require( + transforms.size >= 1, + s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " + + "Use IdentityTransform if you need an identity/no-op transform." + ) /** Perform compilation * @@ -531,10 +554,9 @@ trait Compiler extends Transform with DependencyAPIMigration { @deprecated( "Migrate to '(new FirrtlStage).execute(args: Array[String], annotations: AnnotationSeq)'." + "This will be removed in 1.4.", - "FIRRTL 1.0") - def compile(state: CircuitState, - writer: Writer, - customTransforms: Seq[Transform] = Seq.empty): CircuitState = { + "FIRRTL 1.0" + ) + def compile(state: CircuitState, writer: Writer, customTransforms: Seq[Transform] = Seq.empty): CircuitState = { val finalState = compileAndEmit(state, customTransforms) writer.write(finalState.getEmittedCircuit.value) finalState @@ -555,9 +577,9 @@ trait Compiler extends Transform with DependencyAPIMigration { @deprecated( "Migrate to '(new FirrtlStage).execute(args: Array[String], annotations: AnnotationSeq)'." + "This will be removed in 1.4.", - "FIRRTL 1.3.3") - def compileAndEmit(state: CircuitState, - customTransforms: Seq[Transform] = Seq.empty): CircuitState = { + "FIRRTL 1.3.3" + ) + def compileAndEmit(state: CircuitState, customTransforms: Seq[Transform] = Seq.empty): CircuitState = { val emitAnno = EmitCircuitAnnotation(emitter.getClass) compile(state.copy(annotations = emitAnno +: state.annotations), emitter +: customTransforms) } @@ -574,9 +596,10 @@ trait Compiler extends Transform with DependencyAPIMigration { @deprecated( "Migrate to '(new FirrtlStage).execute(args: Array[String], annotations: AnnotationSeq)'." + "This will be removed in 1.4.", - "FIRRTL 1.3.3") + "FIRRTL 1.3.3" + ) def compile(state: CircuitState, customTransforms: Seq[Transform]): CircuitState = { - val transformManager = new stage.transforms.Compiler ( + val transformManager = new stage.transforms.Compiler( targets = (emitter +: customTransforms ++: transforms).map(Dependency.fromTransform), currentState = prerequisites, knownObjects = (transforms :+ emitter).toSet diff --git a/src/main/scala/firrtl/DependencyAPIMigration.scala b/src/main/scala/firrtl/DependencyAPIMigration.scala index 6a5ff642..dc5957f2 100644 --- a/src/main/scala/firrtl/DependencyAPIMigration.scala +++ b/src/main/scala/firrtl/DependencyAPIMigration.scala @@ -17,14 +17,10 @@ import firrtl.stage.TransformManager.TransformDependency */ trait DependencyAPIMigration { this: Transform => - @deprecated( - "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3") final override def inputForm: CircuitForm = UnknownForm - @deprecated( - "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", - "FIRRTL 1.3") + @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3") final override def outputForm: CircuitForm = UnknownForm override def prerequisites: Seq[TransformDependency] = Seq.empty diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 2050b235..28eb2d6a 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -13,7 +13,6 @@ import firrtl.stage.phases.DriverCompatibility import firrtl.options.{Dependency, Phase, PhaseManager, StageUtils, Viewer} import firrtl.options.phases.DeletedWrapper - /** * The driver provides methods to access the firrtl compiler. * Invoke the compiler with either a FirrtlExecutionOption @@ -37,6 +36,7 @@ import firrtl.options.phases.DeletedWrapper */ @deprecated("Use firrtl.stage.FirrtlStage", "1.2") object Driver { + /** Print a warning message * * @param message error message @@ -71,7 +71,7 @@ object Driver { * @return Annotations read from files */ def getAnnotations( - optionsManager: ExecutionOptionsManager with HasFirrtlOptions + optionsManager: ExecutionOptionsManager with HasFirrtlOptions ): Seq[Annotation] = { val firrtlConfig = optionsManager.firrtlOptions @@ -92,11 +92,11 @@ object Driver { // Warnings to get people to change to drop old API if (firrtlConfig.annotationFileNameOverride.nonEmpty) { val msg = "annotationFileNameOverride has been removed, file will be ignored! " + - "Use annotationFileNames" + "Use annotationFileNames" dramaticError(msg) } else if (usingImplicitAnnoFile) { val msg = "Implicit .anno file from top-name has been removed, file will be ignored!\n" + - (" "*9) + "Use explicit -faf option or annotationFileNames" + (" " * 9) + "Use explicit -faf option or annotationFileNames" dramaticError(msg) } @@ -126,7 +126,7 @@ object Driver { private def getFileExtension(filename: String): FileExtension = filename.drop(filename.lastIndexOf('.')) match { case ".pb" => ProtoBufFile - case _ => FirrtlFile // Default to FIRRTL File + case _ => FirrtlFile // Default to FIRRTL File } // Useful for handling erros in the options @@ -143,7 +143,8 @@ object Driver { val circuitSources = Map( "firrtlSource" -> firrtlConfig.firrtlSource.isDefined, "firrtlCircuit" -> firrtlConfig.firrtlCircuit.isDefined, - "inputFileNameOverride" -> firrtlConfig.inputFileNameOverride.nonEmpty) + "inputFileNameOverride" -> firrtlConfig.inputFileNameOverride.nonEmpty + ) if (circuitSources.values.count(x => x) > 1) { val msg = circuitSources.collect { case (s, true) => s }.mkString(" and ") + " are set, only 1 can be set at a time!" @@ -157,8 +158,9 @@ object Driver { } if ( optionsManager.topName.isEmpty && - firrtlConfig.inputFileNameOverride.nonEmpty && - firrtlConfig.outputFileNameOverride.isEmpty) { + firrtlConfig.inputFileNameOverride.nonEmpty && + firrtlConfig.outputFileNameOverride.isEmpty + ) { val message = "inputFileName set but neither top-name or output-file-override is set" throw new OptionsException(message) } @@ -167,10 +169,9 @@ object Driver { // TODO What does InfoMode mean to ProtoBuf? getFileExtension(inputFileName) match { case ProtoBufFile => proto.FromProto.fromFile(inputFileName) - case FirrtlFile => Parser.parseFile(inputFileName, firrtlConfig.infoMode) + case FirrtlFile => Parser.parseFile(inputFileName, firrtlConfig.infoMode) } - } - catch { + } catch { case _: FileNotFoundException => val message = s"Input file $inputFileName not found" throw new OptionsException(message) @@ -195,20 +196,23 @@ object Driver { val phases: Seq[Phase] = { import DriverCompatibility._ new PhaseManager( - List( Dependency[AddImplicitFirrtlFile], - Dependency[AddImplicitAnnotationFile], - Dependency[AddImplicitOutputFile], - Dependency[AddImplicitEmitter], - Dependency[FirrtlStage] )) - .transformOrder + List( + Dependency[AddImplicitFirrtlFile], + Dependency[AddImplicitAnnotationFile], + Dependency[AddImplicitOutputFile], + Dependency[AddImplicitEmitter], + Dependency[FirrtlStage] + ) + ).transformOrder .map(DeletedWrapper(_)) } - val annosx = try { - phases.foldLeft(annos)( (a, p) => p.transform(a) ) - } catch { - case e: firrtl.options.OptionsException => return FirrtlExecutionFailure(e.message) - } + val annosx = + try { + phases.foldLeft(annos)((a, p) => p.transform(a)) + } catch { + case e: firrtl.options.OptionsException => return FirrtlExecutionFailure(e.message) + } Viewer[FirrtlExecutionResult].view(annosx) } @@ -223,7 +227,7 @@ object Driver { def execute(args: Array[String]): FirrtlExecutionResult = { val optionsManager = new ExecutionOptionsManager("firrtl") with HasFirrtlOptions - if(optionsManager.parse(args)) { + if (optionsManager.parse(args)) { execute(optionsManager) match { case success: FirrtlExecutionSuccess => success @@ -233,8 +237,7 @@ object Driver { case result => throwInternalError(s"Error: Unknown Firrtl Execution result $result") } - } - else { + } else { FirrtlExecutionFailure("Could not parser command line options") } } diff --git a/src/main/scala/firrtl/EmissionOption.scala b/src/main/scala/firrtl/EmissionOption.scala index 91db1f53..d097e14a 100644 --- a/src/main/scala/firrtl/EmissionOption.scala +++ b/src/main/scala/firrtl/EmissionOption.scala @@ -2,8 +2,8 @@ package firrtl -/** - * Base type for emission customization options +/** + * Base type for emission customization options * NOTE: all the following traits must be mixed with SingleTargetAnnotation[T <: Named] * in order to be taken into account in the Emitter */ @@ -24,40 +24,37 @@ case object MemoryEmissionOptionDefault extends MemoryEmissionOption /** Emission customization options for registers */ trait RegisterEmissionOption extends EmissionOption { + /** when true the reset init value will be used to emit a bitstream preset */ - def useInitAsPreset : Boolean = false - + def useInitAsPreset: Boolean = false + /** when true the initial randomization is disabled for this register */ - def disableRandomization : Boolean = false + def disableRandomization: Boolean = false } /** default Emitter behavior for registers */ -case object RegisterEmissionOptionDefault extends RegisterEmissionOption - +case object RegisterEmissionOptionDefault extends RegisterEmissionOption /** Emission customization options for IO ports */ -trait PortEmissionOption extends EmissionOption +trait PortEmissionOption extends EmissionOption /** default Emitter behavior for IO ports */ -case object PortEmissionOptionDefault extends PortEmissionOption - +case object PortEmissionOptionDefault extends PortEmissionOption /** Emission customization options for wires */ -trait WireEmissionOption extends EmissionOption +trait WireEmissionOption extends EmissionOption /** default Emitter behavior for wires */ -case object WireEmissionOptionDefault extends WireEmissionOption - +case object WireEmissionOptionDefault extends WireEmissionOption /** Emission customization options for nodes */ -trait NodeEmissionOption extends EmissionOption +trait NodeEmissionOption extends EmissionOption /** default Emitter behavior for nodes */ -case object NodeEmissionOptionDefault extends NodeEmissionOption - +case object NodeEmissionOptionDefault extends NodeEmissionOption /** Emission customization options for connect */ trait ConnectEmissionOption extends EmissionOption /** default Emitter behavior for connect */ -case object ConnectEmissionOptionDefault extends ConnectEmissionOption +case object ConnectEmissionOptionDefault extends ConnectEmissionOption diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index ae9a7dad..843c76a4 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -37,28 +37,38 @@ object EmitCircuitAnnotation extends HasShellOptions { val options = Seq( new ShellOption[String]( longOption = "emit-circuit", - toAnnotationSeq = (a: String) => a match { - case "chirrtl" => Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter), - EmitCircuitAnnotation(classOf[ChirrtlEmitter])) - case "high" => Seq(RunFirrtlTransformAnnotation(new HighFirrtlEmitter), - EmitCircuitAnnotation(classOf[HighFirrtlEmitter])) - case "middle" => Seq(RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter), - EmitCircuitAnnotation(classOf[MiddleFirrtlEmitter])) - case "low" => Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), - EmitCircuitAnnotation(classOf[LowFirrtlEmitter])) - case "verilog" | "mverilog" => Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), - EmitCircuitAnnotation(classOf[VerilogEmitter])) - case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter), - EmitCircuitAnnotation(classOf[SystemVerilogEmitter])) - case "experimental-btor2" => Seq(RunFirrtlTransformAnnotation(new Btor2Emitter), - EmitCircuitAnnotation(classOf[Btor2Emitter])) - case "experimental-smt2" => Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter), - EmitCircuitAnnotation(classOf[SMTLibEmitter])) - case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") }, + toAnnotationSeq = (a: String) => + a match { + case "chirrtl" => + Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter), EmitCircuitAnnotation(classOf[ChirrtlEmitter])) + case "high" => + Seq(RunFirrtlTransformAnnotation(new HighFirrtlEmitter), EmitCircuitAnnotation(classOf[HighFirrtlEmitter])) + case "middle" => + Seq( + RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter), + EmitCircuitAnnotation(classOf[MiddleFirrtlEmitter]) + ) + case "low" => + Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitCircuitAnnotation(classOf[LowFirrtlEmitter])) + case "verilog" | "mverilog" => + Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitCircuitAnnotation(classOf[VerilogEmitter])) + case "sverilog" => + Seq( + RunFirrtlTransformAnnotation(new SystemVerilogEmitter), + EmitCircuitAnnotation(classOf[SystemVerilogEmitter]) + ) + case "experimental-btor2" => + Seq(RunFirrtlTransformAnnotation(new Btor2Emitter), EmitCircuitAnnotation(classOf[Btor2Emitter])) + case "experimental-smt2" => + Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter), EmitCircuitAnnotation(classOf[SMTLibEmitter])) + case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") + }, helpText = "Run the specified circuit emitter (all modules in one file)", shortOption = Some("E"), // the experimental options are intentionally excluded from the help message - helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) ) + helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") + ) + ) } @@ -67,30 +77,43 @@ object EmitAllModulesAnnotation extends HasShellOptions { val options = Seq( new ShellOption[String]( longOption = "emit-modules", - toAnnotationSeq = (a: String) => a match { - case "chirrtl" => Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter), - EmitAllModulesAnnotation(classOf[ChirrtlEmitter])) - case "high" => Seq(RunFirrtlTransformAnnotation(new HighFirrtlEmitter), - EmitAllModulesAnnotation(classOf[HighFirrtlEmitter])) - case "middle" => Seq(RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter), - EmitAllModulesAnnotation(classOf[MiddleFirrtlEmitter])) - case "low" => Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), - EmitAllModulesAnnotation(classOf[LowFirrtlEmitter])) - case "verilog" | "mverilog" => Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), - EmitAllModulesAnnotation(classOf[VerilogEmitter])) - case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter), - EmitAllModulesAnnotation(classOf[SystemVerilogEmitter])) - case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") }, + toAnnotationSeq = (a: String) => + a match { + case "chirrtl" => + Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter), EmitAllModulesAnnotation(classOf[ChirrtlEmitter])) + case "high" => + Seq( + RunFirrtlTransformAnnotation(new HighFirrtlEmitter), + EmitAllModulesAnnotation(classOf[HighFirrtlEmitter]) + ) + case "middle" => + Seq( + RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter), + EmitAllModulesAnnotation(classOf[MiddleFirrtlEmitter]) + ) + case "low" => + Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitAllModulesAnnotation(classOf[LowFirrtlEmitter])) + case "verilog" | "mverilog" => + Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitAllModulesAnnotation(classOf[VerilogEmitter])) + case "sverilog" => + Seq( + RunFirrtlTransformAnnotation(new SystemVerilogEmitter), + EmitAllModulesAnnotation(classOf[SystemVerilogEmitter]) + ) + case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") + }, helpText = "Run the specified module emitter (one file per module)", shortOption = Some("e"), - helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) ) + helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") + ) + ) } // ***** Annotations for results of emission ***** sealed abstract class EmittedComponent { - def name: String - def value: String + def name: String + def value: String def outputSuffix: String } sealed abstract class EmittedCircuit extends EmittedComponent @@ -147,7 +170,7 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em // Use list instead of set to maintain order val modules = mutable.ArrayBuffer.empty[DefModule] def onStmt(stmt: Statement): Unit = stmt match { - case DefInstance(_, _, name, _) => modules += map(name) + case DefInstance(_, _, name, _) => modules += map(name) case WDefInstance(_, _, name, _) => modules += map(name) case _: WDefInstanceConnector => throwInternalError(s"unrecognized statement: $stmt") case other => other.foreach(onStmt) @@ -157,24 +180,28 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em } val modMap = circuit.modules.map(m => m.name -> m).toMap // Turn each module into it's own circuit with it as the top and all instantied modules as ExtModules - circuit.modules collect { case m: Module => - val instModules = collectInstantiatedModules(m, modMap) - val extModules = instModules map { - case Module(info, name, ports, _) => ExtModule(info, name, ports, name, Seq.empty) - case ext: ExtModule => ext - } - val newCircuit = Circuit(m.info, extModules :+ m, m.name) - EmittedFirrtlModule(m.name, newCircuit.serialize, outputSuffix) + circuit.modules.collect { + case m: Module => + val instModules = collectInstantiatedModules(m, modMap) + val extModules = instModules.map { + case Module(info, name, ports, _) => ExtModule(info, name, ports, name, Seq.empty) + case ext: ExtModule => ext + } + val newCircuit = Circuit(m.info, extModules :+ m, m.name) + EmittedFirrtlModule(m.name, newCircuit.serialize, outputSuffix) } } override def execute(state: CircuitState): CircuitState = { val newAnnos = state.annotations.flatMap { case EmitCircuitAnnotation(a) if this.getClass == a => - Seq(EmittedFirrtlCircuitAnnotation( - EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, outputSuffix))) + Seq( + EmittedFirrtlCircuitAnnotation( + EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, outputSuffix) + ) + ) case EmitAllModulesAnnotation(a) if this.getClass == a => - emitAllModules(state.circuit) map (EmittedFirrtlModuleAnnotation(_)) + emitAllModules(state.circuit).map(EmittedFirrtlModuleAnnotation(_)) case _ => Seq() } state.copy(annotations = newAnnos ++ state.annotations) @@ -195,12 +222,12 @@ case class VRandom(width: BigInt) extends Expression { def nWords = (width + 31) / 32 def realWidth = nWords * 32 override def serialize: String = "RANDOM" - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () } class VerilogEmitter extends SeqTransform with Emitter { @@ -221,14 +248,16 @@ class VerilogEmitter extends SeqTransform with Emitter { else if (e2 == we(one)) e1.e1 else DoPrim(And, Seq(e1.e1, e2.e1), Nil, UIntType(IntWidth(1))) } - def wref(n: String, t: Type) = WRef(n, t, ExpKind, UnknownFlow) + def wref(n: String, t: Type) = WRef(n, t, ExpKind, UnknownFlow) def remove_root(ex: Expression): Expression = ex match { - case ex: WSubField => ex.expr match { - case (e: WSubField) => remove_root(e) - case (_: WRef) => WRef(ex.name, ex.tpe, InstanceKind, UnknownFlow) - } + case ex: WSubField => + ex.expr match { + case (e: WSubField) => remove_root(e) + case (_: WRef) => WRef(ex.name, ex.tpe, InstanceKind, UnknownFlow) + } case _ => throwInternalError(s"shouldn't be here: remove_root($ex)") } + /** Turn Params into Verilog Strings */ def stringify(param: Param): String = param match { case IntParam(name, value) => @@ -237,11 +266,11 @@ class VerilogEmitter extends SeqTransform with Emitter { s"$value" } else { val blen = value.bitLength - if (value > 0) s"$blen'd$value" else s"-${blen+1}'sd${value.abs}" + if (value > 0) s"$blen'd$value" else s"-${blen + 1}'sd${value.abs}" } s".$name($lit)" - case DoubleParam(name, value) => s".$name($value)" - case StringParam(name, value) => s".${name}(${value.verilogEscape})" + case DoubleParam(name, value) => s".$name($value)" + case StringParam(name, value) => s".${name}(${value.verilogEscape})" case RawStringParam(name, value) => s".$name($value)" } def stringify(tpe: GroundType): String = tpe match { @@ -249,16 +278,16 @@ class VerilogEmitter extends SeqTransform with Emitter { val wx = bitWidth(tpe) - 1 if (wx > 0) s"[$wx:0]" else "" case ClockType | AsyncResetType => "" - case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe") + case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe") } def emit(x: Any)(implicit w: Writer): Unit = { emit(x, 0) } def emit(x: Any, top: Int)(implicit w: Writer): Unit = { def cast(e: Expression): Any = e.tpe match { case (t: UIntType) => e - case (t: SIntType) => Seq("$signed(",e,")") - case ClockType => e + case (t: SIntType) => Seq("$signed(", e, ")") + case ClockType => e case AnalogType(_) => e - case _ => throwInternalError(s"unrecognized cast: $e") + case _ => throwInternalError(s"unrecognized cast: $e") } x match { case (e: DoPrim) => emit(op_stream(e), top + 1) @@ -269,186 +298,190 @@ class VerilogEmitter extends SeqTransform with Emitter { if (e.tpe == AsyncResetType) { throw EmitterException("Cannot emit async reset muxes directly") } - emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) + emit(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1) } - case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1) - case (e: WRef) => w write e.serialize - case (e: WSubField) => w write LowerTypes.loweredName(e) - case (e: WSubAccess) => w write s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]" - case (e: WSubIndex) => w write e.serialize - case (e: Literal) => v_print(e) - case (e: VRandom) => w write s"{${e.nWords}{`RANDOM}}" - case (t: GroundType) => w write stringify(t) + case (e: ValidIf) => emit(Seq(cast(e.value)), top + 1) + case (e: WRef) => w.write(e.serialize) + case (e: WSubField) => w.write(LowerTypes.loweredName(e)) + case (e: WSubAccess) => w.write(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]") + case (e: WSubIndex) => w.write(e.serialize) + case (e: Literal) => v_print(e) + case (e: VRandom) => w.write(s"{${e.nWords}{`RANDOM}}") + case (t: GroundType) => w.write(stringify(t)) case (t: VectorType) => emit(t.tpe, top + 1) - w write s"[${t.size - 1}:0]" - case (s: String) => w write s - case (i: Int) => w write i.toString - case (i: Long) => w write i.toString - case (i: BigInt) => w write i.toString - case (i: Info) => i match { - case NoInfo => // Do nothing - case f: FileInfo => - val escaped = FileInfo.escapedToVerilog(f.escaped) - w.write(s" // @[$escaped]") - case m: MultiInfo => - val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" ")) - w.write(s" // @[$escaped]") - } + w.write(s"[${t.size - 1}:0]") + case (s: String) => w.write(s) + case (i: Int) => w.write(i.toString) + case (i: Long) => w.write(i.toString) + case (i: BigInt) => w.write(i.toString) + case (i: Info) => + i match { + case NoInfo => // Do nothing + case f: FileInfo => + val escaped = FileInfo.escapedToVerilog(f.escaped) + w.write(s" // @[$escaped]") + case m: MultiInfo => + val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" ")) + w.write(s" // @[$escaped]") + } case (s: Seq[Any]) => - s foreach (emit(_, top + 1)) - if (top == 0) w write "\n" + s.foreach(emit(_, top + 1)) + if (top == 0) w.write("\n") case x => throwInternalError(s"trying to emit unsupported operator: $x") } } - //;------------- PASS ----------------- - def v_print(e: Expression)(implicit w: Writer) = e match { - case UIntLiteral(value, IntWidth(width)) => - w write s"$width'h${value.toString(16)}" - case SIntLiteral(value, IntWidth(width)) => - val stringLiteral = value.toString(16) - w write (stringLiteral.head match { - case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}" - case '-' => s"-$width'sh${stringLiteral.tail}" - case _ => s"$width'sh${stringLiteral}" - }) - case _ => throwInternalError(s"attempt to print unrecognized expression: $e") - } - - // NOTE: We emit SInts as regular Verilog unsigned wires/regs so the real type of any SInt - // reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary - // to ensure Verilog operations are signed. - def op_stream(doprim: DoPrim): Seq[Any] = { - // Cast to SInt, don't cast multiple times - def doCast(e: Expression): Any = e match { - case DoPrim(AsSInt, Seq(arg), _,_) => doCast(arg) - case slit: SIntLiteral => slit - case other => Seq("$signed(", other, ")") - } - def castIf(e: Expression): Any = { - if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) { - e.tpe match { - case _: SIntType => doCast(e) - case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim") - } - } else { - e - } - } - def cast(e: Expression): Any = doprim.tpe match { - case _: UIntType => e - case _: SIntType => doCast(e) - case _ => throwInternalError(s"Unexpected type for $e in $doprim") - } - def castAs(e: Expression): Any = e.tpe match { - case _: UIntType => e - case _: SIntType => doCast(e) - case _ => throwInternalError(s"Unexpected type for $e in $doprim") - } - def a0: Expression = doprim.args.head - def a1: Expression = doprim.args(1) - def c0: Int = doprim.consts.head.toInt - def c1: Int = doprim.consts(1).toInt - - def checkArgumentLegality(e: Expression): Unit = e match { - case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => - case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality) - case DoPrim(op, args, _,_) if isCast(op) => args.foreach(checkArgumentLegality) - case DoPrim(op, args, _,_) if isBitExtract(op) => args.foreach(checkArgumentLegality) - case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") - } - - def checkCatArgumentLegality(e: Expression): Unit = e match { - case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality) - case _ => checkArgumentLegality(e) - } - - def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = { - val a0Seq = a0 match { - case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) - case _ => Seq(cast(a0)) - } - val a1Seq = a1 match { - case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) - case _ => Seq(cast(a1)) - } - a0Seq ++ Seq(",") ++ a1Seq - } - - doprim.op match { - case Cat => doprim.args foreach(checkCatArgumentLegality) - case cast if isCast(cast) => // Casts are allowed to wrap any Expression - case other => doprim.args foreach checkArgumentLegality - } - doprim.op match { - case Add => Seq(castIf(a0), " + ", castIf(a1)) - case Addw => Seq(castIf(a0), " + ", castIf(a1)) - case Sub => Seq(castIf(a0), " - ", castIf(a1)) - case Subw => Seq(castIf(a0), " - ", castIf(a1)) - case Mul => Seq(castIf(a0), " * ", castIf(a1)) - case Div => Seq(castIf(a0), " / ", castIf(a1)) - case Rem => Seq(castIf(a0), " % ", castIf(a1)) - case Lt => Seq(castIf(a0), " < ", castIf(a1)) - case Leq => Seq(castIf(a0), " <= ", castIf(a1)) - case Gt => Seq(castIf(a0), " > ", castIf(a1)) - case Geq => Seq(castIf(a0), " >= ", castIf(a1)) - case Eq => Seq(castIf(a0), " == ", castIf(a1)) - case Neq => Seq(castIf(a0), " != ", castIf(a1)) - case Pad => - val w = bitWidth(a0.tpe) - val diff = c0 - w - if (w == BigInt(0) || diff <= 0) Seq(a0) - else doprim.tpe match { - // Either sign extend or zero extend. - // If width == BigInt(1), don't extract bit - case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}") - case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}") - case (_) => Seq("{{", diff, "'d0}, ", a0, "}") - } - // Because we don't support complex Expressions, all casts are ignored - // This simplifies handling of assignment of a signed expression to an unsigned LHS value - // which does not require a cast in Verilog - case AsUInt | AsSInt | AsClock | AsAsyncReset => Seq(a0) - case Dshlw => Seq(cast(a0), " << ", a1) - case Dshl => Seq(cast(a0), " << ", a1) - case Dshr => doprim.tpe match { - case (_: SIntType) => Seq(cast(a0)," >>> ", a1) - case (_) => Seq(cast(a0), " >> ", a1) - } - case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0)) - case Shr if c0 >= bitWidth(a0.tpe) => - error("Verilog emitter does not support SHIFT_RIGHT >= arg width") - case Shr if c0 == (bitWidth(a0.tpe)-1) => Seq(a0,"[", bitWidth(a0.tpe) - 1, "]") - case Shr => Seq(a0,"[", bitWidth(a0.tpe) - 1, ":", c0, "]") - case Neg => Seq("-", cast(a0)) - case Cvt => a0.tpe match { - case (_: UIntType) => Seq("{1'b0,", cast(a0), "}") - case (_: SIntType) => Seq(cast(a0)) - } - case Not => Seq("~", a0) - case And => Seq(castAs(a0), " & ", castAs(a1)) - case Or => Seq(castAs(a0), " | ", castAs(a1)) - case Xor => Seq(castAs(a0), " ^ ", castAs(a1)) - case Andr => Seq("&", cast(a0)) - case Orr => Seq("|", cast(a0)) - case Xorr => Seq("^", cast(a0)) - case Cat => "{" +: (castCatArgs(a0, a1) :+ "}") - // If selecting zeroth bit and single-bit wire, just emit the wire - case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) - case Bits if c0 == c1 => Seq(a0, "[", c0, "]") - case Bits => Seq(a0, "[", c0, ":", c1, "]") - // If selecting zeroth bit and single-bit wire, just emit the wire - case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) - case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe)-1, "]") - case Head => - val msb = bitWidth(a0.tpe) - 1 - val lsb = bitWidth(a0.tpe) - c0 - Seq(a0, "[", msb, ":", lsb, "]") - case Tail if c0 == (bitWidth(a0.tpe)-1) => Seq(a0, "[0]") - case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]") - } - } + //;------------- PASS ----------------- + def v_print(e: Expression)(implicit w: Writer) = e match { + case UIntLiteral(value, IntWidth(width)) => + w.write(s"$width'h${value.toString(16)}") + case SIntLiteral(value, IntWidth(width)) => + val stringLiteral = value.toString(16) + w.write(stringLiteral.head match { + case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}" + case '-' => s"-$width'sh${stringLiteral.tail}" + case _ => s"$width'sh${stringLiteral}" + }) + case _ => throwInternalError(s"attempt to print unrecognized expression: $e") + } + + // NOTE: We emit SInts as regular Verilog unsigned wires/regs so the real type of any SInt + // reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary + // to ensure Verilog operations are signed. + def op_stream(doprim: DoPrim): Seq[Any] = { + // Cast to SInt, don't cast multiple times + def doCast(e: Expression): Any = e match { + case DoPrim(AsSInt, Seq(arg), _, _) => doCast(arg) + case slit: SIntLiteral => slit + case other => Seq("$signed(", other, ")") + } + def castIf(e: Expression): Any = { + if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) { + e.tpe match { + case _: SIntType => doCast(e) + case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim") + } + } else { + e + } + } + def cast(e: Expression): Any = doprim.tpe match { + case _: UIntType => e + case _: SIntType => doCast(e) + case _ => throwInternalError(s"Unexpected type for $e in $doprim") + } + def castAs(e: Expression): Any = e.tpe match { + case _: UIntType => e + case _: SIntType => doCast(e) + case _ => throwInternalError(s"Unexpected type for $e in $doprim") + } + def a0: Expression = doprim.args.head + def a1: Expression = doprim.args(1) + def c0: Int = doprim.consts.head.toInt + def c1: Int = doprim.consts(1).toInt + + def checkArgumentLegality(e: Expression): Unit = e match { + case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField => + case DoPrim(Not, args, _, _) => args.foreach(checkArgumentLegality) + case DoPrim(op, args, _, _) if isCast(op) => args.foreach(checkArgumentLegality) + case DoPrim(op, args, _, _) if isBitExtract(op) => args.foreach(checkArgumentLegality) + case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") + } + + def checkCatArgumentLegality(e: Expression): Unit = e match { + case DoPrim(Cat, args, _, _) => args.foreach(checkCatArgumentLegality) + case _ => checkArgumentLegality(e) + } + + def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = { + val a0Seq = a0 match { + case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) + case _ => Seq(cast(a0)) + } + val a1Seq = a1 match { + case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1)) + case _ => Seq(cast(a1)) + } + a0Seq ++ Seq(",") ++ a1Seq + } + + doprim.op match { + case Cat => doprim.args.foreach(checkCatArgumentLegality) + case cast if isCast(cast) => // Casts are allowed to wrap any Expression + case other => doprim.args.foreach(checkArgumentLegality) + } + doprim.op match { + case Add => Seq(castIf(a0), " + ", castIf(a1)) + case Addw => Seq(castIf(a0), " + ", castIf(a1)) + case Sub => Seq(castIf(a0), " - ", castIf(a1)) + case Subw => Seq(castIf(a0), " - ", castIf(a1)) + case Mul => Seq(castIf(a0), " * ", castIf(a1)) + case Div => Seq(castIf(a0), " / ", castIf(a1)) + case Rem => Seq(castIf(a0), " % ", castIf(a1)) + case Lt => Seq(castIf(a0), " < ", castIf(a1)) + case Leq => Seq(castIf(a0), " <= ", castIf(a1)) + case Gt => Seq(castIf(a0), " > ", castIf(a1)) + case Geq => Seq(castIf(a0), " >= ", castIf(a1)) + case Eq => Seq(castIf(a0), " == ", castIf(a1)) + case Neq => Seq(castIf(a0), " != ", castIf(a1)) + case Pad => + val w = bitWidth(a0.tpe) + val diff = c0 - w + if (w == BigInt(0) || diff <= 0) Seq(a0) + else + doprim.tpe match { + // Either sign extend or zero extend. + // If width == BigInt(1), don't extract bit + case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}") + case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}") + case (_) => Seq("{{", diff, "'d0}, ", a0, "}") + } + // Because we don't support complex Expressions, all casts are ignored + // This simplifies handling of assignment of a signed expression to an unsigned LHS value + // which does not require a cast in Verilog + case AsUInt | AsSInt | AsClock | AsAsyncReset => Seq(a0) + case Dshlw => Seq(cast(a0), " << ", a1) + case Dshl => Seq(cast(a0), " << ", a1) + case Dshr => + doprim.tpe match { + case (_: SIntType) => Seq(cast(a0), " >>> ", a1) + case (_) => Seq(cast(a0), " >> ", a1) + } + case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0)) + case Shr if c0 >= bitWidth(a0.tpe) => + error("Verilog emitter does not support SHIFT_RIGHT >= arg width") + case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]") + case Shr => Seq(a0, "[", bitWidth(a0.tpe) - 1, ":", c0, "]") + case Neg => Seq("-", cast(a0)) + case Cvt => + a0.tpe match { + case (_: UIntType) => Seq("{1'b0,", cast(a0), "}") + case (_: SIntType) => Seq(cast(a0)) + } + case Not => Seq("~", a0) + case And => Seq(castAs(a0), " & ", castAs(a1)) + case Or => Seq(castAs(a0), " | ", castAs(a1)) + case Xor => Seq(castAs(a0), " ^ ", castAs(a1)) + case Andr => Seq("&", cast(a0)) + case Orr => Seq("|", cast(a0)) + case Xorr => Seq("^", cast(a0)) + case Cat => "{" +: (castCatArgs(a0, a1) :+ "}") + // If selecting zeroth bit and single-bit wire, just emit the wire + case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) + case Bits if c0 == c1 => Seq(a0, "[", c0, "]") + case Bits => Seq(a0, "[", c0, ":", c1, "]") + // If selecting zeroth bit and single-bit wire, just emit the wire + case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0) + case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]") + case Head => + val msb = bitWidth(a0.tpe) - 1 + val lsb = bitWidth(a0.tpe) - c0 + Seq(a0, "[", msb, ":", lsb, "]") + case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[0]") + case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]") + } + } /** * Gets a reference to a verilog renderer. This is used by the current standard verilog emission process @@ -475,31 +508,43 @@ class VerilogEmitter extends SeqTransform with Emitter { * @param writer where rendering will be placed * @return the render reference */ - def getRenderer(descriptions: Seq[DescriptionAnnotation], - m: Module, - moduleMap: Map[String, DefModule])(implicit writer: Writer): VerilogRender = { + def getRenderer( + descriptions: Seq[DescriptionAnnotation], + m: Module, + moduleMap: Map[String, DefModule] + )( + implicit writer: Writer + ): VerilogRender = { val newMod = new AddDescriptionNodes().executeModule(m, descriptions) newMod match { - case DescribedMod(d, pds, m: Module) => new VerilogRender(d, pds, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer) + case DescribedMod(d, pds, m: Module) => + new VerilogRender(d, pds, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer) case m: Module => new VerilogRender(m, moduleMap)(writer) } } - def addFormalStatement(formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]], - clk: Expression, en: Expression, - stmt: Seq[Any], info: Info, msg: StringLit): Unit = { - throw EmitterException("Cannot emit verification statements in Verilog" + - "(2001). Use the SystemVerilog emitter instead.") + def addFormalStatement( + formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]], + clk: Expression, + en: Expression, + stmt: Seq[Any], + info: Info, + msg: StringLit + ): Unit = { + throw EmitterException( + "Cannot emit verification statements in Verilog" + + "(2001). Use the SystemVerilog emitter instead." + ) } /** * Store Emission option per Target * Guarantee only one emission option per Target */ - private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df : V) { + private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df: V) { private val m = collection.mutable.HashMap[ReferenceTarget, V]().withDefaultValue(df) - def +=(elem : (ReferenceTarget, V)) : EmissionOptionMap.this.type = { + def +=(elem: (ReferenceTarget, V)): EmissionOptionMap.this.type = { if (m.contains(elem._1)) throw EmitterException(s"Multiple EmissionOption for the target ${elem._1} (${m(elem._1)} ; ${elem._2})") m += (elem) @@ -511,7 +556,6 @@ class VerilogEmitter extends SeqTransform with Emitter { /** Provide API to retrieve EmissionOptions based on the provided [[AnnotationSeq]] * * @param annotations : AnnotationSeq to be searched for EmissionOptions - * */ private[firrtl] class EmissionOptions(annotations: AnnotationSeq) { // Private so that we can present an immutable API @@ -540,16 +584,34 @@ class VerilogEmitter extends SeqTransform with Emitter { def getConnectEmissionOption(target: ReferenceTarget): ConnectEmissionOption = connectEmissionOption(target) - private val emissionAnnos = annotations.collect{ - case m : SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m + private val emissionAnnos = annotations.collect { + case m: SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m } // using multiple foreach instead of a single partial function as an Annotation can gather multiple EmissionOptions for simplicity - emissionAnnos.foreach { case a :MemoryEmissionOption => memoryEmissionOption += ((a.target,a)) case _ => } - emissionAnnos.foreach { case a :RegisterEmissionOption => registerEmissionOption += ((a.target,a)) case _ => } - emissionAnnos.foreach { case a :WireEmissionOption => wireEmissionOption += ((a.target,a)) case _ => } - emissionAnnos.foreach { case a :PortEmissionOption => portEmissionOption += ((a.target,a)) case _ => } - emissionAnnos.foreach { case a :NodeEmissionOption => nodeEmissionOption += ((a.target,a)) case _ => } - emissionAnnos.foreach { case a :ConnectEmissionOption => connectEmissionOption += ((a.target,a)) case _ => } + emissionAnnos.foreach { + case a: MemoryEmissionOption => memoryEmissionOption += ((a.target, a)) + case _ => + } + emissionAnnos.foreach { + case a: RegisterEmissionOption => registerEmissionOption += ((a.target, a)) + case _ => + } + emissionAnnos.foreach { + case a: WireEmissionOption => wireEmissionOption += ((a.target, a)) + case _ => + } + emissionAnnos.foreach { + case a: PortEmissionOption => portEmissionOption += ((a.target, a)) + case _ => + } + emissionAnnos.foreach { + case a: NodeEmissionOption => nodeEmissionOption += ((a.target, a)) + case _ => + } + emissionAnnos.foreach { + case a: ConnectEmissionOption => connectEmissionOption += ((a.target, a)) + case _ => + } } /** @@ -562,14 +624,24 @@ class VerilogEmitter extends SeqTransform with Emitter { * @param moduleMap a map of modules so submodules can be discovered * @param writer where rendered information is placed. */ - class VerilogRender(description: Seq[Description], - portDescriptions: Map[String, Seq[Description]], - m: Module, - moduleMap: Map[String, DefModule], - circuitName: String, - emissionOptions: EmissionOptions)(implicit writer: Writer) { - - def this(m: Module, moduleMap: Map[String, DefModule], circuitName: String, emissionOptions: EmissionOptions)(implicit writer: Writer) = { + class VerilogRender( + description: Seq[Description], + portDescriptions: Map[String, Seq[Description]], + m: Module, + moduleMap: Map[String, DefModule], + circuitName: String, + emissionOptions: EmissionOptions + )( + implicit writer: Writer) { + + def this( + m: Module, + moduleMap: Map[String, DefModule], + circuitName: String, + emissionOptions: EmissionOptions + )( + implicit writer: Writer + ) = { this(Seq(), Map.empty, m, moduleMap, circuitName, emissionOptions)(writer) } def this(m: Module, moduleMap: Map[String, DefModule])(implicit writer: Writer) = { @@ -582,7 +654,7 @@ class VerilogEmitter extends SeqTransform with Emitter { def build_netlist(s: Statement): Unit = { s.foreach(build_netlist) s match { - case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr) + case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr) case sx: IsInvalid => error("Should have removed these!") // TODO Since only register update and memories use the netlist anymore, I think nodes are // unnecessary @@ -642,7 +714,14 @@ class VerilogEmitter extends SeqTransform with Emitter { if (bi.isValidInt) bi.toString else s"${bi.bitLength}'d$bi" // declare vector type with no preset and optionally with an ifdef guard - private def declareVectorType(b: String, n: String, tpe: Type, size: BigInt, info: Info, ifdefOpt: Option[String]): Unit = { + private def declareVectorType( + b: String, + n: String, + tpe: Type, + size: BigInt, + info: Info, + ifdefOpt: Option[String] + ): Unit = { val decl = Seq(b, " ", tpe, " ", n, " [0:", bigIntToVLit(size - 1), "];", info) if (ifdefOpt.isDefined) { ifdefDeclares(ifdefOpt.get) += decl @@ -675,7 +754,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case tx: VectorType => declareVectorType(b, n, tx.tpe, tx.size, info, ifdefOpt) case tx => - val decl = Seq(b, " ", tx, " ", n,";",info) + val decl = Seq(b, " ", tx, " ", n, ";", info) if (ifdefOpt.isDefined) { ifdefDeclares(ifdefOpt.get) += decl } else { @@ -703,8 +782,18 @@ class VerilogEmitter extends SeqTransform with Emitter { assigns += Seq("`ifndef RANDOMIZE_GARBAGE_ASSIGN") assigns += Seq("assign ", e, " = ", syn, ";", info) assigns += Seq("`else") - assigns += Seq("assign ", e, " = ", garbageCond, " ? ", rand_string(syn.tpe, "RANDOMIZE_GARBAGE_ASSIGN"), " : ", syn, - ";", info) + assigns += Seq( + "assign ", + e, + " = ", + garbageCond, + " ? ", + rand_string(syn.tpe, "RANDOMIZE_GARBAGE_ASSIGN"), + " : ", + syn, + ";", + info + ) assigns += Seq("`endif // RANDOMIZE_GARBAGE_ASSIGN") } @@ -721,12 +810,12 @@ class VerilogEmitter extends SeqTransform with Emitter { if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly") val (eninfo, tinfo, finfo) = MultiInfo.demux(info) - lazy val _if = Seq(tabs, "if (", m.cond, ") begin", eninfo) - lazy val _else = Seq(tabs, "end else begin") - lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo) - lazy val _end = Seq(tabs, "end") - lazy val _true = addUpdate(tinfo, m.tval, tabs + tab) - lazy val _false = addUpdate(finfo, m.fval, tabs + tab) + lazy val _if = Seq(tabs, "if (", m.cond, ") begin", eninfo) + lazy val _else = Seq(tabs, "end else begin") + lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo) + lazy val _end = Seq(tabs, "end") + lazy val _true = addUpdate(tinfo, m.tval, tabs + tab) + lazy val _false = addUpdate(finfo, m.fval, tabs + tab) lazy val _elseIfFalse = { val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail @@ -743,13 +832,14 @@ class VerilogEmitter extends SeqTransform with Emitter { */ (m.tval, m.fval) match { case (t, f) if weq(t, r) && weq(f, r) => Nil - case (t, _) if weq(t, r) => _ifNot +: _false :+ _end - case (_, f) if weq(f, r) => m.cond.tpe match { - case AsyncResetType => (_if +: _true :+ _else) ++ _true :+ _end - case _ => _if +: _true :+ _end - } - case (_, _: Mux) => (_if +: _true) ++ _elseIfFalse - case _ => (_if +: _true :+ _else) ++ _false :+ _end + case (t, _) if weq(t, r) => _ifNot +: _false :+ _end + case (_, f) if weq(f, r) => + m.cond.tpe match { + case AsyncResetType => (_if +: _true :+ _else) ++ _true :+ _end + case _ => _if +: _true :+ _end + } + case (_, _: Mux) => (_if +: _true) ++ _elseIfFalse + case _ => (_if +: _true :+ _else) ++ _false :+ _end } case e => Seq(Seq(tabs, r, " <= ", e, ";", info)) } @@ -816,35 +906,52 @@ class VerilogEmitter extends SeqTransform with Emitter { val maxDataValue = (BigInt(1) << dataWidth.toInt) - 1 def checkValueRange(value: BigInt, at: String): Unit = { - if(value < 0) throw EmitterException(s"Memory ${at} cannot be initialized with negative value: $value") - if(value > maxDataValue) throw EmitterException(s"Memory ${at} cannot be initialized with value: $value. Too large (> $maxDataValue)!") + if (value < 0) throw EmitterException(s"Memory ${at} cannot be initialized with negative value: $value") + if (value > maxDataValue) + throw EmitterException(s"Memory ${at} cannot be initialized with value: $value. Too large (> $maxDataValue)!") } opt.initValue match { case MemoryArrayInit(values) => - if(values.length != s.depth) throw EmitterException( - s"Memory ${s.name} of depth ${s.depth} cannot be initialized with an array of length ${values.length}!" - ) + if (values.length != s.depth) + throw EmitterException( + s"Memory ${s.name} of depth ${s.depth} cannot be initialized with an array of length ${values.length}!" + ) val memName = LowerTypes.loweredName(wref(s.name, s.dataType)) - values.zipWithIndex.foreach { case (value, addr) => - checkValueRange(value, s"${s.name}[$addr]") - val access = s"$memName[${bigIntToVLit(addr)}]" - memoryInitials += Seq(access, " = ", bigIntToVLit(value), ";") + values.zipWithIndex.foreach { + case (value, addr) => + checkValueRange(value, s"${s.name}[$addr]") + val access = s"$memName[${bigIntToVLit(addr)}]" + memoryInitials += Seq(access, " = ", bigIntToVLit(value), ";") } case MemoryScalarInit(value) => checkValueRange(value, s.name) // note: s.dataType is the incorrect type for initvar, but it is ignored in the serialization val index = wref("initvar", s.dataType) memoryInitials += Seq("for (initvar = 0; initvar < ", bigIntToVLit(s.depth), "; initvar = initvar+1)") - memoryInitials += Seq(tab, WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow), - " = ", bigIntToVLit(value), ";") + memoryInitials += Seq( + tab, + WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow), + " = ", + bigIntToVLit(value), + ";" + ) case MemoryRandomInit => // note: s.dataType is the incorrect type for initvar, but it is ignored in the serialization val index = wref("initvar", s.dataType) val rstring = rand_string(s.dataType, "RANDOMIZE_MEM_INIT") - ifdefInitials("RANDOMIZE_MEM_INIT") += Seq("for (initvar = 0; initvar < ", bigIntToVLit(s.depth), "; initvar = initvar+1)") - ifdefInitials("RANDOMIZE_MEM_INIT") += Seq(tab, WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow), - " = ", rstring, ";") + ifdefInitials("RANDOMIZE_MEM_INIT") += Seq( + "for (initvar = 0; initvar < ", + bigIntToVLit(s.depth), + "; initvar = initvar+1)" + ) + ifdefInitials("RANDOMIZE_MEM_INIT") += Seq( + tab, + WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow), + " = ", + rstring, + ";" + ) } } @@ -888,7 +995,7 @@ class VerilogEmitter extends SeqTransform with Emitter { if (lines.size > 1) { val lineSeqs = lines.tail.map { - case "" => Seq(" *") + case "" => Seq(" *") case nonEmpty => Seq(" * ", nonEmpty) } Seq("/* ", lines.head) +: lineSeqs :+ Seq(" */") @@ -905,19 +1012,20 @@ class VerilogEmitter extends SeqTransform with Emitter { def build_ports(): Unit = { def padToMax(strs: Seq[String]): Seq[String] = { val len = if (strs.nonEmpty) strs.map(_.length).max else 0 - strs map (_.padTo(len, ' ')) + strs.map(_.padTo(len, ' ')) } // Turn directions into strings (and AnalogType into inout) - val dirs = m.ports map { case Port(_, name, dir, tpe) => - (dir, tpe) match { - case (_, AnalogType(_)) => "inout " // padded to length of output - case (Input, _) => "input " - case (Output, _) => "output" - } + val dirs = m.ports.map { + case Port(_, name, dir, tpe) => + (dir, tpe) match { + case (_, AnalogType(_)) => "inout " // padded to length of output + case (Input, _) => "input " + case (Output, _) => "output" + } } // Turn types into strings, all ports must be GroundTypes - val tpes = m.ports map { + val tpes = m.ports.map { case Port(_, _, _, tpe: GroundType) => stringify(tpe) case port: Port => error(s"Trying to emit non-GroundType Port $port") } @@ -925,9 +1033,10 @@ class VerilogEmitter extends SeqTransform with Emitter { // dirs are already padded (dirs, padToMax(tpes), m.ports).zipped.toSeq.zipWithIndex.foreach { case ((dir, tpe, Port(info, name, _, _)), i) => - portDescriptions.get(name).map { case d => - portdefs += Seq("") - portdefs ++= build_description(d) + portDescriptions.get(name).map { + case d => + portdefs += Seq("") + portdefs ++= build_description(d) } if (i != m.ports.size - 1) { @@ -956,14 +1065,14 @@ class VerilogEmitter extends SeqTransform with Emitter { } withoutDescription.foreach(build_streams) withoutDescription match { - case sx@Connect(info, loc@WRef(_, _, PortKind | WireKind | InstanceKind, _), expr) => + case sx @ Connect(info, loc @ WRef(_, _, PortKind | WireKind | InstanceKind, _), expr) => assign(loc, expr, info) case sx: DefWire => declare("wire", sx.name, sx.tpe, sx.info) case sx: DefRegister => val options = emissionOptions.getRegisterEmissionOption(moduleTarget.ref(sx.name)) val e = wref(sx.name, sx.tpe) - if (options.useInitAsPreset){ + if (options.useInitAsPreset) { declare("reg", sx.name, sx.tpe, sx.info, sx.init) regUpdate(e, sx.clock, sx.reset, e) } else { @@ -997,11 +1106,11 @@ class VerilogEmitter extends SeqTransform with Emitter { case sx: WDefInstanceConnector => val (module, params) = moduleMap(sx.module) match { case DescribedMod(_, _, ExtModule(_, _, _, extname, params)) => (extname, params) - case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty) - case ExtModule(_, _, _, extname, params) => (extname, params) - case Module(_, name, _, _) => (name, Seq.empty) + case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty) + case ExtModule(_, _, _, extname, params) => (extname, params) + case Module(_, name, _, _) => (name, Seq.empty) } - val ps = if (params.nonEmpty) params map stringify mkString("#(", ", ", ") ") else "" + val ps = if (params.nonEmpty) params.map(stringify).mkString("#(", ", ", ") ") else "" instdeclares += Seq(module, " ", ps, sx.name, " (", sx.info) for (((port, ref), i) <- sx.portCons.zipWithIndex) { val line = Seq(tab, ".", remove_root(port), "(", ref, ")") @@ -1012,14 +1121,16 @@ class VerilogEmitter extends SeqTransform with Emitter { case sx: DefMemory => val options = emissionOptions.getMemoryEmissionOption(moduleTarget.ref(sx.name)) val fullSize = sx.depth * (sx.dataType match { - case GroundType(IntWidth(width)) => width - }) + case GroundType(IntWidth(width)) => width + }) val decl = if (fullSize > (1 << 29)) "reg /* sparse */" else "reg" declareVectorType(decl, sx.name, sx.dataType, sx.depth, sx.info) initialize_mem(sx, options) if (sx.readLatency != 0 || sx.writeLatency != 1) - throw EmitterException("All memories should be transformed into " + - "blackboxes or combinational by previous passses") + throw EmitterException( + "All memories should be transformed into " + + "blackboxes or combinational by previous passses" + ) for (r <- sx.readers) { val data = memPortField(sx, r, "data") val addr = memPortField(sx, r, "addr") @@ -1031,7 +1142,7 @@ class VerilogEmitter extends SeqTransform with Emitter { //; Read port assign(addr, netlist(addr)) - // assign(en, netlist(en)) //;Connects value to m.r.en + // assign(en, netlist(en)) //;Connects value to m.r.en val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow) val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow) val depthValue = UIntLiteral(sx.depth, IntWidth(sx.depth.bitLength)) @@ -1069,8 +1180,10 @@ class VerilogEmitter extends SeqTransform with Emitter { } if (sx.readwriters.nonEmpty) - throw EmitterException("All readwrite ports should be transformed into " + - "read & write ports by previous passes") + throw EmitterException( + "All readwrite ports should be transformed into " + + "read & write ports by previous passes" + ) case _ => } } @@ -1081,10 +1194,11 @@ class VerilogEmitter extends SeqTransform with Emitter { for (x <- portdefs) emit(Seq(tab, x)) emit(Seq(");")) - ifdefDeclares.toSeq.sortWith(_._1 < _._1).foreach { case (ifdef, declares) => - emit(Seq("`ifdef " + ifdef)) - for (x <- declares) emit(Seq(tab, x)) - emit(Seq("`endif // " + ifdef)) + ifdefDeclares.toSeq.sortWith(_._1 < _._1).foreach { + case (ifdef, declares) => + emit(Seq("`ifdef " + ifdef)) + for (x <- declares) emit(Seq(tab, x)) + emit(Seq("`endif // " + ifdef)) } for (x <- declares) emit(Seq(tab, x)) for (x <- instdeclares) emit(Seq(tab, x)) @@ -1093,7 +1207,12 @@ class VerilogEmitter extends SeqTransform with Emitter { emit(Seq("`ifdef SYNTHESIS")) for (x <- attachSynAssigns) emit(Seq(tab, x)) emit(Seq("`elsif verilator")) - emit(Seq(tab, "`error \"Verilator does not support alias and thus cannot arbirarily connect bidirectional wires and ports\"")) + emit( + Seq( + tab, + "`error \"Verilator does not support alias and thus cannot arbirarily connect bidirectional wires and ports\"" + ) + ) emit(Seq("`else")) for (x <- attachAliases) emit(Seq(tab, x)) emit(Seq("`endif")) @@ -1129,7 +1248,7 @@ class VerilogEmitter extends SeqTransform with Emitter { emit(Seq("`define RANDOM $random")) emit(Seq("`endif")) // the initvar is also used to initialize memories to constants - if(memoryInitials.isEmpty) emit(Seq("`ifdef RANDOMIZE_MEM_INIT")) + if (memoryInitials.isEmpty) emit(Seq("`ifdef RANDOMIZE_MEM_INIT")) // Since simulators don't actually support memories larger than 2^31 - 1, there is no reason // to change Verilog emission in the common case. Instead, we only emit a larger initvar // where necessary @@ -1140,7 +1259,7 @@ class VerilogEmitter extends SeqTransform with Emitter { val width = maxMemSize.bitLength - 1 // minus one because [width-1:0] has a width of "width" emit(Seq(s" reg [$width:0] initvar;")) } - if(memoryInitials.isEmpty) emit(Seq("`endif")) + if (memoryInitials.isEmpty) emit(Seq("`endif")) emit(Seq("`ifndef SYNTHESIS")) // User-defined macro of code to run before an initial block emit(Seq("`ifdef FIRRTL_BEFORE_INITIAL")) @@ -1162,15 +1281,16 @@ class VerilogEmitter extends SeqTransform with Emitter { emit(Seq(" #0.002 begin end")) emit(Seq(" `endif")) emit(Seq(" `endif")) - ifdefInitials.toSeq.sortWith(_._1 < _._1).foreach { case (ifdef, initials) => - emit(Seq("`ifdef " + ifdef)) - for (x <- initials) emit(Seq(tab, x)) - emit(Seq("`endif // " + ifdef)) + ifdefInitials.toSeq.sortWith(_._1 < _._1).foreach { + case (ifdef, initials) => + emit(Seq("`ifdef " + ifdef)) + for (x <- initials) emit(Seq(tab, x)) + emit(Seq("`endif // " + ifdef)) } for (x <- initials) emit(Seq(tab, x)) for (x <- asyncInitials) emit(Seq(tab, x)) emit(Seq(" `endif // RANDOMIZE")) - for(x <- memoryInitials) emit(Seq(tab, x)) + for (x <- memoryInitials) emit(Seq(tab, x)) emit(Seq("end // initial")) // User-defined macro of code to run after an initial block emit(Seq("`ifdef FIRRTL_AFTER_INITIAL")) @@ -1258,7 +1378,7 @@ class VerilogEmitter extends SeqTransform with Emitter { val emissionOptions = new EmissionOptions(cs.annotations) val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap - cs.circuit.modules flatMap { + cs.circuit.modules.flatMap { case dm @ DescribedMod(d, pds, module: Module) => val writer = new java.io.StringWriter val renderer = new VerilogRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer) @@ -1282,8 +1402,8 @@ class MinimumVerilogEmitter extends VerilogEmitter with Emitter { override def prerequisites = firrtl.stage.Forms.AssertsRemoved ++ firrtl.stage.Forms.LowFormMinimumOptimized - override def transforms = new TransformManager(firrtl.stage.Forms.VerilogMinimumOptimized, prerequisites) - .flattenedTransformOrder + override def transforms = + new TransformManager(firrtl.stage.Forms.VerilogMinimumOptimized, prerequisites).flattenedTransformOrder } @@ -1292,9 +1412,14 @@ class SystemVerilogEmitter extends VerilogEmitter { override def prerequisites = firrtl.stage.Forms.LowFormOptimized - override def addFormalStatement(formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]], - clk: Expression, en: Expression, - stmt: Seq[Any], info: Info, msg: StringLit): Unit = { + override def addFormalStatement( + formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]], + clk: Expression, + en: Expression, + stmt: Seq[Any], + info: Info, + msg: StringLit + ): Unit = { val lines = formals.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) lines += Seq("// ", msg.serialize) lines += Seq("if (", en, ") begin") diff --git a/src/main/scala/firrtl/ExecutionOptionsManager.scala b/src/main/scala/firrtl/ExecutionOptionsManager.scala index d21ccade..50fb30a6 100644 --- a/src/main/scala/firrtl/ExecutionOptionsManager.scala +++ b/src/main/scala/firrtl/ExecutionOptionsManager.scala @@ -5,15 +5,22 @@ package firrtl import logger.LogLevel import logger.{ClassLogLevelAnnotation, LogClassNamesAnnotation, LogFileAnnotation, LogLevelAnnotation} import firrtl.annotations._ -import firrtl.Parser.{InfoMode, UseInfo, IgnoreInfo, GenInfo, AppendInfo} +import firrtl.Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo} import firrtl.ir.Circuit import firrtl.passes.memlib.{InferReadWriteAnnotation, ReplSeqMemAnnotation} import firrtl.passes.clocklist.ClockListAnnotation import firrtl.transforms.NoCircuitDedupAnnotation import scopt.OptionParser -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, - InfoModeAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation} -import firrtl.stage.phases.DriverCompatibility.{TopNameAnnotation, EmitOneFilePerModuleAnnotation} +import firrtl.stage.{ + CompilerAnnotation, + FirrtlCircuitAnnotation, + FirrtlFileAnnotation, + FirrtlSourceAnnotation, + InfoModeAnnotation, + OutputFileAnnotation, + RunFirrtlTransformAnnotation +} +import firrtl.stage.phases.DriverCompatibility.{EmitOneFilePerModuleAnnotation, TopNameAnnotation} import firrtl.options.{InputAnnotationFileAnnotation, OutputAnnotationFileAnnotation, ProgramArgsAnnotation, StageUtils} import firrtl.transforms.{DontCheckCombLoopsAnnotation, NoDCEAnnotation} @@ -33,7 +40,7 @@ abstract class HasParser(applicationName: String) { final val parser = new OptionParser[Unit](applicationName) { var terminateOnExit = true override def terminate(exitState: Either[String, Unit]): Unit = { - if(terminateOnExit) sys.exit(0) + if (terminateOnExit) sys.exit(0) } } @@ -43,12 +50,14 @@ abstract class HasParser(applicationName: String) { def doNotExitOnHelp(): Unit = { parser.terminateOnExit = false } + /** * By default scopt calls sys.exit when --help is in options, this un-defeats doNotExitOnHelp */ def exitOnHelp(): Unit = { parser.terminateOnExit = true - }} + } +} /** * Most of the chisel toolchain components require a topName which defines a circuit or a device under test. @@ -59,20 +68,19 @@ abstract class HasParser(applicationName: String) { */ @deprecated("Use a FirrtlOptionsView, LoggerOptionsView, or construct your own view of an AnnotationSeq", "1.2") case class CommonOptions( - topName: String = "", - targetDirName: String = ".", - globalLogLevel: LogLevel.Value = LogLevel.None, - logToFile: Boolean = false, - logClassNames: Boolean = false, - classLogLevels: Map[String, LogLevel.Value] = Map.empty, - programArgs: Seq[String] = Seq.empty -) extends ComposableOptions { + topName: String = "", + targetDirName: String = ".", + globalLogLevel: LogLevel.Value = LogLevel.None, + logToFile: Boolean = false, + logClassNames: Boolean = false, + classLogLevels: Map[String, LogLevel.Value] = Map.empty, + programArgs: Seq[String] = Seq.empty) + extends ComposableOptions { def getLogFileName(optionsManager: ExecutionOptionsManager): String = { - if(topName.isEmpty) { + if (topName.isEmpty) { optionsManager.getBuildFileName("log", "firrtl") - } - else { + } else { optionsManager.getBuildFileName("log") } } @@ -80,10 +88,12 @@ case class CommonOptions( def toAnnotations: AnnotationSeq = List() ++ (if (topName.nonEmpty) Seq(TopNameAnnotation(topName)) else Seq()) ++ (if (targetDirName != ".") Some(TargetDirAnnotation(targetDirName)) else None) ++ Some(LogLevelAnnotation(globalLogLevel)) ++ - (if (logToFile) { Some(LogFileAnnotation(None)) } else { None }) ++ - (if (logClassNames) { Some(LogClassNamesAnnotation) } else { None }) ++ - classLogLevels.map{ case (c, v) => ClassLogLevelAnnotation(c, v) } ++ - programArgs.map( a => ProgramArgsAnnotation(a) ) + (if (logToFile) { Some(LogFileAnnotation(None)) } + else { None }) ++ + (if (logClassNames) { Some(LogClassNamesAnnotation) } + else { None }) ++ + classLogLevels.map { case (c, v) => ClassLogLevelAnnotation(c, v) } ++ + programArgs.map(a => ProgramArgsAnnotation(a)) } @deprecated("Specify command line arguments in an Annotation mixing in HasScoptOptions", "1.2") @@ -93,7 +103,8 @@ trait HasCommonOptions { parser.note("common options") - parser.opt[String]("top-name") + parser + .opt[String]("top-name") .abbr("tn") .valueName("<top-level-circuit-name>") .foreach { x => @@ -101,15 +112,19 @@ trait HasCommonOptions { } .text("This options defines the top level circuit, defaults to dut when possible") - parser.opt[String]("target-dir") - .abbr("td").valueName("<target-directory>") + parser + .opt[String]("target-dir") + .abbr("td") + .valueName("<target-directory>") .foreach { x => commonOptions = commonOptions.copy(targetDirName = x) } .text(s"This options defines a work directory for intermediate files, default is ${commonOptions.targetDirName}") - parser.opt[String]("log-level") - .abbr("ll").valueName("<error|warn|info|debug|trace>") + parser + .opt[String]("log-level") + .abbr("ll") + .valueName("<error|warn|info|debug|trace>") .foreach { x => val level = x.toLowerCase match { case "error" => LogLevel.Error @@ -126,16 +141,18 @@ trait HasCommonOptions { } .text(s"This options defines global log level, default is ${commonOptions.globalLogLevel}") - parser.opt[Seq[String]]("class-log-level") - .abbr("cll").valueName("<FullClassName:[error|warn|info|debug|trace]>[,...]") + parser + .opt[Seq[String]]("class-log-level") + .abbr("cll") + .valueName("<FullClassName:[error|warn|info|debug|trace]>[,...]") .foreach { x => val logAssignments = x.map { y => val className :: levelName :: _ = y.split(":").toList val level = levelName.toLowerCase match { case "error" => LogLevel.Error - case "warn" => LogLevel.Warn - case "info" => LogLevel.Info + case "warn" => LogLevel.Warn + case "info" => LogLevel.Info case "debug" => LogLevel.Debug case "trace" => LogLevel.Trace case _ => @@ -149,14 +166,16 @@ trait HasCommonOptions { } .text(s"This options defines class log level, default is ${commonOptions.classLogLevels}") - parser.opt[Unit]("log-to-file") + parser + .opt[Unit]("log-to-file") .abbr("ltf") .foreach { _ => commonOptions = commonOptions.copy(logToFile = true) } .text(s"default logs to stdout, this flags writes to topName.log or firrtl.log if no topName") - parser.opt[Unit]("log-class-names") + parser + .opt[Unit]("log-class-names") .abbr("lcn") .foreach { _ => commonOptions = commonOptions.copy(logClassNames = true) @@ -165,8 +184,12 @@ trait HasCommonOptions { parser.help("help").text("prints this usage text") - parser.arg[String]("<arg>...").unbounded().optional().action( (x, c) => - commonOptions = commonOptions.copy(programArgs = commonOptions.programArgs :+ x) ).text("optional unbounded args") + parser + .arg[String]("<arg>...") + .unbounded() + .optional() + .action((x, c) => commonOptions = commonOptions.copy(programArgs = commonOptions.programArgs :+ x)) + .text("optional unbounded args") } @@ -189,46 +212,47 @@ final case class OneFilePerModule(targetDir: String) extends OutputConfig */ @deprecated("Use a FirrtlOptionsView or construct your own view of an AnnotationSeq", "1.2") case class FirrtlExecutionOptions( - inputFileNameOverride: String = "", - outputFileNameOverride: String = "", - compilerName: String = "verilog", - infoModeName: String = "append", - inferRW: Seq[String] = Seq.empty, - firrtlSource: Option[String] = None, - customTransforms: Seq[Transform] = List.empty, - annotations: List[Annotation] = List.empty, - annotationFileNameOverride: String = "", - outputAnnotationFileName: String = "", - emitOneFilePerModule: Boolean = false, - dontCheckCombLoops: Boolean = false, - noDCE: Boolean = false, - annotationFileNames: List[String] = List.empty, - firrtlCircuit: Option[Circuit] = None -) -extends ComposableOptions { - - require(!(emitOneFilePerModule && outputFileNameOverride.nonEmpty), - "Cannot both specify the output filename and emit one file per module!!!") + inputFileNameOverride: String = "", + outputFileNameOverride: String = "", + compilerName: String = "verilog", + infoModeName: String = "append", + inferRW: Seq[String] = Seq.empty, + firrtlSource: Option[String] = None, + customTransforms: Seq[Transform] = List.empty, + annotations: List[Annotation] = List.empty, + annotationFileNameOverride: String = "", + outputAnnotationFileName: String = "", + emitOneFilePerModule: Boolean = false, + dontCheckCombLoops: Boolean = false, + noDCE: Boolean = false, + annotationFileNames: List[String] = List.empty, + firrtlCircuit: Option[Circuit] = None) + extends ComposableOptions { + + require( + !(emitOneFilePerModule && outputFileNameOverride.nonEmpty), + "Cannot both specify the output filename and emit one file per module!!!" + ) def infoMode: InfoMode = { infoModeName match { - case "use" => UseInfo + case "use" => UseInfo case "ignore" => IgnoreInfo - case "gen" => GenInfo(inputFileNameOverride) + case "gen" => GenInfo(inputFileNameOverride) case "append" => AppendInfo(inputFileNameOverride) - case other => UseInfo + case other => UseInfo } } def compiler: Compiler = { compilerName match { - case "none" => new NoneCompiler() - case "high" => new HighFirrtlCompiler() - case "low" => new LowFirrtlCompiler() - case "middle" => new MiddleFirrtlCompiler() - case "verilog" => new VerilogCompiler() - case "mverilog" => new MinimumVerilogCompiler() - case "sverilog" => new SystemVerilogCompiler() + case "none" => new NoneCompiler() + case "high" => new HighFirrtlCompiler() + case "low" => new LowFirrtlCompiler() + case "middle" => new MiddleFirrtlCompiler() + case "verilog" => new VerilogCompiler() + case "mverilog" => new MinimumVerilogCompiler() + case "sverilog" => new SystemVerilogCompiler() } } @@ -255,6 +279,7 @@ extends ComposableOptions { if (inputFileNameOverride.nonEmpty) inputFileNameOverride else optionsManager.getBuildFileName("fir", inputFileNameOverride) } + /** Get the user-specified [[OutputConfig]] * * @param optionsManager this is needed to access build function and its common options @@ -264,6 +289,7 @@ extends ComposableOptions { if (emitOneFilePerModule) OneFilePerModule(optionsManager.targetDirName) else SingleFile(optionsManager.getBuildFileName(outputSuffix, outputFileNameOverride)) } + /** Get the user-specified targetFile assuming [[OutputConfig]] is [[SingleFile]] * * @param optionsManager this is needed to access build function and its common options @@ -272,9 +298,10 @@ extends ComposableOptions { def getTargetFile(optionsManager: ExecutionOptionsManager): String = { getOutputConfig(optionsManager) match { case SingleFile(targetFile) => targetFile - case other => throw new Exception("OutputConfig is not SingleFile!") + case other => throw new Exception("OutputConfig is not SingleFile!") } } + /** Gives annotations based on the output configuration * * @param optionsManager this is needed to access build function and its common options @@ -283,19 +310,20 @@ extends ComposableOptions { def getEmitterAnnos(optionsManager: ExecutionOptionsManager): Seq[Annotation] = { // TODO should this be a public function? val emitter = compilerName match { - case "none" => classOf[ChirrtlEmitter] - case "high" => classOf[HighFirrtlEmitter] - case "middle" => classOf[MiddleFirrtlEmitter] - case "low" => classOf[LowFirrtlEmitter] - case "verilog" => classOf[VerilogEmitter] + case "none" => classOf[ChirrtlEmitter] + case "high" => classOf[HighFirrtlEmitter] + case "middle" => classOf[MiddleFirrtlEmitter] + case "low" => classOf[LowFirrtlEmitter] + case "verilog" => classOf[VerilogEmitter] case "mverilog" => classOf[MinimumVerilogEmitter] case "sverilog" => classOf[VerilogEmitter] } getOutputConfig(optionsManager) match { - case SingleFile(_) => Seq(EmitCircuitAnnotation(emitter)) + case SingleFile(_) => Seq(EmitCircuitAnnotation(emitter)) case OneFilePerModule(_) => Seq(EmitAllModulesAnnotation(emitter)) } } + /** * build the annotation file name, taking overriding parameters * @@ -313,23 +341,28 @@ extends ComposableOptions { } List() ++ (if (inputFileNameOverride.nonEmpty) Seq(FirrtlFileAnnotation(inputFileNameOverride)) else Seq()) ++ - (if (outputFileNameOverride.nonEmpty) { Some(OutputFileAnnotation(outputFileNameOverride)) } else { None }) ++ + (if (outputFileNameOverride.nonEmpty) { Some(OutputFileAnnotation(outputFileNameOverride)) } + else { None }) ++ Some(CompilerAnnotation(compilerName)) ++ Some(InfoModeAnnotation(infoModeName)) ++ firrtlSource.map(FirrtlSourceAnnotation(_)) ++ customTransforms.map(t => RunFirrtlTransformAnnotation(t)) ++ annotations ++ - (if (annotationFileNameOverride.nonEmpty) { Some(InputAnnotationFileAnnotation(annotationFileNameOverride)) } else { None }) ++ - (if (outputAnnotationFileName.nonEmpty) { Some(OutputAnnotationFileAnnotation(outputAnnotationFileName)) } else { None }) ++ - (if (emitOneFilePerModule) { Some(EmitOneFilePerModuleAnnotation) } else { None }) ++ - (if (dontCheckCombLoops) { Some(DontCheckCombLoopsAnnotation) } else { None }) ++ - (if (noDCE) { Some(NoDCEAnnotation) } else { None }) ++ + (if (annotationFileNameOverride.nonEmpty) { Some(InputAnnotationFileAnnotation(annotationFileNameOverride)) } + else { None }) ++ + (if (outputAnnotationFileName.nonEmpty) { Some(OutputAnnotationFileAnnotation(outputAnnotationFileName)) } + else { None }) ++ + (if (emitOneFilePerModule) { Some(EmitOneFilePerModuleAnnotation) } + else { None }) ++ + (if (dontCheckCombLoops) { Some(DontCheckCombLoopsAnnotation) } + else { None }) ++ + (if (noDCE) { Some(NoDCEAnnotation) } + else { None }) ++ annotationFileNames.map(InputAnnotationFileAnnotation(_)) ++ firrtlCircuit.map(FirrtlCircuitAnnotation(_)) } } - @deprecated("Specify command line arguments in an Annotation mixing in HasScoptOptions", "1.2") trait HasFirrtlOptions { self: ExecutionOptionsManager => @@ -337,16 +370,19 @@ trait HasFirrtlOptions { parser.note("firrtl options") - parser.opt[String]("input-file") + parser + .opt[String]("input-file") .abbr("i") - .valueName ("<firrtl-source>") + .valueName("<firrtl-source>") .foreach { x => firrtlOptions = firrtlOptions.copy(inputFileNameOverride = x) - }.text { + } + .text { "use this to override the default input file name , default is empty" } - parser.opt[String]("output-file") + parser + .opt[String]("output-file") .abbr("o") .valueName("<output>") .validate { x => @@ -356,40 +392,47 @@ trait HasFirrtlOptions { } .foreach { x => firrtlOptions = firrtlOptions.copy(outputFileNameOverride = x) - }.text { - "use this to override the default output file name, default is empty" - } + } + .text { + "use this to override the default output file name, default is empty" + } - parser.opt[String]("annotation-file") + parser + .opt[String]("annotation-file") .abbr("faf") .unbounded() .valueName("<input-anno-file>") .foreach { x => val annoFiles = x +: firrtlOptions.annotationFileNames firrtlOptions = firrtlOptions.copy(annotationFileNames = annoFiles) - }.text("Used to specify annotation files (can appear multiple times)") + } + .text("Used to specify annotation files (can appear multiple times)") - parser.opt[Unit]("force-append-anno-file") + parser + .opt[Unit]("force-append-anno-file") .abbr("ffaaf") .hidden() .foreach { _ => val msg = "force-append-anno-file is deprecated and will soon be removed\n" + - (" "*9) + "(It does not do anything anymore)" + (" " * 9) + "(It does not do anything anymore)" StageUtils.dramaticWarning(msg) } - parser.opt[String]("output-annotation-file") + parser + .opt[String]("output-annotation-file") .abbr("foaf") - .valueName ("<output-anno-file>") + .valueName("<output-anno-file>") .foreach { x => firrtlOptions = firrtlOptions.copy(outputAnnotationFileName = x) - }.text { - "use this to set the annotation output file" - } + } + .text { + "use this to set the annotation output file" + } - parser.opt[String]("compiler") + parser + .opt[String]("compiler") .abbr("X") - .valueName ("<high|middle|low|verilog|mverilog|sverilog|none>") + .valueName("<high|middle|low|verilog|mverilog|sverilog|none>") .foreach { x => firrtlOptions = firrtlOptions.copy(compilerName = x) } @@ -399,12 +442,14 @@ trait HasFirrtlOptions { } else { parser.failure(s"$x not a legal compiler") } - }.text { + } + .text { s"compiler to use, default is ${firrtlOptions.compilerName}" } - parser.opt[String]("info-mode") - .valueName ("<ignore|use|gen|append>") + parser + .opt[String]("info-mode") + .valueName("<ignore|use|gen|append>") .foreach { x => firrtlOptions = firrtlOptions.copy(infoModeName = x.toLowerCase) } @@ -416,13 +461,14 @@ trait HasFirrtlOptions { s"specifies the source info handling, default is ${firrtlOptions.infoModeName}" } - parser.opt[Seq[String]]("custom-transforms") + parser + .opt[Seq[String]]("custom-transforms") .abbr("fct") - .valueName ("<package>.<class>") + .valueName("<package>.<class>") .foreach { customTransforms: Seq[String] => firrtlOptions = firrtlOptions.copy( customTransforms = firrtlOptions.customTransforms ++ - (customTransforms map { x: String => + (customTransforms.map { x: String => Class.forName(x).asInstanceOf[Class[_ <: Transform]].newInstance() }) ) @@ -431,10 +477,10 @@ trait HasFirrtlOptions { """runs these custom transforms during compilation.""" } - - parser.opt[Seq[String]]("inline") + parser + .opt[Seq[String]]("inline") .abbr("fil") - .valueName ("<circuit>[.<module>[.<instance>]][,..],") + .valueName("<circuit>[.<module>[.<instance>]][,..],") .foreach { x => val newAnnotations = x.map { value => value.split('.') match { @@ -455,20 +501,23 @@ trait HasFirrtlOptions { """Inline one or more module (comma separated, no spaces) module looks like "MyModule" or "MyModule.myinstance""" } - parser.opt[Unit]("infer-rw") + parser + .opt[Unit]("infer-rw") .abbr("firw") .foreach { x => firrtlOptions = firrtlOptions.copy( annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation, customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.InferReadWrite ) - }.text { + } + .text { "Enable readwrite port inference for the target circuit" } - parser.opt[String]("repl-seq-mem") + parser + .opt[String]("repl-seq-mem") .abbr("frsq") - .valueName ("-c:<circuit>:-i:<filename>:-o:<filename>") + .valueName("-c:<circuit>:-i:<filename>:-o:<filename>") .foreach { x => firrtlOptions = firrtlOptions.copy( annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation.parse(x), @@ -479,9 +528,10 @@ trait HasFirrtlOptions { "Replace sequential memories with blackboxes + configuration file" } - parser.opt[String]("list-clocks") + parser + .opt[String]("list-clocks") .abbr("clks") - .valueName ("-c:<circuit>:-m:<module>:-o:<filename>") + .valueName("-c:<circuit>:-m:<module>:-o:<filename>") .foreach { x => firrtlOptions = firrtlOptions.copy( annotations = firrtlOptions.annotations :+ ClockListAnnotation.parse(x), @@ -492,7 +542,8 @@ trait HasFirrtlOptions { "List which signal drives each clock of every descendent of specified module" } - parser.opt[Unit]("split-modules") + parser + .opt[Unit]("split-modules") .abbr("fsm") .validate { x => if (firrtlOptions.outputFileNameOverride.nonEmpty) @@ -501,32 +552,39 @@ trait HasFirrtlOptions { } .foreach { _ => firrtlOptions = firrtlOptions.copy(emitOneFilePerModule = true) - }.text { + } + .text { "Emit each module to its own file in the target directory." } - parser.opt[Unit]("no-check-comb-loops") + parser + .opt[Unit]("no-check-comb-loops") .foreach { _ => firrtlOptions = firrtlOptions.copy(dontCheckCombLoops = true) - }.text { + } + .text { "Do NOT check for combinational loops (not recommended)" } - parser.opt[Unit]("no-dce") + parser + .opt[Unit]("no-dce") .foreach { _ => firrtlOptions = firrtlOptions.copy(noDCE = true) - }.text { + } + .text { "Do NOT run dead code elimination" } - parser.opt[Unit]("no-dedup") + parser + .opt[Unit]("no-dedup") .foreach { _ => firrtlOptions = firrtlOptions.copy( annotations = firrtlOptions.annotations :+ NoCircuitDedupAnnotation ) - }.text { - "Do NOT dedup modules" - } + } + .text { + "Do NOT dedup modules" + } parser.note("") } @@ -537,16 +595,16 @@ sealed trait FirrtlExecutionResult @deprecated("Use FirrtlStage and examine the output AnnotationSeq directly", "1.2") object FirrtlExecutionSuccess { def apply( - emitType : String, - emitted : String, + emitType: String, + emitted: String, circuitState: CircuitState ): FirrtlExecutionSuccess = new FirrtlExecutionSuccess(emitType, emitted, circuitState) - def unapply(arg: FirrtlExecutionSuccess): Option[(String, String)] = { Some((arg.emitType, arg.emitted)) } } + /** * Indicates a successful execution of the firrtl compiler, returning the compiled result and * the type of compile @@ -557,10 +615,10 @@ object FirrtlExecutionSuccess { */ @deprecated("Use FirrtlStage and examine the output AnnotationSeq directly", "1.2") class FirrtlExecutionSuccess( - val emitType: String, - val emitted : String, - val circuitState: CircuitState -) extends FirrtlExecutionResult + val emitType: String, + val emitted: String, + val circuitState: CircuitState) + extends FirrtlExecutionResult /** * The firrtl compilation failed. @@ -571,7 +629,6 @@ class FirrtlExecutionSuccess( case class FirrtlExecutionFailure(message: String) extends FirrtlExecutionResult /** - * * @param applicationName The name shown in the usage */ @deprecated("Use new FirrtlStage infrastructure", "1.2") @@ -607,7 +664,7 @@ class ExecutionOptionsManager(val applicationName: String) extends HasParser(app commonOptions = commonOptions.copy(topName = newTopName) } def setTopNameIfNotSet(newTopName: String): Unit = { - if(commonOptions.topName.isEmpty) { + if (commonOptions.topName.isEmpty) { setTopName(newTopName) } } @@ -627,21 +684,19 @@ class ExecutionOptionsManager(val applicationName: String) extends HasParser(app def getBuildFileName(suffix: String, fileNameOverride: String = ""): String = { makeTargetDir() - val baseName = if(fileNameOverride.nonEmpty) fileNameOverride else topName + val baseName = if (fileNameOverride.nonEmpty) fileNameOverride else topName val directoryName = { - if(fileNameOverride.nonEmpty) { + if (fileNameOverride.nonEmpty) { "" - } - else if(baseName.startsWith("./") || baseName.startsWith("/")) { + } else if (baseName.startsWith("./") || baseName.startsWith("/")) { "" - } - else { - if(targetDirName.endsWith("/")) targetDirName else targetDirName + "/" + } else { + if (targetDirName.endsWith("/")) targetDirName else targetDirName + "/" } } val normalizedSuffix = { - val dottedSuffix = if(suffix.startsWith(".")) suffix else s".$suffix" - if(baseName.endsWith(dottedSuffix)) "" else dottedSuffix + val dottedSuffix = if (suffix.startsWith(".")) suffix else s".$suffix" + if (baseName.endsWith(dottedSuffix)) "" else dottedSuffix } val path = directoryName + baseName.split("/").dropRight(1).mkString("/") FileUtils.makeDirectory(path) diff --git a/src/main/scala/firrtl/FileUtils.scala b/src/main/scala/firrtl/FileUtils.scala index 8e73b4f9..3db86b7c 100644 --- a/src/main/scala/firrtl/FileUtils.scala +++ b/src/main/scala/firrtl/FileUtils.scala @@ -7,7 +7,7 @@ import java.io.File import firrtl.options.StageUtils import scala.collection.Seq -import scala.sys.process.{BasicIO, ProcessLogger, stringSeqToProcess} +import scala.sys.process.{stringSeqToProcess, BasicIO, ProcessLogger} object FileUtils { @@ -17,7 +17,7 @@ object FileUtils { */ def makeDirectory(directoryName: String): Boolean = { val dirFile = new File(directoryName) - if(dirFile.exists()) { + if (dirFile.exists()) { dirFile.isDirectory } else { dirFile.mkdirs() @@ -33,6 +33,7 @@ object FileUtils { def deleteDirectoryHierarchy(directoryPathName: String): Boolean = { deleteDirectoryHierarchy(new File(directoryPathName)) } + /** * recursively delete all directories in a relative path * DO NOT DELETE absolute paths @@ -40,18 +41,18 @@ object FileUtils { * @param file: a directory hierarchy to delete */ def deleteDirectoryHierarchy(file: File, atTop: Boolean = true): Boolean = { - if(file.getPath.split("/").last.isEmpty || + if ( + file.getPath.split("/").last.isEmpty || file.getAbsolutePath == "/" || - file.getPath.startsWith("/")) { + file.getPath.startsWith("/") + ) { StageUtils.dramaticError(s"delete directory ${file.getPath} will not delete absolute paths") false - } - else { + } else { val result = { - if(file.isDirectory) { - file.listFiles().forall( f => deleteDirectoryHierarchy(f)) && file.delete() - } - else { + if (file.isDirectory) { + file.listFiles().forall(f => deleteDirectoryHierarchy(f)) && file.delete() + } else { file.delete() } } @@ -81,7 +82,7 @@ object FileUtils { * @param cmd the command/executable (without any arguments). * @return true if ```cmd``` returns a 0 exit status. */ - def isCommandAvailable(cmd:String): Boolean = { + def isCommandAvailable(cmd: String): Boolean = { isCommandAvailable(Seq(cmd)) } @@ -90,7 +91,7 @@ object FileUtils { * Instead we try to run the executable itself (with innocuous arguments) and interpret any errors/exceptions * as an indication that the executable is unavailable. */ - lazy val isVCSAvailable: Boolean = isCommandAvailable(Seq("vcs", "-platform")) + lazy val isVCSAvailable: Boolean = isCommandAvailable(Seq("vcs", "-platform")) /** Read a text file and return it as a Seq of strings * Closes the file after read to avoid dangling file handles diff --git a/src/main/scala/firrtl/FirrtlException.scala b/src/main/scala/firrtl/FirrtlException.scala index 20d984a1..6f98fda3 100644 --- a/src/main/scala/firrtl/FirrtlException.scala +++ b/src/main/scala/firrtl/FirrtlException.scala @@ -18,7 +18,7 @@ object FIRRTLException { } @deprecated("External users should use either FirrtlUserException or their own hierarchy", "1.2") class FIRRTLException(val str: String, cause: Throwable = null) - extends RuntimeException(FIRRTLException.defaultMessage(str, cause), cause) + extends RuntimeException(FIRRTLException.defaultMessage(str, cause), cause) /** Exception indicating user error * @@ -26,7 +26,8 @@ class FIRRTLException(val str: String, cause: Throwable = null) * This can be extended by custom transform writers. */ class FirrtlUserException(message: String, cause: Throwable = null) - extends RuntimeException(message, cause) with NoStackTrace + extends RuntimeException(message, cause) + with NoStackTrace /** Wraps exceptions from CustomTransforms so they can be reported appropriately */ case class CustomTransformException(cause: Throwable) extends Exception("", cause) @@ -40,4 +41,4 @@ case class CustomTransformException(cause: Throwable) extends Exception("", caus * transforms are treated differently and should thus have their own structure */ private[firrtl] class FirrtlInternalException(message: String, cause: Throwable = null) - extends Exception(message, cause) + extends Exception(message, cause) diff --git a/src/main/scala/firrtl/Implicits.scala b/src/main/scala/firrtl/Implicits.scala index ec1cf3d6..fd732917 100644 --- a/src/main/scala/firrtl/Implicits.scala +++ b/src/main/scala/firrtl/Implicits.scala @@ -7,19 +7,19 @@ import Utils.trim import firrtl.constraint.Constraint object Implicits { - implicit def int2WInt(i: Int): WrappedInt = WrappedInt(BigInt(i)) - implicit def bigint2WInt(i: BigInt): WrappedInt = WrappedInt(i) + implicit def int2WInt(i: Int): WrappedInt = WrappedInt(BigInt(i)) + implicit def bigint2WInt(i: BigInt): WrappedInt = WrappedInt(i) implicit def constraint2bound(c: Constraint): Bound = c match { case x: Bound => x case x => CalcBound(x) } implicit def constraint2width(c: Constraint): Width = c match { case Closed(x) if trim(x).isWhole => IntWidth(x.toBigInt) - case x => CalcWidth(x) + case x => CalcWidth(x) } implicit def width2constraint(w: Width): Constraint = w match { case CalcWidth(x: Constraint) => x - case IntWidth(x) => Closed(BigDecimal(x)) + case IntWidth(x) => Closed(BigDecimal(x)) case UnknownWidth => UnknownBound case v: Constraint => v } diff --git a/src/main/scala/firrtl/LexerHelper.scala b/src/main/scala/firrtl/LexerHelper.scala index cc17ac46..3ddfc5b9 100644 --- a/src/main/scala/firrtl/LexerHelper.scala +++ b/src/main/scala/firrtl/LexerHelper.scala @@ -15,7 +15,7 @@ import firrtl.antlr.FIRRTLParser abstract class LexerHelper { - import FIRRTLParser.{NEWLINE, INDENT, DEDENT} + import FIRRTLParser.{DEDENT, INDENT, NEWLINE} private val tokenBuffer = mutable.Queue.empty[Token] private val indentations = mutable.Stack[Int]() @@ -58,9 +58,9 @@ abstract class LexerHelper { def handleNewlineToken(token: Token): Token = { @tailrec - def nonNewline(token: Token) : (Token, Token) = { + def nonNewline(token: Token): (Token, Token) = { val nextNext = pullToken() - if(nextNext.getType == NEWLINE) + if (nextNext.getType == NEWLINE) nonNewline(nextNext) else (token, nextNext) @@ -94,10 +94,11 @@ abstract class LexerHelper { } } - val t = if (tokenBuffer.isEmpty) - pullToken() - else - tokenBuffer.dequeue + val t = + if (tokenBuffer.isEmpty) + pullToken() + else + tokenBuffer.dequeue if (reachedEof) t @@ -117,8 +118,8 @@ abstract class LexerHelper { setType(tokenType) tokenType match { case `NEWLINE` => setText("<NEWLINE>") - case `INDENT` => setText("<INDENT>") - case `DEDENT` => setText("<DEDENT>") + case `INDENT` => setText("<INDENT>") + case `DEDENT` => setText("<DEDENT>") } } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 19e7d8c6..90881a57 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -91,10 +91,10 @@ class LowFirrtlOptimization extends CoreTransform { } /** Runs runs only the optimization passes needed for Verilog emission */ - @deprecated( - "Use 'new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm)'. This will be removed in 1.4.", - "FIRRTL 1.3" - ) +@deprecated( + "Use 'new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm)'. This will be removed in 1.4.", + "FIRRTL 1.3" +) class MinimumLowFirrtlOptimization extends CoreTransform { def inputForm = LowForm def outputForm = LowForm diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index 3bf89885..e9a698ae 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -12,36 +12,35 @@ object Mappers { } private object PortMagnet { implicit def forType(f: Type => Type): PortMagnet = new PortMagnet { - override def map(port: Port): Port = port mapType f + override def map(port: Port): Port = port.mapType(f) } implicit def forString(f: String => String): PortMagnet = new PortMagnet { - override def map(port: Port): Port = port mapString f + override def map(port: Port): Port = port.mapString(f) } } implicit class PortMap(val _port: Port) extends AnyVal { def map[T](f: T => T)(implicit magnet: (T => T) => PortMagnet): Port = magnet(f).map(_port) } - // ********** Stmt Mappers ********** private trait StmtMagnet { def map(stmt: Statement): Statement } private object StmtMagnet { implicit def forStmt(f: Statement => Statement): StmtMagnet = new StmtMagnet { - override def map(stmt: Statement): Statement = stmt mapStmt f + override def map(stmt: Statement): Statement = stmt.mapStmt(f) } implicit def forExp(f: Expression => Expression): StmtMagnet = new StmtMagnet { - override def map(stmt: Statement): Statement = stmt mapExpr f + override def map(stmt: Statement): Statement = stmt.mapExpr(f) } implicit def forType(f: Type => Type): StmtMagnet = new StmtMagnet { - override def map(stmt: Statement) : Statement = stmt mapType f + override def map(stmt: Statement): Statement = stmt.mapType(f) } implicit def forString(f: String => String): StmtMagnet = new StmtMagnet { - override def map(stmt: Statement): Statement = stmt mapString f + override def map(stmt: Statement): Statement = stmt.mapString(f) } implicit def forInfo(f: Info => Info): StmtMagnet = new StmtMagnet { - override def map(stmt: Statement): Statement = stmt mapInfo f + override def map(stmt: Statement): Statement = stmt.mapInfo(f) } } implicit class StmtMap(val _stmt: Statement) extends AnyVal { @@ -55,13 +54,13 @@ object Mappers { } private object ExprMagnet { implicit def forExpr(f: Expression => Expression): ExprMagnet = new ExprMagnet { - override def map(expr: Expression): Expression = expr mapExpr f + override def map(expr: Expression): Expression = expr.mapExpr(f) } implicit def forType(f: Type => Type): ExprMagnet = new ExprMagnet { - override def map(expr: Expression): Expression = expr mapType f + override def map(expr: Expression): Expression = expr.mapType(f) } implicit def forWidth(f: Width => Width): ExprMagnet = new ExprMagnet { - override def map(expr: Expression): Expression = expr mapWidth f + override def map(expr: Expression): Expression = expr.mapWidth(f) } } implicit class ExprMap(val _expr: Expression) extends AnyVal { @@ -74,10 +73,10 @@ object Mappers { } private object TypeMagnet { implicit def forType(f: Type => Type): TypeMagnet = new TypeMagnet { - override def map(tpe: Type): Type = tpe mapType f + override def map(tpe: Type): Type = tpe.mapType(f) } implicit def forWidth(f: Width => Width): TypeMagnet = new TypeMagnet { - override def map(tpe: Type): Type = tpe mapWidth f + override def map(tpe: Type): Type = tpe.mapWidth(f) } } implicit class TypeMap(val _tpe: Type) extends AnyVal { @@ -91,7 +90,7 @@ object Mappers { private object WidthMagnet { implicit def forWidth(f: Width => Width): WidthMagnet = new WidthMagnet { override def map(width: Width): Width = width match { - case mapable: HasMapWidth => mapable mapWidth f // WIR + case mapable: HasMapWidth => mapable.mapWidth(f) // WIR case other => other // Standard IR nodes } } @@ -106,21 +105,21 @@ object Mappers { } private object ModuleMagnet { implicit def forStmt(f: Statement => Statement): ModuleMagnet = new ModuleMagnet { - override def map(module: DefModule): DefModule = module mapStmt f + override def map(module: DefModule): DefModule = module.mapStmt(f) } implicit def forPorts(f: Port => Port): ModuleMagnet = new ModuleMagnet { - override def map(module: DefModule): DefModule = module mapPort f + override def map(module: DefModule): DefModule = module.mapPort(f) } implicit def forString(f: String => String): ModuleMagnet = new ModuleMagnet { - override def map(module: DefModule): DefModule = module mapString f + override def map(module: DefModule): DefModule = module.mapString(f) } implicit def forInfo(f: Info => Info): ModuleMagnet = new ModuleMagnet { - override def map(module: DefModule): DefModule = module mapInfo f + override def map(module: DefModule): DefModule = module.mapInfo(f) } } implicit class ModuleMap(val _module: DefModule) extends AnyVal { def map[T](f: T => T)(implicit magnet: (T => T) => ModuleMagnet): DefModule = magnet(f).map(_module) - } + } // ********** Circuit Mappers ********** private trait CircuitMagnet { @@ -128,16 +127,16 @@ object Mappers { } private object CircuitMagnet { implicit def forModules(f: DefModule => DefModule): CircuitMagnet = new CircuitMagnet { - override def map(circuit: Circuit): Circuit = circuit mapModule f + override def map(circuit: Circuit): Circuit = circuit.mapModule(f) } implicit def forString(f: String => String): CircuitMagnet = new CircuitMagnet { - override def map(circuit: Circuit): Circuit = circuit mapString f + override def map(circuit: Circuit): Circuit = circuit.mapString(f) } implicit def forInfo(f: Info => Info): CircuitMagnet = new CircuitMagnet { - override def map(circuit: Circuit): Circuit = circuit mapInfo f + override def map(circuit: Circuit): Circuit = circuit.mapInfo(f) } } implicit class CircuitMap(val _circuit: Circuit) extends AnyVal { def map[T](f: T => T)(implicit magnet: (T => T) => CircuitMagnet): Circuit = magnet(f).map(_circuit) - } + } } diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index bb358be6..196539c8 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -29,8 +29,7 @@ class Namespace private { do { str = s"${value}_$idx" idx += 1 - } - while (!(tryName(str))) + } while (!(tryName(str))) indices(value) = idx str } @@ -55,10 +54,10 @@ object Namespace { def buildNamespaceStmt(s: Statement): Seq[String] = s match { case s: IsDeclaration => Seq(s.name) case s: Conditionally => buildNamespaceStmt(s.conseq) ++ buildNamespaceStmt(s.alt) - case s: Block => s.stmts flatMap buildNamespaceStmt + case s: Block => s.stmts.flatMap(buildNamespaceStmt) case _ => Nil } - namespace.namespace ++= m.ports map (_.name) + namespace.namespace ++= m.ports.map(_.name) m match { case in: Module => namespace.namespace ++= buildNamespaceStmt(in.body) @@ -71,11 +70,11 @@ object Namespace { /** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */ def apply(c: Circuit): Namespace = { val namespace = new Namespace - namespace.namespace ++= c.modules map (_.name) + namespace.namespace ++= c.modules.map(_.name) namespace } - /** Initializes a [[Namespace]] from arbitrary strings **/ + /** Initializes a [[Namespace]] from arbitrary strings * */ def apply(names: Seq[String] = Nil): Namespace = { val namespace = new Namespace namespace.namespace ++= names diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala index d3075cbb..40eaa88f 100644 --- a/src/main/scala/firrtl/Parser.scala +++ b/src/main/scala/firrtl/Parser.scala @@ -17,7 +17,6 @@ case class InvalidStringLitException(message: String) extends ParserException(me case class InvalidEscapeCharException(message: String) extends ParserException(message) case class SyntaxErrorsException(message: String) extends ParserException(message) - object Parser extends LazyLogging { /** Parses a file in a given filename and returns a parsed [[firrtl.ir.Circuit Circuit]] */ @@ -57,13 +56,13 @@ object Parser extends LazyLogging { ast } + /** Takes Iterator over lines of FIRRTL, returns FirrtlNode (root node is Circuit) */ def parse(lines: Iterator[String], infoMode: InfoMode = UseInfo): Circuit = parseString(lines.mkString("\n"), infoMode) def parse(lines: Seq[String]): Circuit = parseString(lines.mkString("\n"), UseInfo) - /** Parse the concrete syntax of a FIRRTL [[firrtl.ir.Circuit]], e.g. * {{{ * """circuit Top: @@ -106,7 +105,7 @@ object Parser extends LazyLogging { def parse(lines: Seq[String], infoMode: InfoMode): Circuit = parse(lines.iterator, infoMode) - def parse(text: String, infoMode: InfoMode): Circuit = parse(text split "\n", infoMode) + def parse(text: String, infoMode: InfoMode): Circuit = parse(text.split("\n"), infoMode) /** Parse the concrete syntax of a FIRRTL [[firrtl.ir.Expression]], e.g. * "add(x, y)" becomes: diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 883692c8..baa8638a 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -15,14 +15,14 @@ object PrimOps extends LazyLogging { def w1(e: DoPrim): Width = getWidth(t1(e)) def w2(e: DoPrim): Width = getWidth(t2(e)) def p1(e: DoPrim): Width = t1(e) match { - case FixedType(w, p) => p + case FixedType(w, p) => p case IntervalType(min, max, p) => p - case _ => sys.error(s"Cannot get binary point from ${t1(e)}") + case _ => sys.error(s"Cannot get binary point from ${t1(e)}") } def p2(e: DoPrim): Width = t2(e) match { - case FixedType(w, p) => p + case FixedType(w, p) => p case IntervalType(min, max, p) => p - case _ => sys.error(s"Cannot get binary point from ${t1(e)}") + case _ => sys.error(s"Cannot get binary point from ${t1(e)}") } def c1(e: DoPrim) = IntWidth(e.consts.head) def c2(e: DoPrim) = IntWidth(e.consts(1)) @@ -37,8 +37,16 @@ object PrimOps extends LazyLogging { (t1(e), t2(e)) match { case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) - case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), IntWidth(1)), IsMax(p1(e), p2(e))) - case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, l2), IsAdd(u1, u2), IsMax(p1, p2)) + case (_: FixedType, _: FixedType) => + FixedType( + IsAdd( + IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), + IntWidth(1) + ), + IsMax(p1(e), p2(e)) + ) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + IntervalType(IsAdd(l1, l2), IsAdd(u1, u2), IsMax(p1, p2)) case _ => UnknownType } } @@ -49,8 +57,13 @@ object PrimOps extends LazyLogging { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1))) - case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)),IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))),IntWidth(1)), IsMax(p1(e), p2(e))) - case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, IsNeg(u2)), IsAdd(u1, IsNeg(l2)), IsMax(p1, p2)) + case (_: FixedType, _: FixedType) => + FixedType( + IsAdd(IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), IntWidth(1)), + IsMax(p1(e), p2(e)) + ) + case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => + IntervalType(IsAdd(l1, IsNeg(u2)), IsAdd(u1, IsNeg(l2)), IsMax(p1, p2)) case _ => UnknownType } override def toString = "sub" @@ -70,7 +83,8 @@ object PrimOps extends LazyLogging { ) case _ => UnknownType } - override def toString = "mul" } + override def toString = "mul" + } /** Division */ case object Div extends PrimOp { @@ -79,7 +93,8 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: SIntType) => SIntType(IsAdd(w1(e), IntWidth(1))) case _ => UnknownType } - override def toString = "div" } + override def toString = "div" + } /** Remainder */ case object Rem extends PrimOp { @@ -88,7 +103,9 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: SIntType) => SIntType(MIN(w1(e), w2(e))) case _ => UnknownType } - override def toString = "rem" } + override def toString = "rem" + } + /** Less Than */ case object Lt extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -98,7 +115,9 @@ object PrimOps extends LazyLogging { case (_: IntervalType, _: IntervalType) => Utils.BoolType case _ => UnknownType } - override def toString = "lt" } + override def toString = "lt" + } + /** Less Than Or Equal To */ case object Leq extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -108,7 +127,9 @@ object PrimOps extends LazyLogging { case (_: IntervalType, _: IntervalType) => Utils.BoolType case _ => UnknownType } - override def toString = "leq" } + override def toString = "leq" + } + /** Greater Than */ case object Gt extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -118,7 +139,9 @@ object PrimOps extends LazyLogging { case (_: IntervalType, _: IntervalType) => Utils.BoolType case _ => UnknownType } - override def toString = "gt" } + override def toString = "gt" + } + /** Greater Than Or Equal To */ case object Geq extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -128,7 +151,9 @@ object PrimOps extends LazyLogging { case (_: IntervalType, _: IntervalType) => Utils.BoolType case _ => UnknownType } - override def toString = "geq" } + override def toString = "geq" + } + /** Equal To */ case object Eq extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -138,7 +163,9 @@ object PrimOps extends LazyLogging { case (_: IntervalType, _: IntervalType) => Utils.BoolType case _ => UnknownType } - override def toString = "eq" } + override def toString = "eq" + } + /** Not Equal To */ case object Neq extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -148,31 +175,42 @@ object PrimOps extends LazyLogging { case (_: IntervalType, _: IntervalType) => Utils.BoolType case _ => UnknownType } - override def toString = "neq" } + override def toString = "neq" + } + /** Padding */ case object Pad extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => UIntType(IsMax(w1(e), c1(e))) - case _: SIntType => SIntType(IsMax(w1(e), c1(e))) + case _: UIntType => UIntType(IsMax(w1(e), c1(e))) + case _: SIntType => SIntType(IsMax(w1(e), c1(e))) case _: FixedType => FixedType(IsMax(w1(e), c1(e)), p1(e)) case _ => UnknownType } - override def toString = "pad" } + override def toString = "pad" + } + /** Static Shift Left */ case object Shl extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => UIntType(IsAdd(w1(e), c1(e))) - case _: SIntType => SIntType(IsAdd(w1(e), c1(e))) - case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), p1(e)) - case IntervalType(l, u, p) => IntervalType(IsMul(l, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), IsMul(u, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), p) + case _: UIntType => UIntType(IsAdd(w1(e), c1(e))) + case _: SIntType => SIntType(IsAdd(w1(e), c1(e))) + case _: FixedType => FixedType(IsAdd(w1(e), c1(e)), p1(e)) + case IntervalType(l, u, p) => + IntervalType( + IsMul(l, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), + IsMul(u, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), + p + ) case _ => UnknownType } - override def toString = "shl" } + override def toString = "shl" + } + /** Static Shift Right */ case object Shr extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => UIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1))) - case _: SIntType => SIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1))) + case _: UIntType => UIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1))) + case _: SIntType => SIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1))) case _: FixedType => FixedType(IsMax(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)), p1(e)), p1(e)) case IntervalType(l, u, IntWidth(p)) => val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt)) @@ -187,11 +225,12 @@ object PrimOps extends LazyLogging { } override def toString = "shr" } + /** Dynamic Shift Left */ case object Dshl extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => UIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1)))) - case _: SIntType => SIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1)))) + case _: UIntType => UIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1)))) + case _: SIntType => SIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1)))) case _: FixedType => FixedType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))), p1(e)) case IntervalType(l, u, p) => val maxShiftAmt = IsAdd(IsPow(w2(e)), Closed(-1)) @@ -206,18 +245,20 @@ object PrimOps extends LazyLogging { } override def toString = "dshl" } + /** Dynamic Shift Right */ case object Dshr extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => UIntType(w1(e)) - case _: SIntType => SIntType(w1(e)) + case _: UIntType => UIntType(w1(e)) + case _: SIntType => SIntType(w1(e)) case _: FixedType => FixedType(w1(e), p1(e)) // Decreasing magnitude -- don't need more bits case IntervalType(l, u, p) => IntervalType(l, u, p) - case _ => UnknownType + case _ => UnknownType } override def toString = "dshr" } + /** Arithmetic Convert to Signed */ case object Cvt extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -227,6 +268,7 @@ object PrimOps extends LazyLogging { } override def toString = "cvt" } + /** Negate */ case object Neg extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -236,6 +278,7 @@ object PrimOps extends LazyLogging { } override def toString = "neg" } + /** Bitwise Complement */ case object Not extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -245,6 +288,7 @@ object PrimOps extends LazyLogging { } override def toString = "not" } + /** Bitwise And */ case object And extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -253,6 +297,7 @@ object PrimOps extends LazyLogging { } override def toString = "and" } + /** Bitwise Or */ case object Or extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -261,6 +306,7 @@ object PrimOps extends LazyLogging { } override def toString = "or" } + /** Bitwise Exclusive Or */ case object Xor extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { @@ -269,6 +315,7 @@ object PrimOps extends LazyLogging { } override def toString = "xor" } + /** Bitwise And Reduce */ case object Andr extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -277,6 +324,7 @@ object PrimOps extends LazyLogging { } override def toString = "andr" } + /** Bitwise Or Reduce */ case object Orr extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -285,6 +333,7 @@ object PrimOps extends LazyLogging { } override def toString = "orr" } + /** Bitwise Exclusive Or Reduce */ case object Xorr extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -293,22 +342,30 @@ object PrimOps extends LazyLogging { } override def toString = "xorr" } + /** Concatenate */ case object Cat extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { - case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType, _: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(w1(e), w2(e))) + case ( + _: UIntType | _: SIntType | _: FixedType | _: IntervalType, + _: UIntType | _: SIntType | _: FixedType | _: IntervalType + ) => + UIntType(IsAdd(w1(e), w2(e))) case (t1, t2) => UnknownType } override def toString = "cat" } + /** Bit Extraction */ case object Bits extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(IsAdd(c1(e), IsNeg(c2(e))), IntWidth(1))) + case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => + UIntType(IsAdd(IsAdd(c1(e), IsNeg(c2(e))), IntWidth(1))) case _ => UnknownType } override def toString = "bits" } + /** Head */ case object Head extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -317,6 +374,7 @@ object PrimOps extends LazyLogging { } override def toString = "head" } + /** Tail */ case object Tail extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { @@ -325,20 +383,22 @@ object PrimOps extends LazyLogging { } override def toString = "tail" } - /** Increase Precision **/ + + /** Increase Precision * */ case object IncP extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), IsAdd(p1(e), c1(e))) + case _: FixedType => FixedType(IsAdd(w1(e), c1(e)), IsAdd(p1(e), c1(e))) // Keeps the same exact value, but adds more precision for the future i.e. aaa.bbb -> aaa.bbb00 case IntervalType(l, u, p) => IntervalType(l, u, IsAdd(p, c1(e))) - case _ => UnknownType + case _ => UnknownType } override def toString = "incp" } - /** Decrease Precision **/ + + /** Decrease Precision * */ case object DecP extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: FixedType => FixedType(IsAdd(w1(e),IsNeg(c1(e))), IsAdd(p1(e), IsNeg(c1(e)))) + case _: FixedType => FixedType(IsAdd(w1(e), IsNeg(c1(e))), IsAdd(p1(e), IsNeg(c1(e)))) case IntervalType(l, u, IntWidth(p)) => val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt)) // BP is inferred at this point @@ -355,7 +415,8 @@ object PrimOps extends LazyLogging { } override def toString = "decp" } - /** Set Precision **/ + + /** Set Precision * */ case object SetP extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { case _: FixedType => FixedType(IsAdd(c1(e), IsAdd(w1(e), IsNeg(p1(e)))), c1(e)) @@ -369,84 +430,98 @@ object PrimOps extends LazyLogging { } override def toString = "setp" } + /** Interpret As UInt */ case object AsUInt extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => UIntType(w1(e)) - case _: SIntType => UIntType(w1(e)) + case _: UIntType => UIntType(w1(e)) + case _: SIntType => UIntType(w1(e)) case _: FixedType => UIntType(w1(e)) - case ClockType => UIntType(IntWidth(1)) + case ClockType => UIntType(IntWidth(1)) case AsyncResetType => UIntType(IntWidth(1)) - case ResetType => UIntType(IntWidth(1)) - case AnalogType(w) => UIntType(w1(e)) + case ResetType => UIntType(IntWidth(1)) + case AnalogType(w) => UIntType(w1(e)) case _: IntervalType => UIntType(w1(e)) case _ => UnknownType } override def toString = "asUInt" } + /** Interpret As SInt */ case object AsSInt extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => SIntType(w1(e)) - case _: SIntType => SIntType(w1(e)) + case _: UIntType => SIntType(w1(e)) + case _: SIntType => SIntType(w1(e)) case _: FixedType => SIntType(w1(e)) - case ClockType => SIntType(IntWidth(1)) + case ClockType => SIntType(IntWidth(1)) case AsyncResetType => SIntType(IntWidth(1)) - case ResetType => SIntType(IntWidth(1)) - case _: AnalogType => SIntType(w1(e)) + case ResetType => SIntType(IntWidth(1)) + case _: AnalogType => SIntType(w1(e)) case _: IntervalType => SIntType(w1(e)) case _ => UnknownType } override def toString = "asSInt" } + /** Interpret As Clock */ case object AsClock extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { case _: UIntType => ClockType case _: SIntType => ClockType - case ClockType => ClockType + case ClockType => ClockType case AsyncResetType => ClockType - case ResetType => ClockType - case _: AnalogType => ClockType + case ResetType => ClockType + case _: AnalogType => ClockType case _: IntervalType => ClockType case _ => UnknownType } override def toString = "asClock" } + /** Interpret As AsyncReset */ case object AsAsyncReset extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType | _: IntervalType | _: FixedType => AsyncResetType + case _: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType | _: IntervalType | + _: FixedType => + AsyncResetType case _ => UnknownType } override def toString = "asAsyncReset" } - /** Interpret as Fixed Point **/ + + /** Interpret as Fixed Point * */ case object AsFixedPoint extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { - case _: UIntType => FixedType(w1(e), c1(e)) - case _: SIntType => FixedType(w1(e), c1(e)) + case _: UIntType => FixedType(w1(e), c1(e)) + case _: SIntType => FixedType(w1(e), c1(e)) case _: FixedType => FixedType(w1(e), c1(e)) case ClockType => FixedType(IntWidth(1), c1(e)) case _: AnalogType => FixedType(w1(e), c1(e)) case AsyncResetType => FixedType(IntWidth(1), c1(e)) - case ResetType => FixedType(IntWidth(1), c1(e)) + case ResetType => FixedType(IntWidth(1), c1(e)) case _: IntervalType => FixedType(w1(e), c1(e)) case _ => UnknownType } override def toString = "asFixedPoint" } - /** Interpret as Interval (closed lower bound, closed upper bound, binary point) **/ + + /** Interpret as Interval (closed lower bound, closed upper bound, binary point) * */ case object AsInterval extends PrimOp { override def propagateType(e: DoPrim): Type = t1(e) match { // Chisel shifts up and rounds first. - case _: UIntType | _: SIntType | _: FixedType | ClockType | AsyncResetType | ResetType | _: AnalogType | _: IntervalType => - IntervalType(Closed(BigDecimal(o1(e))/BigDecimal(BigInt(1) << o3(e).toInt)), Closed(BigDecimal(o2(e))/BigDecimal(BigInt(1) << o3(e).toInt)), IntWidth(o3(e))) + case _: UIntType | _: SIntType | _: FixedType | ClockType | AsyncResetType | ResetType | _: AnalogType | + _: IntervalType => + IntervalType( + Closed(BigDecimal(o1(e)) / BigDecimal(BigInt(1) << o3(e).toInt)), + Closed(BigDecimal(o2(e)) / BigDecimal(BigInt(1) << o3(e).toInt)), + IntWidth(o3(e)) + ) case _ => UnknownType } override def toString = "asInterval" } - /** Try to fit the first argument into the type of the smaller argument **/ + + /** Try to fit the first argument into the type of the smaller argument * */ case object Squeeze extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => @@ -457,15 +532,17 @@ object PrimOps extends LazyLogging { } override def toString = "squz" } - /** Wrap First Operand Around Range/Width of Second Operand **/ + + /** Wrap First Operand Around Range/Width of Second Operand * */ case object Wrap extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => IntervalType(l2, u2, p1) - case _ => UnknownType + case _ => UnknownType } override def toString = "wrap" } - /** Clip First Operand At Range/Width of Second Operand **/ + + /** Clip First Operand At Range/Width of Second Operand * */ case object Clip extends PrimOp { override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match { case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => @@ -485,19 +562,20 @@ object PrimOps extends LazyLogging { ) // format: on private lazy val strToPrimOp: Map[String, PrimOp] = { - builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap + builtinPrimOps.map { case op: PrimOp => op.toString -> op }.toMap } /** Seq of String representations of [[ir.PrimOp]]s */ - lazy val listing: Seq[String] = builtinPrimOps map (_.toString) + lazy val listing: Seq[String] = builtinPrimOps.map(_.toString) + /** Gets the corresponding [[ir.PrimOp]] from its String representation */ def fromString(op: String): PrimOp = strToPrimOp(op) // Width Constraint Functions - def PLUS(w1: Width, w2: Width): Constraint = IsAdd(w1, w2) - def MAX(w1: Width, w2: Width): Constraint = IsMax(w1, w2) + def PLUS(w1: Width, w2: Width): Constraint = IsAdd(w1, w2) + def MAX(w1: Width, w2: Width): Constraint = IsMax(w1, w2) def MINUS(w1: Width, w2: Width): Constraint = IsAdd(w1, IsNeg(w2)) - def MIN(w1: Width, w2: Width): Constraint = IsMin(w1, w2) + def MIN(w1: Width, w2: Width): Constraint = IsMin(w1, w2) def set_primop_type(e: DoPrim): DoPrim = DoPrim(e.op, e.args, e.consts, e.op.propagateType(e)) } diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala index 9c848bca..d85998b5 100644 --- a/src/main/scala/firrtl/RenameMap.scala +++ b/src/main/scala/firrtl/RenameMap.scala @@ -38,9 +38,9 @@ object RenameMap { */ // TODO This should probably be refactored into immutable and mutable versions final class RenameMap private ( - val underlying: mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]](), - val chained: Option[RenameMap] = None -) { + val underlying: mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] = + mutable.HashMap[CompleteTarget, Seq[CompleteTarget]](), + val chained: Option[RenameMap] = None) { /** Chain a [[RenameMap]] with this [[RenameMap]] * @param next the map to chain with this map @@ -100,7 +100,7 @@ final class RenameMap private ( * $noteDistinct */ def recordAll(map: collection.Map[CompleteTarget, Seq[CompleteTarget]]): Unit = - map.foreach{ + map.foreach { case (from: IsComponent, tos: Seq[_]) => completeRename(from, tos) case (from: IsModule, tos: Seq[_]) => completeRename(from, tos) case (from: CircuitTarget, tos: Seq[_]) => completeRename(from, tos) @@ -128,7 +128,7 @@ final class RenameMap private ( * @param key Target referencing the original circuit * @return Optionally return sequence of targets that key remaps to */ - def get(key: CircuitTarget): Option[Seq[CircuitTarget]] = completeGet(key).map( _.map { case x: CircuitTarget => x } ) + def get(key: CircuitTarget): Option[Seq[CircuitTarget]] = completeGet(key).map(_.map { case x: CircuitTarget => x }) /** Get renames of a [[firrtl.annotations.IsMember IsMember]] * @param key Target referencing the original member of the circuit @@ -136,12 +136,11 @@ final class RenameMap private ( */ def get(key: IsMember): Option[Seq[IsMember]] = completeGet(key).map { _.map { case x: IsMember => x } } - /** Create new [[RenameMap]] that merges this and renameMap * @param renameMap * @return */ - def ++ (renameMap: RenameMap): RenameMap = { + def ++(renameMap: RenameMap): RenameMap = { val newChained = if (chained.nonEmpty && renameMap.chained.nonEmpty) { Some(chained.get ++ renameMap.chained.get) } else { @@ -168,7 +167,7 @@ final class RenameMap private ( def getReverseRenameMap: RenameMap = { val reverseMap = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]]() - underlying.keysIterator.foreach{ key => + underlying.keysIterator.foreach { key => apply(key).foreach { v => reverseMap(v) = key +: reverseMap.getOrElse(v, Nil) } @@ -181,8 +180,9 @@ final class RenameMap private ( /** Serialize the underlying remapping of keys to new targets * @return */ - def serialize: String = underlying.map { case (k, v) => - k.serialize + "=>" + v.map(_.serialize).mkString(", ") + def serialize: String = underlying.map { + case (k, v) => + k.serialize + "=>" + v.map(_.serialize).mkString(", ") }.mkString("\n") /** Records which local InstanceTargets will require modification. @@ -229,7 +229,8 @@ final class RenameMap private ( val hereRet = (chainedRet.flatMap { target => hereCompleteGet(target).getOrElse(Seq(target)) }).distinct - if (hereRet.size == 1 && hereRet.head == key) { None } else { Some(hereRet) } + if (hereRet.size == 1 && hereRet.head == key) { None } + else { Some(hereRet) } } } else { hereCompleteGet(key) @@ -238,10 +239,11 @@ final class RenameMap private ( private def hereCompleteGet(key: CompleteTarget): Option[Seq[CompleteTarget]] = { val errors = mutable.ArrayBuffer[String]() - val ret = if(hasChanges) { + val ret = if (hasChanges) { val ret = recursiveGet(errors)(key) - if(errors.nonEmpty) { throw IllegalRenameException(errors.mkString("\n")) } - if(ret.size == 1 && ret.head == key) { None } else { Some(ret) } + if (errors.nonEmpty) { throw IllegalRenameException(errors.mkString("\n")) } + if (ret.size == 1 && ret.head == key) { None } + else { Some(ret) } } else { None } ret } @@ -266,50 +268,54 @@ final class RenameMap private ( * @return Renamed targets if a match is found, otherwise None */ private def referenceGet(errors: mutable.ArrayBuffer[String])(key: ReferenceTarget): Option[Seq[IsComponent]] = { - def traverseTokens(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseTokensCache.getOrElseUpdate(key, { - if (underlying.contains(key)) { - Some(underlying(key).flatMap { - case comp: IsComponent => Some(comp) - case other => - errors += s"reference ${key.targetParent} cannot be renamed to a non-component ${other}" - None - }) - } else { - key match { - case t: ReferenceTarget if t.component.nonEmpty => - val last = t.component.last - val parent = t.copy(component = t.component.dropRight(1)) - traverseTokens(parent).map(_.flatMap { x => - (x, last) match { - case (t2: InstanceTarget, Field(f)) => Some(t2.ref(f)) - case (t2: ReferenceTarget, Field(f)) => Some(t2.field(f)) - case (t2: ReferenceTarget, Index(i)) => Some(t2.index(i)) - case other => - errors += s"Illegal rename: ${key.targetParent} cannot be renamed to ${other._1} - must rename $key directly" - None - } - }) - case t: ReferenceTarget => None + def traverseTokens(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseTokensCache.getOrElseUpdate( + key, { + if (underlying.contains(key)) { + Some(underlying(key).flatMap { + case comp: IsComponent => Some(comp) + case other => + errors += s"reference ${key.targetParent} cannot be renamed to a non-component ${other}" + None + }) + } else { + key match { + case t: ReferenceTarget if t.component.nonEmpty => + val last = t.component.last + val parent = t.copy(component = t.component.dropRight(1)) + traverseTokens(parent).map(_.flatMap { x => + (x, last) match { + case (t2: InstanceTarget, Field(f)) => Some(t2.ref(f)) + case (t2: ReferenceTarget, Field(f)) => Some(t2.field(f)) + case (t2: ReferenceTarget, Index(i)) => Some(t2.index(i)) + case other => + errors += s"Illegal rename: ${key.targetParent} cannot be renamed to ${other._1} - must rename $key directly" + None + } + }) + case t: ReferenceTarget => None + } } } - }) - - def traverseHierarchy(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseHierarchyCache.getOrElseUpdate(key, { - val tokenRenamed = traverseTokens(key) - if (tokenRenamed.nonEmpty) { - tokenRenamed - } else { - key match { - case t: ReferenceTarget if t.isLocal => None - case t: ReferenceTarget => - val encapsulatingInstance = t.path.head._1.value - val stripped = t.stripHierarchy(1) - traverseHierarchy(stripped).map(_.map { - _.addHierarchy(t.module, encapsulatingInstance) - }) + ) + + def traverseHierarchy(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseHierarchyCache.getOrElseUpdate( + key, { + val tokenRenamed = traverseTokens(key) + if (tokenRenamed.nonEmpty) { + tokenRenamed + } else { + key match { + case t: ReferenceTarget if t.isLocal => None + case t: ReferenceTarget => + val encapsulatingInstance = t.path.head._1.value + val stripped = t.stripHierarchy(1) + traverseHierarchy(stripped).map(_.map { + _.addHierarchy(t.module, encapsulatingInstance) + }) + } } } - }) + ) traverseHierarchy(key) } @@ -335,64 +341,73 @@ final class RenameMap private ( * @return Renamed targets if a match is found, otherwise None */ private def instanceGet(errors: mutable.ArrayBuffer[String])(key: InstanceTarget): Option[Seq[IsModule]] = { - def traverseLeft(key: InstanceTarget): Option[Seq[IsModule]] = traverseLeftCache.getOrElseUpdate(key, { - val getOpt = underlying.get(key) - - if (getOpt.nonEmpty) { - getOpt.map(_.flatMap { - case isMod: IsModule => Some(isMod) - case other => - errors += s"IsModule: $key cannot be renamed to non-IsModule $other" - None - }) - } else { - key match { - case t: InstanceTarget if t.isLocal => None - case t: InstanceTarget => - val (Instance(outerInst), OfModule(outerMod)) = t.path.head - val stripped = t.copy(path = t.path.tail, module = outerMod) - traverseLeft(stripped).map(_.map { - case absolute if absolute.circuit == absolute.module => absolute - case relative => relative.addHierarchy(t.module, outerInst) - }) + def traverseLeft(key: InstanceTarget): Option[Seq[IsModule]] = traverseLeftCache.getOrElseUpdate( + key, { + val getOpt = underlying.get(key) + + if (getOpt.nonEmpty) { + getOpt.map(_.flatMap { + case isMod: IsModule => Some(isMod) + case other => + errors += s"IsModule: $key cannot be renamed to non-IsModule $other" + None + }) + } else { + key match { + case t: InstanceTarget if t.isLocal => None + case t: InstanceTarget => + val (Instance(outerInst), OfModule(outerMod)) = t.path.head + val stripped = t.copy(path = t.path.tail, module = outerMod) + traverseLeft(stripped).map(_.map { + case absolute if absolute.circuit == absolute.module => absolute + case relative => relative.addHierarchy(t.module, outerInst) + }) + } } } - }) - - def traverseRight(key: InstanceTarget): Option[Seq[IsModule]] = traverseRightCache.getOrElseUpdate(key, { - val findLeft = traverseLeft(key) - if (findLeft.isDefined) { - findLeft - } else { - key match { - case t: InstanceTarget if t.isLocal => None - case t: InstanceTarget => - val (Instance(i), OfModule(m)) = t.path.last - val parent = t.copy(path = t.path.dropRight(1), instance = i, ofModule = m) - traverseRight(parent).map(_.map(_.instOf(t.instance, t.ofModule))) + ) + + def traverseRight(key: InstanceTarget): Option[Seq[IsModule]] = traverseRightCache.getOrElseUpdate( + key, { + val findLeft = traverseLeft(key) + if (findLeft.isDefined) { + findLeft + } else { + key match { + case t: InstanceTarget if t.isLocal => None + case t: InstanceTarget => + val (Instance(i), OfModule(m)) = t.path.last + val parent = t.copy(path = t.path.dropRight(1), instance = i, ofModule = m) + traverseRight(parent).map(_.map(_.instOf(t.instance, t.ofModule))) + } } } - }) + ) traverseRight(key) } private def circuitGet(errors: mutable.ArrayBuffer[String])(key: CircuitTarget): Seq[CircuitTarget] = { - underlying.get(key).map(_.flatMap { - case c: CircuitTarget => Some(c) - case other => - errors += s"Illegal rename: $key cannot be renamed to non-circuit target: $other" - None - }).getOrElse(Seq(key)) + underlying + .get(key) + .map(_.flatMap { + case c: CircuitTarget => Some(c) + case other => + errors += s"Illegal rename: $key cannot be renamed to non-circuit target: $other" + None + }) + .getOrElse(Seq(key)) } private def moduleGet(errors: mutable.ArrayBuffer[String])(key: ModuleTarget): Option[Seq[IsModule]] = { - underlying.get(key).map(_.flatMap { - case mod: IsModule => Some(mod) - case other => - errors += s"Illegal rename: $key cannot be renamed to non-module target: $other" - None - }) + underlying + .get(key) + .map(_.flatMap { + case mod: IsModule => Some(mod) + case other => + errors += s"Illegal rename: $key cannot be renamed to non-module target: $other" + None + }) } // the possible results returned by ofModuleGet @@ -438,10 +453,11 @@ final class RenameMap private ( private def ofModuleGet(errors: mutable.ArrayBuffer[String])(key: IsComponent): OfModuleRenameResult = { val circuit = key.circuit def renameOfModules( - path: Seq[(Instance, OfModule)], - foundRename: Boolean, + path: Seq[(Instance, OfModule)], + foundRename: Boolean, newCircuitOpt: Option[String], - children: Seq[(Instance, OfModule)]): OfModuleRenameResult = { + children: Seq[(Instance, OfModule)] + ): OfModuleRenameResult = { if (path.isEmpty && foundRename) { RenamedOfModules(children) } else if (path.isEmpty) { @@ -489,15 +505,15 @@ final class RenameMap private ( * @return Renamed targets */ private def recursiveGet(errors: mutable.ArrayBuffer[String])(key: CompleteTarget): Seq[CompleteTarget] = { - if(getCache.contains(key)) { + if (getCache.contains(key)) { getCache(key) } else { // rename just the component portion; path/ref/component for ReferenceTargets or path/instance for InstanceTargets val componentRename = key match { - case t: CircuitTarget => None - case t: ModuleTarget => None - case t: InstanceTarget => instanceGet(errors)(t) + case t: CircuitTarget => None + case t: ModuleTarget => None + case t: InstanceTarget => instanceGet(errors)(t) case ref: ReferenceTarget if ref.isLocal => referenceGet(errors)(ref) case ref @ ReferenceTarget(c, m, p, r, t) => val (Instance(inst), OfModule(ofMod)) = p.last @@ -510,7 +526,6 @@ final class RenameMap private ( } } - // if no component rename was found, look for Module renames; root module/OfModules in path val moduleRename = if (componentRename.isDefined) { componentRename @@ -522,7 +537,8 @@ final class RenameMap private ( ofModuleGet(errors)(t) match { case AbsoluteOfModule(absolute) => t match { - case ref: ReferenceTarget => Some(Seq(ref.copy(circuit = absolute.circuit, module = absolute.module, path = absolute.asPath))) + case ref: ReferenceTarget => + Some(Seq(ref.copy(circuit = absolute.circuit, module = absolute.module, path = absolute.asPath))) case inst: InstanceTarget => Some(Seq(absolute)) } case RenamedOfModules(children) => @@ -532,14 +548,16 @@ final class RenameMap private ( val newPath = mod.asPath ++ children t match { - case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath) + case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath) case inst: InstanceTarget => val (Instance(newInst), OfModule(newOfMod)) = newPath.last - inst.copy(circuit = mod.circuit, + inst.copy( + circuit = mod.circuit, module = mod.module, path = newPath.dropRight(1), instance = newInst, - ofModule = newOfMod) + ofModule = newOfMod + ) } } Some(result) @@ -551,14 +569,16 @@ final class RenameMap private ( val newPath = mod.asPath ++ children t match { - case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath) + case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath) case inst: InstanceTarget => val (Instance(newInst), OfModule(newOfMod)) = newPath.last - inst.copy(circuit = mod.circuit, + inst.copy( + circuit = mod.circuit, module = mod.module, path = newPath.dropRight(1), instance = newInst, - ofModule = newOfMod) + ofModule = newOfMod + ) } }) } @@ -579,8 +599,8 @@ final class RenameMap private ( circuitGet(errors)(CircuitTarget(t.circuit)).map { case CircuitTarget(c) => t match { - case ref: ReferenceTarget => ref.copy(circuit = c) - case inst: InstanceTarget => inst.copy(circuit = c) + case ref: ReferenceTarget => ref.copy(circuit = c) + case inst: InstanceTarget => inst.copy(circuit = c) } } } @@ -597,7 +617,7 @@ final class RenameMap private ( * @param tos */ private def completeRename(from: CompleteTarget, tos: Seq[CompleteTarget]): Unit = { - tos.foreach{recordSensitivity(from, _)} + tos.foreach { recordSensitivity(from, _) } val existing = underlying.getOrElse(from, Vector.empty) val updated = (existing ++ tos).distinct underlying(from) = updated @@ -625,29 +645,30 @@ final class RenameMap private ( def delete(name: ComponentName): Unit = underlying(name) = Seq.empty def addMap(map: collection.Map[Named, Seq[Named]]): Unit = - recordAll(map.map { case (key, values) => (Target.convertNamed2Target(key), values.map(Target.convertNamed2Target)) }) + recordAll(map.map { + case (key, values) => (Target.convertNamed2Target(key), values.map(Target.convertNamed2Target)) + }) def get(key: CircuitName): Option[Seq[CircuitName]] = { - get(Target.convertCircuitName2CircuitTarget(key)).map(_.collect{ case c: CircuitTarget => c.toNamed }) + get(Target.convertCircuitName2CircuitTarget(key)).map(_.collect { case c: CircuitTarget => c.toNamed }) } def get(key: ModuleName): Option[Seq[ModuleName]] = { - get(Target.convertModuleName2ModuleTarget(key)).map(_.collect{ case m: ModuleTarget => m.toNamed }) + get(Target.convertModuleName2ModuleTarget(key)).map(_.collect { case m: ModuleTarget => m.toNamed }) } def get(key: ComponentName): Option[Seq[ComponentName]] = { - get(Target.convertComponentName2ReferenceTarget(key)).map(_.collect{ case c: IsComponent => c.toNamed }) + get(Target.convertComponentName2ReferenceTarget(key)).map(_.collect { case c: IsComponent => c.toNamed }) } def get(key: Named): Option[Seq[Named]] = key match { case t: CompleteTarget => get(t) - case other => get(key.toTarget).map(_.collect{ case c: IsComponent => c.toNamed }) + case other => get(key.toTarget).map(_.collect { case c: IsComponent => c.toNamed }) } - // Mutable helpers - APIs that set these are deprecated! private var circuitName: String = "" - private var moduleName: String = "" + private var moduleName: String = "" /** Sets mutable state to record current module we are visiting * @param module @@ -673,7 +694,7 @@ final class RenameMap private ( def rename(from: String, tos: Seq[String]): Unit = { val mn = ModuleName(moduleName, CircuitName(circuitName)) val fromName = ComponentName(from, mn).toTarget - val tosName = tos map { to => ComponentName(to, mn).toTarget } + val tosName = tos.map { to => ComponentName(to, mn).toTarget } record(fromName, tosName) } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index e9af3365..bc285ef3 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -21,24 +21,22 @@ object seqCat { case 1 => args.head case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth)) case _ => - val (high, low) = args splitAt (args.length / 2) + val (high, low) = args.splitAt(args.length / 2) DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth)) } } /** Given an expression, return an expression consisting of all sub-expressions - * concatenated (or flattened). - */ + * concatenated (or flattened). + */ object toBits { def apply(e: Expression): Expression = e match { case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex) case t => Utils.error(s"Invalid operand expression for toBits: $e") } private def hiercat(e: Expression): Expression = e.tpe match { - case t: VectorType => seqCat((0 until t.size).reverse map (i => - hiercat(WSubIndex(e, i, t.tpe, UnknownFlow)))) - case t: BundleType => seqCat(t.fields map (f => - hiercat(WSubField(e, f.name, f.tpe, UnknownFlow)))) + case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UnknownFlow)))) + case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UnknownFlow)))) case t: GroundType => DoPrim(AsUInt, Seq(e), Seq.empty, UnknownType) case t => Utils.error(s"Unknown type encountered in toBits: $e") } @@ -53,12 +51,12 @@ object getWidth { } object bitWidth { - def apply(dt: Type): BigInt = widthOf(dt) + def apply(dt: Type): BigInt = widthOf(dt) private def widthOf(dt: Type): BigInt = dt match { case t: VectorType => t.size * bitWidth(t.tpe) - case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) + case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_ + _) case GroundType(IntWidth(width)) => width - case t => Utils.error(s"Unknown type encountered in bitWidth: $dt") + case t => Utils.error(s"Unknown type encountered in bitWidth: $dt") } } @@ -88,32 +86,28 @@ object fromBits { } Block(fbits._2) } - private def getPartGround(lhs: Expression, - lhst: Type, - rhs: Expression, - offset: BigInt): (BigInt, Seq[Statement]) = { + private def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { val intWidth = bitWidth(lhst) val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType) val rhsConnect = castRhs(lhst, sel) (offset + intWidth, Seq(Connect(NoInfo, lhs, rhsConnect))) } - private def getPart(lhs: Expression, - lhst: Type, - rhs: Expression, - offset: BigInt): (BigInt, Seq[Statement]) = + private def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = lhst match { - case t: VectorType => (0 until t.size foldLeft( (offset, Seq[Statement]()) )) { - case ((curOffset, stmts), i) => - val subidx = WSubIndex(lhs, i, t.tpe, UnknownFlow) - val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) - (tmpOffset, stmts ++ substmts) - } - case t: BundleType => (t.fields foldRight( (offset, Seq[Statement]()) )) { - case (f, (curOffset, stmts)) => - val subfield = WSubField(lhs, f.name, f.tpe, UnknownFlow) - val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) - (tmpOffset, stmts ++ substmts) - } + case t: VectorType => + ((0 until t.size).foldLeft((offset, Seq[Statement]()))) { + case ((curOffset, stmts), i) => + val subidx = WSubIndex(lhs, i, t.tpe, UnknownFlow) + val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) + } + case t: BundleType => + (t.fields.foldRight((offset, Seq[Statement]()))) { + case (f, (curOffset, stmts)) => + val subfield = WSubField(lhs, f.name, f.tpe, UnknownFlow) + val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset) + (tmpOffset, stmts ++ substmts) + } case t: GroundType => getPartGround(lhs, t, rhs, offset) case t => Utils.error(s"Unknown type encountered in fromBits: $lhst") } @@ -129,6 +123,7 @@ object flattenType { } object Utils extends LazyLogging { + /** Unwind the causal chain until we hit the initial exception (which may be the first). * * @param maybeException - possible exception triggering the error, @@ -157,13 +152,16 @@ object Utils extends LazyLogging { * * @param message - possible string to emit, * @param exception - possible exception triggering the error. - */ + */ def throwInternalError(message: String = "", exception: Option[Exception] = None) = { // We'll get the first exception in the chain, keeping it intact. val first = true val throwable = getThrowable(exception, true) val string = if (message.nonEmpty) message + "\n" else message - error("Internal Error! %sPlease file an issue at https://github.com/ucb-bar/firrtl/issues".format(string), throwable) + error( + "Internal Error! %sPlease file an issue at https://github.com/ucb-bar/firrtl/issues".format(string), + throwable + ) } def time[R](block: => R): (Double, R) = { @@ -177,9 +175,9 @@ object Utils extends LazyLogging { /** Removes all [[firrtl.ir.EmptyStmt]] statements and condenses * [[firrtl.ir.Block]] statements. */ - def squashEmpty(s: Statement): Statement = s map squashEmpty match { + def squashEmpty(s: Statement): Statement = s.map(squashEmpty) match { case Block(stmts) => - val newStmts = stmts filter (_ != EmptyStmt) + val newStmts = stmts.filter(_ != EmptyStmt) newStmts.size match { case 0 => EmptyStmt case 1 => newStmts.head @@ -191,43 +189,46 @@ object Utils extends LazyLogging { /** Returns true if PrimOp is a cast, false otherwise */ def isCast(op: PrimOp): Boolean = op match { case AsUInt | AsSInt | AsClock | AsAsyncReset | AsFixedPoint => true - case _ => false + case _ => false } + /** Returns true if Expression is a casting PrimOp, false otherwise */ def isCast(expr: Expression): Boolean = expr match { - case DoPrim(op, _,_,_) if isCast(op) => true - case _ => false + case DoPrim(op, _, _, _) if isCast(op) => true + case _ => false } /** Returns true if PrimOp is a BitExtraction, false otherwise */ def isBitExtract(op: PrimOp): Boolean = op match { case Bits | Head | Tail | Shr => true - case _ => false + case _ => false } + /** Returns true if Expression is a Bits PrimOp, false otherwise */ def isBitExtract(expr: Expression): Boolean = expr match { - case DoPrim(op, _,_, UIntType(_)) if isBitExtract(op) => true - case _ => false + case DoPrim(op, _, _, UIntType(_)) if isBitExtract(op) => true + case _ => false } - /** Provide a nice name to create a temporary **/ + /** Provide a nice name to create a temporary * */ def niceName(e: Expression): String = niceName(1)(e) def niceName(depth: Int)(e: Expression): String = { e match { case Reference(name, _, _, _) if name(0) == '_' => name - case Reference(name, _, _, _) => "_" + name + case Reference(name, _, _, _) => "_" + name case SubAccess(expr, index, _, _) if depth <= 0 => niceName(depth)(expr) - case SubAccess(expr, index, _, _) => niceName(depth)(expr) + niceName(depth - 1)(index) - case SubField(expr, field, _, _) => niceName(depth)(expr) + "_" + field - case SubIndex(expr, index, _, _) => niceName(depth)(expr) + "_" + index - case DoPrim(op, args, consts, _) if depth <= 0 => "_" + op - case DoPrim(op, args, consts, _) => "_" + op + (args.map(niceName(depth - 1)) ++ consts.map("_" + _)).mkString("") - case Mux(cond, tval, fval, _) if depth <= 0 => "_mux" - case Mux(cond, tval, fval, _) => "_mux" + Seq(cond, tval, fval).map(niceName(depth - 1)).mkString("") - case UIntLiteral(value, _) => "_" + value - case SIntLiteral(value, _) => "_" + value + case SubAccess(expr, index, _, _) => niceName(depth)(expr) + niceName(depth - 1)(index) + case SubField(expr, field, _, _) => niceName(depth)(expr) + "_" + field + case SubIndex(expr, index, _, _) => niceName(depth)(expr) + "_" + index + case DoPrim(op, args, consts, _) if depth <= 0 => "_" + op + case DoPrim(op, args, consts, _) => "_" + op + (args.map(niceName(depth - 1)) ++ consts.map("_" + _)).mkString("") + case Mux(cond, tval, fval, _) if depth <= 0 => "_mux" + case Mux(cond, tval, fval, _) => "_mux" + Seq(cond, tval, fval).map(niceName(depth - 1)).mkString("") + case UIntLiteral(value, _) => "_" + value + case SIntLiteral(value, _) => "_" + value } } + /** Maps node name to value */ type NodeMap = mutable.HashMap[String, Expression] @@ -235,18 +236,18 @@ object Utils extends LazyLogging { /** Indent the results of [[ir.FirrtlNode.serialize]] */ @deprecated("Use ther new firrt.ir.Serializer instead.", "FIRRTL 1.4") - def indent(str: String) = str replaceAllLiterally ("\n", "\n ") - - implicit def toWrappedExpression (x:Expression): WrappedExpression = new WrappedExpression(x) - def getSIntWidth(s: BigInt): Int = s.bitLength + 1 - def getUIntWidth(u: BigInt): Int = u.bitLength - def dec2string(v: BigDecimal): String = v.underlying().stripTrailingZeros().toPlainString - def trim(v: BigDecimal): BigDecimal = BigDecimal(dec2string(v)) - def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b - def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a - def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 + def indent(str: String) = str.replaceAllLiterally("\n", "\n ") + + implicit def toWrappedExpression(x: Expression): WrappedExpression = new WrappedExpression(x) + def getSIntWidth(s: BigInt): Int = s.bitLength + 1 + def getUIntWidth(u: BigInt): Int = u.bitLength + def dec2string(v: BigDecimal): String = v.underlying().stripTrailingZeros().toPlainString + def trim(v: BigDecimal): BigDecimal = BigDecimal(dec2string(v)) + def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b + def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a + def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 val BoolType = UIntType(IntWidth(1)) - val one = UIntLiteral(1) + val one = UIntLiteral(1) val zero = UIntLiteral(0) def create_exps(n: String, t: Type): Seq[Expression] = @@ -255,16 +256,18 @@ object Utils extends LazyLogging { case ex: Mux => val e1s = create_exps(ex.tval) val e2s = create_exps(ex.fval) - e1s zip e2s map {case (e1, e2) => - Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2)) + e1s.zip(e2s).map { + case (e1, e2) => + Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2)) + } + case ex: ValidIf => create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe)) + case ex => + ex.tpe match { + case (_: GroundType) => Seq(ex) + case t: BundleType => + t.fields.flatMap(f => create_exps(WSubField(ex, f.name, f.tpe, times(flow(ex), f.flip)))) + case t: VectorType => (0 until t.size).flatMap(i => create_exps(WSubIndex(ex, i, t.tpe, flow(ex)))) } - case ex: ValidIf => create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe)) - case ex => ex.tpe match { - case (_: GroundType) => Seq(ex) - case t: BundleType => - t.fields.flatMap(f => create_exps(WSubField(ex, f.name, f.tpe,times(flow(ex), f.flip)))) - case t: VectorType => (0 until t.size).flatMap(i => create_exps(WSubIndex(ex, i, t.tpe,flow(ex)))) - } } /** Like create_exps, but returns intermediate Expressions as well @@ -275,26 +278,28 @@ object Utils extends LazyLogging { case ex: Mux => val e1s = expandRef(ex.tval) val e2s = expandRef(ex.fval) - e1s zip e2s map {case (e1, e2) => - Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2)) + e1s.zip(e2s).map { + case (e1, e2) => + Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2)) + } + case ex: ValidIf => expandRef(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe)) + case ex => + ex.tpe match { + case (_: GroundType) => Seq(ex) + case (t: BundleType) => + ex +: t.fields.flatMap(f => expandRef(WSubField(ex, f.name, f.tpe, times(flow(ex), f.flip)))) + case (t: VectorType) => + ex +: (0 until t.size).flatMap(i => expandRef(WSubIndex(ex, i, t.tpe, flow(ex)))) } - case ex: ValidIf => expandRef(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe)) - case ex => ex.tpe match { - case (_: GroundType) => Seq(ex) - case (t: BundleType) => - ex +: t.fields.flatMap(f => expandRef(WSubField(ex, f.name, f.tpe, times(flow(ex), f.flip)))) - case (t: VectorType) => - ex +: (0 until t.size).flatMap(i => expandRef(WSubIndex(ex, i, t.tpe, flow(ex)))) - } } def toTarget(main: String, module: String)(expression: Expression): ReferenceTarget = { val tokens = mutable.ArrayBuffer[TargetToken]() var ref = "???" def onExp(expr: Expression): Expression = { - expr map onExp match { + expr.map(onExp) match { case e: Reference => ref = e.name - case e: SubField => tokens += TargetToken.Field(e.name) - case e: SubIndex => tokens += TargetToken.Index(e.value) + case e: SubField => tokens += TargetToken.Field(e.name) + case e: SubIndex => tokens += TargetToken.Index(e.value) case other => throwInternalError("Cannot call Utils.toTarget on non-referencing expression") } expr @@ -302,39 +307,42 @@ object Utils extends LazyLogging { onExp(expression) ReferenceTarget(main, module, Nil, ref, tokens.toSeq) } - @deprecated("get_flip is fundamentally slow, use to_flip(flow(expr))", "1.2") - def get_flip(t: Type, i: Int, f: Orientation): Orientation = { - if (i >= get_size(t)) throwInternalError(s"get_flip: shouldn't be here - $i >= get_size($t)") - t match { - case (_: GroundType) => f - case (tx: BundleType) => - val (_, flip) = tx.fields.foldLeft( (i, None: Option[Orientation]) ) { - case ((n, ret), x) if n < get_size(x.tpe) => ret match { - case None => (n, Some(get_flip(x.tpe, n, times(x.flip, f)))) - case Some(_) => (n, ret) - } - case ((n, ret), x) => (n - get_size(x.tpe), ret) - } - flip.get - case (tx: VectorType) => - val (_, flip) = (0 until tx.size).foldLeft( (i, None: Option[Orientation]) ) { - case ((n, ret), x) if n < get_size(tx.tpe) => ret match { - case None => (n, Some(get_flip(tx.tpe, n, f))) - case Some(_) => (n, ret) - } - case ((n, ret), x) => (n - get_size(tx.tpe), ret) - } - flip.get - } - } - - def get_point (e:Expression) : Int = e match { - case (e: WRef) => 0 - case (e: WSubField) => e.expr.tpe match {case b: BundleType => - (b.fields takeWhile (_.name != e.name) foldLeft 0)( - (point, f) => point + get_size(f.tpe)) + @deprecated("get_flip is fundamentally slow, use to_flip(flow(expr))", "1.2") + def get_flip(t: Type, i: Int, f: Orientation): Orientation = { + if (i >= get_size(t)) throwInternalError(s"get_flip: shouldn't be here - $i >= get_size($t)") + t match { + case (_: GroundType) => f + case (tx: BundleType) => + val (_, flip) = tx.fields.foldLeft((i, None: Option[Orientation])) { + case ((n, ret), x) if n < get_size(x.tpe) => + ret match { + case None => (n, Some(get_flip(x.tpe, n, times(x.flip, f)))) + case Some(_) => (n, ret) + } + case ((n, ret), x) => (n - get_size(x.tpe), ret) + } + flip.get + case (tx: VectorType) => + val (_, flip) = (0 until tx.size).foldLeft((i, None: Option[Orientation])) { + case ((n, ret), x) if n < get_size(tx.tpe) => + ret match { + case None => (n, Some(get_flip(tx.tpe, n, f))) + case Some(_) => (n, ret) + } + case ((n, ret), x) => (n - get_size(tx.tpe), ret) + } + flip.get } - case (e: WSubIndex) => e.value * get_size(e.tpe) + } + + def get_point(e: Expression): Int = e match { + case (e: WRef) => 0 + case (e: WSubField) => + e.expr.tpe match { + case b: BundleType => + (b.fields.takeWhile(_.name != e.name).foldLeft(0))((point, f) => point + get_size(f.tpe)) + } + case (e: WSubIndex) => e.value * get_size(e.tpe) case (e: WSubAccess) => get_point(e.expr) } @@ -345,8 +353,8 @@ object Utils extends LazyLogging { */ def hasFlip(t: Type): Boolean = t match { case t: BundleType => - (t.fields exists (_.flip == Flip)) || - (t.fields exists (f => hasFlip(f.tpe))) + (t.fields.exists(_.flip == Flip)) || + (t.fields.exists(f => hasFlip(f.tpe))) case t: VectorType => hasFlip(t.tpe) case _ => false } @@ -358,17 +366,17 @@ object Utils extends LazyLogging { kids += e e } - e map addKids + e.map(addKids) kids.toSeq } /** Walks two expression trees and returns a sequence of tuples of where they differ */ def diff(e1: Expression, e2: Expression): Seq[(Expression, Expression)] = { - if(weq(e1, e2)) Nil + if (weq(e1, e2)) Nil else { val (e1Kids, e2Kids) = (getKids(e1), getKids(e2)) - if(e1Kids == Nil || e2Kids == Nil || e1Kids.size != e2Kids.size) Seq((e1, e2)) + if (e1Kids == Nil || e2Kids == Nil || e1Kids.size != e2Kids.size) Seq((e1, e2)) else { e1Kids.zip(e2Kids).flatMap { case (e1k, e2k) => diff(e1k, e2k) } } @@ -378,65 +386,67 @@ object Utils extends LazyLogging { /** Returns an inlined expression (replacing node references with values), * stopping on a stopping condition or until the reference is not a node */ - def inline(nodeMap: NodeMap, stop: String => Boolean = {x: String => false})(e: Expression): Expression = { - def onExp(e: Expression): Expression = e map onExp match { + def inline(nodeMap: NodeMap, stop: String => Boolean = { x: String => false })(e: Expression): Expression = { + def onExp(e: Expression): Expression = e.map(onExp) match { case Reference(name, _, _, _) if nodeMap.contains(name) && !stop(name) => onExp(nodeMap(name)) - case other => other + case other => other } onExp(e) } def mux_type(e1: Expression, e2: Expression): Type = mux_type(e1.tpe, e2.tpe) - def mux_type(t1: Type, t2: Type): Type = (t1, t2) match { - case (ClockType, ClockType) => ClockType + def mux_type(t1: Type, t2: Type): Type = (t1, t2) match { + case (ClockType, ClockType) => ClockType case (AsyncResetType, AsyncResetType) => AsyncResetType case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth) case (t1: FixedType, t2: FixedType) => FixedType(UnknownWidth, UnknownWidth) case (t1: IntervalType, t2: IntervalType) => IntervalType(UnknownBound, UnknownBound, UnknownWidth) case (t1: VectorType, t2: VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) - case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map { - case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) - }) + case (t1: BundleType, t2: BundleType) => + BundleType(t1.fields.zip(t2.fields).map { + case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) + }) case _ => UnknownType } - def mux_type_and_widths(e1: Expression,e2: Expression): Type = + def mux_type_and_widths(e1: Expression, e2: Expression): Type = mux_type_and_widths(e1.tpe, e2.tpe) def mux_type_and_widths(t1: Type, t2: Type): Type = { def wmax(w1: Width, w2: Width): Width = (w1, w2) match { - case (w1x: IntWidth, w2x: IntWidth) => IntWidth(w1x.width max w2x.width) + case (w1x: IntWidth, w2x: IntWidth) => IntWidth(w1x.width.max(w2x.width)) case (w1x, w2x) => IsMax(w1x, w2x) } (t1, t2) match { - case (ClockType, ClockType) => ClockType + case (ClockType, ClockType) => ClockType case (AsyncResetType, AsyncResetType) => AsyncResetType case (t1x: UIntType, t2x: UIntType) => UIntType(IsMax(t1x.width, t2x.width)) case (t1x: SIntType, t2x: SIntType) => SIntType(IsMax(t1x.width, t2x.width)) case (FixedType(w1, p1), FixedType(w2, p2)) => - FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2)) + FixedType(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2)) case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsMin(l1, l2), constraint.IsMax(u1, u2), MAX(p1, p2)) - case (t1x: VectorType, t2x: VectorType) => VectorType( - mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size) - case (t1x: BundleType, t2x: BundleType) => BundleType(t1x.fields zip t2x.fields map { - case (f1, f2) => Field(f1.name, f1.flip, mux_type_and_widths(f1.tpe, f2.tpe)) - }) + case (t1x: VectorType, t2x: VectorType) => VectorType(mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size) + case (t1x: BundleType, t2x: BundleType) => + BundleType(t1x.fields.zip(t2x.fields).map { + case (f1, f2) => Field(f1.name, f1.flip, mux_type_and_widths(f1.tpe, f2.tpe)) + }) case _ => UnknownType } } - def module_type(m: DefModule): BundleType = BundleType(m.ports map { + def module_type(m: DefModule): BundleType = BundleType(m.ports.map { case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) }) def sub_type(v: Type): Type = v match { case vx: VectorType => vx.tpe case vx => UnknownType } - def field_type(v: Type, s: String) : Type = v match { - case vx: BundleType => vx.fields find (_.name == s) match { - case Some(f) => f.tpe - case None => UnknownType - } + def field_type(v: Type, s: String): Type = v match { + case vx: BundleType => + vx.fields.find(_.name == s) match { + case Some(f) => f.tpe + case None => UnknownType + } case vx => UnknownType } @@ -445,13 +455,12 @@ object Utils extends LazyLogging { //// =============== EXPANSION FUNCTIONS ================ def get_size(t: Type): Int = t match { - case tx: BundleType => (tx.fields foldLeft 0)( - (sum, f) => sum + get_size(f.tpe)) + case tx: BundleType => (tx.fields.foldLeft(0))((sum, f) => sum + get_size(f.tpe)) case tx: VectorType => tx.size * get_size(tx.tpe) case tx => 1 } - def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = { + def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int, Int)] = { import passes.CheckTypes.legalResetType //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) (t1, t2) match { @@ -461,27 +470,39 @@ object Utils extends LazyLogging { case (_: AnalogType, _: AnalogType) => if (flip1 == flip2) Seq((0, 0)) else Nil case (t1x: BundleType, t2x: BundleType) => def emptyMap = Map[String, (Type, Orientation, Int)]() - val t1_fields = t1x.fields.foldLeft( (emptyMap, 0) ) { case ((map, ilen), f1) => - (map + (f1.name ->( (f1.tpe, f1.flip, ilen) )), ilen + get_size(f1.tpe)) - }._1 - t2x.fields.foldLeft( (Seq[(Int, Int)](), 0) ) { case ((points, jlen), f2) => - t1_fields get f2.name match { - case None => (points, jlen + get_size(f2.tpe)) - case Some((f1_tpe, f1_flip, ilen)) => - val f1_times = times(flip1, f1_flip) - val f2_times = times(flip2, f2.flip) - val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times) - (points ++ (ls map { case (x, y) => (x + ilen, y + jlen) }), jlen + get_size(f2.tpe)) + val t1_fields = t1x.fields + .foldLeft((emptyMap, 0)) { + case ((map, ilen), f1) => + (map + (f1.name -> ((f1.tpe, f1.flip, ilen))), ilen + get_size(f1.tpe)) + } + ._1 + t2x.fields + .foldLeft((Seq[(Int, Int)](), 0)) { + case ((points, jlen), f2) => + t1_fields.get(f2.name) match { + case None => (points, jlen + get_size(f2.tpe)) + case Some((f1_tpe, f1_flip, ilen)) => + val f1_times = times(flip1, f1_flip) + val f2_times = times(flip2, f2.flip) + val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times) + (points ++ (ls.map { case (x, y) => (x + ilen, y + jlen) }), jlen + get_size(f2.tpe)) + } } - }._1 + ._1 case (t1x: VectorType, t2x: VectorType) => val size = math.min(t1x.size, t2x.size) - (0 until size).foldLeft( (Seq[(Int, Int)](), 0, 0) ) { case ((points, ilen, jlen), _) => - val ls = get_valid_points(t1x.tpe, t2x.tpe, flip1, flip2) - (points ++ (ls map { case (x, y) => (x + ilen, y + jlen) }), - ilen + get_size(t1x.tpe), jlen + get_size(t2x.tpe)) - }._1 - case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil + (0 until size) + .foldLeft((Seq[(Int, Int)](), 0, 0)) { + case ((points, ilen, jlen), _) => + val ls = get_valid_points(t1x.tpe, t2x.tpe, flip1, flip2) + ( + points ++ (ls.map { case (x, y) => (x + ilen, y + jlen) }), + ilen + get_size(t1x.tpe), + jlen + get_size(t2x.tpe) + ) + } + ._1 + case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil case (AsyncResetType, AsyncResetType) => if (flip1 == flip2) Seq((0, 0)) else Nil // The following two cases handle driving ResetType from other legal reset types // Flippedness is important here because ResetType can be driven by other reset types, but it @@ -495,112 +516,114 @@ object Utils extends LazyLogging { } // =========== FLOW/FLIP UTILS ============ - def swap(g: Flow) : Flow = g match { + def swap(g: Flow): Flow = g match { case UnknownFlow => UnknownFlow - case SourceFlow => SinkFlow - case SinkFlow => SourceFlow - case DuplexFlow => DuplexFlow + case SourceFlow => SinkFlow + case SinkFlow => SourceFlow + case DuplexFlow => DuplexFlow } - def swap(d: Direction) : Direction = d match { + def swap(d: Direction): Direction = d match { case Output => Input - case Input => Output + case Input => Output } - def swap(f: Orientation) : Orientation = f match { + def swap(f: Orientation): Orientation = f match { case Default => Flip - case Flip => Default + case Flip => Default } // Input <-> SourceFlow <-> Flip // Output <-> SinkFlow <-> Default def to_dir(g: Flow): Direction = g match { case SourceFlow => Input - case SinkFlow => Output + case SinkFlow => Output } def to_dir(o: Orientation): Direction = o match { - case Flip => Input + case Flip => Input case Default => Output } def to_flow(d: Direction): Flow = d match { - case Input => SourceFlow + case Input => SourceFlow case Output => SinkFlow } def to_flip(d: Direction): Orientation = d match { - case Input => Flip + case Input => Flip case Output => Default } def to_flip(g: Flow): Orientation = g match { case SourceFlow => Flip - case SinkFlow => Default + case SinkFlow => Default } def field_flip(v: Type, s: String): Orientation = v match { - case vx: BundleType => vx.fields find (_.name == s) match { - case Some(ft) => ft.flip - case None => Default - } + case vx: BundleType => + vx.fields.find(_.name == s) match { + case Some(ft) => ft.flip + case None => Default + } case vx => Default } def get_field(v: Type, s: String): Field = v match { - case vx: BundleType => vx.fields find (_.name == s) match { - case Some(ft) => ft - case None => throwInternalError(s"get_field: shouldn't be here - $v.$s") - } + case vx: BundleType => + vx.fields.find(_.name == s) match { + case Some(ft) => ft + case None => throwInternalError(s"get_field: shouldn't be here - $v.$s") + } case vx => throwInternalError(s"get_field: shouldn't be here - $v") } - def times(d: Direction,flip: Orientation): Direction = flip match { + def times(d: Direction, flip: Orientation): Direction = flip match { case Default => d - case Flip => swap(d) + case Flip => swap(d) } - def times(g: Flow, d: Direction): Direction = times(d, g) + def times(g: Flow, d: Direction): Direction = times(d, g) def times(d: Direction, g: Flow): Direction = g match { - case SinkFlow => d + case SinkFlow => d case SourceFlow => swap(d) // SourceFlow == INPUT == REVERSE } - def times(g: Flow, flip: Orientation): Flow = times(flip, g) + def times(g: Flow, flip: Orientation): Flow = times(flip, g) def times(flip: Orientation, g: Flow): Flow = flip match { case Default => g - case Flip => swap(g) + case Flip => swap(g) } def times(f1: Orientation, f2: Orientation): Orientation = f2 match { case Default => f1 - case Flip => swap(f1) + case Flip => swap(f1) } // =========== ACCESSORS ========= def kind(e: Expression): Kind = e match { - case ex: WRef => ex.kind - case ex: WSubField => kind(ex.expr) - case ex: WSubIndex => kind(ex.expr) + case ex: WRef => ex.kind + case ex: WSubField => kind(ex.expr) + case ex: WSubIndex => kind(ex.expr) case ex: WSubAccess => kind(ex.expr) case ex => ExpKind } def flow(e: Expression): Flow = e match { - case ex: WRef => ex.flow - case ex: WSubField => ex.flow - case ex: WSubIndex => ex.flow - case ex: WSubAccess => ex.flow - case ex: DoPrim => SourceFlow + case ex: WRef => ex.flow + case ex: WSubField => ex.flow + case ex: WSubIndex => ex.flow + case ex: WSubAccess => ex.flow + case ex: DoPrim => SourceFlow case ex: UIntLiteral => SourceFlow case ex: SIntLiteral => SourceFlow - case ex: Mux => SourceFlow - case ex: ValidIf => SourceFlow + case ex: Mux => SourceFlow + case ex: ValidIf => SourceFlow case WInvalid => SourceFlow - case ex => throwInternalError(s"flow: shouldn't be here - $e") + case ex => throwInternalError(s"flow: shouldn't be here - $e") } def get_flow(s: Statement): Flow = s match { - case sx: DefWire => DuplexFlow - case sx: DefRegister => DuplexFlow - case sx: WDefInstance => SourceFlow - case sx: DefNode => SourceFlow - case sx: DefInstance => SourceFlow - case sx: DefMemory => SourceFlow - case sx: Block => UnknownFlow - case sx: Connect => UnknownFlow + case sx: DefWire => DuplexFlow + case sx: DefRegister => DuplexFlow + case sx: WDefInstance => SourceFlow + case sx: DefNode => SourceFlow + case sx: DefInstance => SourceFlow + case sx: DefMemory => SourceFlow + case sx: Block => UnknownFlow + case sx: Connect => UnknownFlow case sx: PartialConnect => UnknownFlow - case sx: Stop => UnknownFlow - case sx: Print => UnknownFlow - case sx: IsInvalid => UnknownFlow + case sx: Stop => UnknownFlow + case sx: Print => UnknownFlow + case sx: IsInvalid => UnknownFlow case EmptyStmt => UnknownFlow } def get_flow(p: Port): Flow = if (p.direction == Input) SourceFlow else SinkFlow @@ -630,7 +653,7 @@ object Utils extends LazyLogging { val (root, tail) = splitRef(e.expr) tail match { case EmptyExpression => (root, WRef(e.name, e.tpe, root.kind, e.flow)) - case exp => (root, WSubField(tail, e.name, e.tpe, e.flow)) + case exp => (root, WSubField(tail, e.name, e.tpe, e.flow)) } } @@ -657,28 +680,28 @@ object Utils extends LazyLogging { def getDeclaration(m: Module, expr: Expression): IsDeclaration = { def getRootDecl(name: String)(s: Statement): Option[IsDeclaration] = s match { case decl: IsDeclaration => if (decl.name == name) Some(decl) else None - case c: Conditionally => + case c: Conditionally => val m = (getRootDecl(name)(c.conseq), getRootDecl(name)(c.alt)) (m: @unchecked) match { case (Some(decl), None) => Some(decl) case (None, Some(decl)) => Some(decl) - case (None, None) => None + case (None, None) => None } case begin: Block => - val stmts = begin.stmts flatMap getRootDecl(name) // can we short circuit? + val stmts = begin.stmts.flatMap(getRootDecl(name)) // can we short circuit? if (stmts.nonEmpty) Some(stmts.head) else None case _ => None } expr match { case (_: WRef | _: WSubIndex | _: WSubField) => val (root, tail) = splitRef(expr) - val rootDecl = m.ports find (_.name == root.name) match { + val rootDecl = m.ports.find(_.name == root.name) match { case Some(decl) => decl case None => getRootDecl(root.name)(m.body) match { case Some(decl) => decl - case None => throw new DeclarationNotFoundException( - s"[module ${m.name}] Reference ${expr.serialize} not declared!") + case None => + throw new DeclarationNotFoundException(s"[module ${m.name}] Reference ${expr.serialize} not declared!") } } rootDecl @@ -771,7 +794,7 @@ object Utils extends LazyLogging { .findAllMatchIn(name) .map(_.end - 1) .toSeq - .foldLeft(Seq[String]()){ case (seq, id) => seq :+ name.splitAt(id)._1 } + .foldLeft(Seq[String]()) { case (seq, id) => seq :+ name.splitAt(id)._1 } } /** Returns the value masked with the width. @@ -785,14 +808,14 @@ object Utils extends LazyLogging { } object MemoizedHash { - implicit def convertTo[T](e: T): MemoizedHash[T] = new MemoizedHash(e) + implicit def convertTo[T](e: T): MemoizedHash[T] = new MemoizedHash(e) implicit def convertFrom[T](f: MemoizedHash[T]): T = f.t } class MemoizedHash[T](val t: T) { override lazy val hashCode = t.hashCode override def equals(that: Any) = that match { - case x: MemoizedHash[_] => t equals x.t + case x: MemoizedHash[_] => t.equals(x.t) case _ => false } } @@ -833,13 +856,12 @@ class ModuleGraph { def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = { nodes.get(child) match { case Some(children) => - if(children(parent)) { + if (children(parent)) { parent :: path - } - else { + } else { children.foreach { grandchild => val newPath = pathExists(grandchild, parent, grandchild :: path) - if(newPath.nonEmpty) { + if (newPath.nonEmpty) { return newPath } } diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 502d021d..b14c39c7 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -13,7 +13,6 @@ import Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo} import firrtl.ir._ import Utils.throwInternalError - class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] with ParseTreeVisitor[FirrtlNode] { // Strip file path private def stripPath(filename: String) = filename.drop(filename.lastIndexOf("/") + 1) @@ -21,7 +20,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w // Check if identifier is made of legal characters private def legalId(id: String) = { val legalChars = ('A' to 'Z').toSet ++ ('a' to 'z').toSet ++ ('0' to '9').toSet ++ Set('_', '$') - id forall legalChars + id.forall(legalChars) } def visit(ctx: CircuitContext): Circuit = visitCircuit(ctx) @@ -37,22 +36,22 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def string2BigInt(s: String): BigInt = { // private define legal patterns s match { - case ZeroPattern(_*) => BigInt(0) - case HexPattern(hexdigits) => BigInt(hexdigits, 16) - case OctalPattern(octaldigits) => BigInt(octaldigits, 8) + case ZeroPattern(_*) => BigInt(0) + case HexPattern(hexdigits) => BigInt(hexdigits, 16) + case OctalPattern(octaldigits) => BigInt(octaldigits, 8) case BinaryPattern(binarydigits) => BigInt(binarydigits, 2) - case DecPattern(num) => BigInt(num, 10) - case _ => throw new Exception("Invalid String for conversion to BigInt " + s) + case DecPattern(num) => BigInt(num, 10) + case _ => throw new Exception("Invalid String for conversion to BigInt " + s) } } private def string2BigDecimal(s: String): BigDecimal = { // private define legal patterns s match { - case ZeroPattern(_*) => BigDecimal(0) - case DecPattern(num) => BigDecimal(num) + case ZeroPattern(_*) => BigDecimal(0) + case DecPattern(num) => BigDecimal(num) case DecimalPattern(num) => BigDecimal(num) - case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s) + case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s) } } @@ -64,7 +63,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w parentCtx.getStart.getCharPositionInLine lazy val useInfo: String = ctx match { case Some(info) => info.getText.drop(2).init // remove surrounding @[ ... ] - case None => "" + case None => "" } infoMode match { case UseInfo => @@ -88,14 +87,19 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def visitModule(ctx: ModuleContext): DefModule = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0).getText match { - case "module" => Module(info, ctx.id.getText, ctx.port.asScala.map(visitPort).toSeq, - if (ctx.moduleBlock() != null) - visitBlock(ctx.moduleBlock()) - else EmptyStmt) + case "module" => + Module( + info, + ctx.id.getText, + ctx.port.asScala.map(visitPort).toSeq, + if (ctx.moduleBlock() != null) + visitBlock(ctx.moduleBlock()) + else EmptyStmt + ) case "extmodule" => val defname = if (ctx.defname != null) ctx.defname.id.getText else ctx.id.getText - val ports = ctx.port.asScala map visitPort - val params = ctx.parameter.asScala map visitParameter + val ports = ctx.port.asScala.map(visitPort) + val params = ctx.parameter.asScala.map(visitParameter) ExtModule(info, ctx.id.getText, ports.toSeq, defname, params.toSeq) } } @@ -111,22 +115,22 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w case (null, str, null, null) => StringParam(name, visitStringLit(str)) case (null, null, dbl, null) => DoubleParam(name, dbl.getText.toDouble) case (null, null, null, raw) => RawStringParam(name, raw.getText.tail.init.replace("\\'", "'")) // Remove "\'"s - case _ => throwInternalError(s"visiting impossible parameter ${ctx.getText}") + case _ => throwInternalError(s"visiting impossible parameter ${ctx.getText}") } } private def visitDir(ctx: DirContext): Direction = ctx.getText match { - case "input" => Input + case "input" => Input case "output" => Output } private def visitMdir(ctx: MdirContext): MPortDir = ctx.getText match { case "infer" => MInfer - case "read" => MRead + case "read" => MRead case "write" => MWrite - case "rdwr" => MReadWrite + case "rdwr" => MReadWrite } // Match on a type instead of on strings? @@ -135,47 +139,53 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w ctx.getChild(0) match { case term: TerminalNode => term.getText match { - case "UInt" => if (ctx.getChildCount > 1) UIntType(getWidth(ctx.intLit(0))) - else UIntType(UnknownWidth) - case "SInt" => if (ctx.getChildCount > 1) SIntType(getWidth(ctx.intLit(0))) - else SIntType(UnknownWidth) - case "Fixed" => ctx.intLit.size match { - case 0 => FixedType(UnknownWidth, UnknownWidth) - case 1 => ctx.getChild(2).getText match { - case "<" => FixedType(UnknownWidth, getWidth(ctx.intLit(0))) - case _ => FixedType(getWidth(ctx.intLit(0)), UnknownWidth) + case "UInt" => + if (ctx.getChildCount > 1) UIntType(getWidth(ctx.intLit(0))) + else UIntType(UnknownWidth) + case "SInt" => + if (ctx.getChildCount > 1) SIntType(getWidth(ctx.intLit(0))) + else SIntType(UnknownWidth) + case "Fixed" => + ctx.intLit.size match { + case 0 => FixedType(UnknownWidth, UnknownWidth) + case 1 => + ctx.getChild(2).getText match { + case "<" => FixedType(UnknownWidth, getWidth(ctx.intLit(0))) + case _ => FixedType(getWidth(ctx.intLit(0)), UnknownWidth) + } + case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1))) } - case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1))) - } - case "Interval" => ctx.boundValue.size match { - case 0 => - val point = ctx.intLit.size match { - case 0 => UnknownWidth - case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) - } - IntervalType(UnknownBound, UnknownBound, point) - case 2 => - val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match { - case (_, "?") => UnknownBound - case ("(", v) => Open(string2BigDecimal(v)) - case ("[", v) => Closed(string2BigDecimal(v)) - } - val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match { - case (_, "?") => UnknownBound - case (")", v) => Open(string2BigDecimal(v)) - case ("]", v) => Closed(string2BigDecimal(v)) - } - val point = ctx.intLit.size match { - case 0 => UnknownWidth - case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) - } - IntervalType(lower, upper, point) - } - case "Clock" => ClockType + case "Interval" => + ctx.boundValue.size match { + case 0 => + val point = ctx.intLit.size match { + case 0 => UnknownWidth + case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) + } + IntervalType(UnknownBound, UnknownBound, point) + case 2 => + val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match { + case (_, "?") => UnknownBound + case ("(", v) => Open(string2BigDecimal(v)) + case ("[", v) => Closed(string2BigDecimal(v)) + } + val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match { + case (_, "?") => UnknownBound + case (")", v) => Open(string2BigDecimal(v)) + case ("]", v) => Closed(string2BigDecimal(v)) + } + val point = ctx.intLit.size match { + case 0 => UnknownWidth + case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText)) + } + IntervalType(lower, upper, point) + } + case "Clock" => ClockType case "AsyncReset" => AsyncResetType - case "Reset" => ResetType - case "Analog" => if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0))) - else AnalogType(UnknownWidth) + case "Reset" => ResetType + case "Analog" => + if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0))) + else AnalogType(UnknownWidth) case "{" => BundleType(ctx.field.asScala.map(visitField).toSeq) } case typeContext: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.intLit(0).getText)) @@ -208,11 +218,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w private def visitRuw(ctx: Option[RuwContext]): ReadUnderWrite.Value = ctx match { case None => ReadUnderWrite.Undefined - case Some(ctx) => ctx.getText match { - case "undefined" => ReadUnderWrite.Undefined - case "old" => ReadUnderWrite.Old - case "new" => ReadUnderWrite.New - } + case Some(ctx) => + ctx.getText match { + case "undefined" => ReadUnderWrite.Undefined + case "old" => ReadUnderWrite.Old + case "new" => ReadUnderWrite.New + } } // Memories are fairly complicated to translate thus have a dedicated method @@ -220,7 +231,11 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w val readers = mutable.ArrayBuffer.empty[String] val writers = mutable.ArrayBuffer.empty[String] val readwriters = mutable.ArrayBuffer.empty[String] - case class ParamValue(typ: Option[Type] = None, lit: Option[BigInt] = None, ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined, unique: Boolean = true) + case class ParamValue( + typ: Option[Type] = None, + lit: Option[BigInt] = None, + ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined, + unique: Boolean = true) val fieldMap = mutable.HashMap[String, ParamValue]() val memName = ctx.id(0).getText def parseMemFields(memFields: Seq[MemFieldContext]): Unit = @@ -228,14 +243,14 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w val fieldName = field.children.asScala(0).getText fieldName match { - case "reader" => readers ++= field.id().asScala.map(_.getText) - case "writer" => writers ++= field.id().asScala.map(_.getText) + case "reader" => readers ++= field.id().asScala.map(_.getText) + case "writer" => writers ++= field.id().asScala.map(_.getText) case "readwriter" => readwriters ++= field.id().asScala.map(_.getText) case _ => val paramDef = fieldName match { - case "data-type" => ParamValue(typ = Some(visitType(field.`type`()))) + case "data-type" => ParamValue(typ = Some(visitType(field.`type`()))) case "read-under-write" => ParamValue(ruw = visitRuw(Option(field.ruw))) - case _ => ParamValue(lit = Some(BigInt(field.intLit().getText))) + case _ => ParamValue(lit = Some(BigInt(field.intLit().getText))) } if (fieldMap.contains(fieldName)) throw new ParameterRedefinedException(s"Redefinition of $fieldName in FIRRTL line:${field.start.getLine}") @@ -255,20 +270,26 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w } // Check for required fields - Seq("data-type", "depth", "read-latency", "write-latency") foreach { field => - fieldMap.getOrElse(field, throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found")) + Seq("data-type", "depth", "read-latency", "write-latency").foreach { field => + fieldMap.getOrElse( + field, + throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found") + ) } def lit(param: String) = fieldMap(param).lit.get val ruw = fieldMap.get("read-under-write").map(_.ruw).getOrElse(ir.ReadUnderWrite.Undefined) - DefMemory(info, + DefMemory( + info, name = memName, dataType = fieldMap("data-type").typ.get, depth = lit("depth"), writeLatency = lit("write-latency").toInt, readLatency = lit("read-latency").toInt, - readers = readers.toSeq, writers = writers.toSeq, readwriters = readwriters.toSeq, + readers = readers.toSeq, + writers = writers.toSeq, + readwriters = readwriters.toSeq, readUnderWrite = ruw ) } @@ -299,56 +320,88 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0) match { case when: WhenContext => visitWhen(when) - case term: TerminalNode => term.getText match { - case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`())) - case "reg" => - val name = ctx.id(0).getText - val tpe = visitType(ctx.`type`()) - val (reset, init, rinfo) = { - val rb = ctx.reset_block() - if (rb != null) { - val sr = rb.simple_reset.simple_reset0() - val innerInfo = if (info == NoInfo) visitInfo(Option(rb.info), ctx) else info - (visitExp(sr.exp(0)), visitExp(sr.exp(1)), innerInfo) + case term: TerminalNode => + term.getText match { + case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`())) + case "reg" => + val name = ctx.id(0).getText + val tpe = visitType(ctx.`type`()) + val (reset, init, rinfo) = { + val rb = ctx.reset_block() + if (rb != null) { + val sr = rb.simple_reset.simple_reset0() + val innerInfo = if (info == NoInfo) visitInfo(Option(rb.info), ctx) else info + (visitExp(sr.exp(0)), visitExp(sr.exp(1)), innerInfo) + } else + (UIntLiteral(0, IntWidth(1)), Reference(name, tpe), info) } - else - (UIntLiteral(0, IntWidth(1)), Reference(name, tpe), info) - } - DefRegister(rinfo, name, tpe, visitExp(ctx_exp(0)), reset, init) - case "mem" => visitMem(ctx) - case "cmem" => - val (tpe, size) = visitCMemType(ctx.`type`()) - CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false) - case "smem" => - val (tpe, size) = visitCMemType(ctx.`type`()) - CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw))) - case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) - case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0))) - - case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) - case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq) - case "printf(" => Print(info, visitStringLit(ctx.StringLit), ctx_exp.drop(2).map(visitExp).toSeq, - visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) - // formal - case "assert" => Verification(Formal.Assert, info, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit)) - case "assume" => Verification(Formal.Assume, info, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit)) - case "cover" => Verification(Formal.Cover, info, visitExp(ctx_exp(0)), - visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), - visitStringLit(ctx.StringLit)) - // end formal - case "skip" => EmptyStmt - } + DefRegister(rinfo, name, tpe, visitExp(ctx_exp(0)), reset, init) + case "mem" => visitMem(ctx) + case "cmem" => + val (tpe, size) = visitCMemType(ctx.`type`()) + CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false) + case "smem" => + val (tpe, size) = visitCMemType(ctx.`type`()) + CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw))) + case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText) + case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0))) + + case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) + case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq) + case "printf(" => + Print( + info, + visitStringLit(ctx.StringLit), + ctx_exp.drop(2).map(visitExp).toSeq, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)) + ) + // formal + case "assert" => + Verification( + Formal.Assert, + info, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)), + visitExp(ctx_exp(2)), + visitStringLit(ctx.StringLit) + ) + case "assume" => + Verification( + Formal.Assume, + info, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)), + visitExp(ctx_exp(2)), + visitStringLit(ctx.StringLit) + ) + case "cover" => + Verification( + Formal.Cover, + info, + visitExp(ctx_exp(0)), + visitExp(ctx_exp(1)), + visitExp(ctx_exp(2)), + visitStringLit(ctx.StringLit) + ) + // end formal + case "skip" => EmptyStmt + } // If we don't match on the first child, try the next one case _ => ctx.getChild(1).getText match { case "<=" => Connect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) case "<-" => PartialConnect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1))) case "is" => IsInvalid(info, visitExp(ctx_exp(0))) - case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType, ctx.id(1).getText, Seq(visitExp(ctx_exp(0)), visitExp(ctx_exp(1))), visitMdir(ctx.mdir)) + case "mport" => + CDefMPort( + info, + ctx.id(0).getText, + UnknownType, + ctx.id(1).getText, + Seq(visitExp(ctx_exp(0)), visitExp(ctx_exp(1))), + visitMdir(ctx.mdir) + ) } } } @@ -379,10 +432,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w new SubAccess(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) } case _: PrimopContext => - DoPrim(visitPrimop(ctx.primop), - ctx_exp.map(visitExp).toSeq, - ctx.intLit.asScala.map(x => string2BigInt(x.getText)).toSeq, - UnknownType) + DoPrim( + visitPrimop(ctx.primop), + ctx_exp.map(visitExp).toSeq, + ctx.intLit.asScala.map(x => string2BigInt(x.getText)).toSeq, + UnknownType + ) case _ => ctx.getChild(0).getText match { case "UInt" => @@ -405,7 +460,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w SIntLiteral(value) } case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType) - case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) + case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType) } } } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 95b24ad0..4153fc74 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -27,96 +27,110 @@ case object DuplexFlow extends Flow case object UnknownFlow extends Flow object WRef { + /** Creates a WRef from a Wire */ def apply(wire: DefWire): WRef = new WRef(wire.name, wire.tpe, WireKind, UnknownFlow) + /** Creates a WRef from a Register */ def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UnknownFlow) + /** Creates a WRef from a Node */ def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, SourceFlow) + /** Creates a WRef from a Port */ def apply(port: Port): WRef = new WRef(port.name, port.tpe, PortKind, UnknownFlow) + /** Creates a WRef from a WDefInstance */ def apply(wi: WDefInstance): WRef = new WRef(wi.name, wi.tpe, InstanceKind, UnknownFlow) + /** Creates a WRef from a DefMemory */ def apply(mem: DefMemory): WRef = new WRef(mem.name, passes.MemPortUtils.memType(mem), MemKind, UnknownFlow) + /** Creates a WRef from an arbitrary string name */ def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = Reference(n, t, k, UnknownFlow) - def apply(name: String, tpe: Type , kind: Kind, flow: Flow): WRef = Reference(name, tpe, kind, flow) + def apply(name: String, tpe: Type, kind: Kind, flow: Flow): WRef = Reference(name, tpe, kind, flow) def unapply(ref: Reference): Option[(String, Type, Kind, Flow)] = Some((ref.name, ref.tpe, ref.kind, ref.flow)) } object WSubField { - def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UnknownFlow) - def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UnknownFlow) - def apply(expr: Expression, name: String, tpe: Type, flow: Flow): WSubField = new WSubField(expr, name, tpe, flow) + def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UnknownFlow) + def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UnknownFlow) + def apply(expr: Expression, name: String, tpe: Type, flow: Flow): WSubField = new WSubField(expr, name, tpe, flow) def unapply(wsf: WSubField): Option[(Expression, String, Type, Flow)] = Some((wsf.expr, wsf.name, wsf.tpe, wsf.flow)) } object WSubIndex { - def apply(expr: Expression, value: Int, tpe: Type, flow: Flow): WSubIndex = new WSubIndex(expr, value, tpe, flow) + def apply(expr: Expression, value: Int, tpe: Type, flow: Flow): WSubIndex = new WSubIndex(expr, value, tpe, flow) def unapply(wsi: WSubIndex): Option[(Expression, Int, Type, Flow)] = Some((wsi.expr, wsi.value, wsi.tpe, wsi.flow)) } object WSubAccess { - def apply(expr: Expression, index: Expression, tpe: Type, flow: Flow): WSubAccess = new WSubAccess(expr, index, tpe, flow) - def unapply(wsa: WSubAccess): Option[(Expression, Expression, Type, Flow)] = Some((wsa.expr, wsa.index, wsa.tpe, wsa.flow)) + def apply(expr: Expression, index: Expression, tpe: Type, flow: Flow): WSubAccess = + new WSubAccess(expr, index, tpe, flow) + def unapply(wsa: WSubAccess): Option[(Expression, Expression, Type, Flow)] = Some( + (wsa.expr, wsa.index, wsa.tpe, wsa.flow) + ) } case object WVoid extends Expression with UseSerializer { def tpe = UnknownType - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () } case object WInvalid extends Expression with UseSerializer { def tpe = UnknownType - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () } // Useful for splitting then remerging references case object EmptyExpression extends Expression with UseSerializer { def tpe = UnknownType - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () } object WDefInstance { def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType) - def apply(info: Info, name: String, module: String, tpe: Type): WDefInstance = new WDefInstance(info, name, module, tpe) + def apply(info: Info, name: String, module: String, tpe: Type): WDefInstance = + new WDefInstance(info, name, module, tpe) def unapply(wi: WDefInstance): Option[(Info, String, String, Type)] = { Some((wi.info, wi.name, wi.module, wi.tpe)) } } case class WDefInstanceConnector( - info: Info, - name: String, - module: String, - tpe: Type, - portCons: Seq[(Expression, Expression)]) extends Statement with IsDeclaration with UseSerializer { + info: Info, + name: String, + module: String, + tpe: Type, + portCons: Seq[(Expression, Expression)]) + extends Statement + with IsDeclaration + with UseSerializer { def mapExpr(f: Expression => Expression): Statement = - this.copy(portCons = portCons map { case (e1, e2) => (f(e1), f(e2)) }) - def mapStmt(f: Statement => Statement): Statement = this - def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = portCons foreach { case (e1, e2) => (f(e1), f(e2)) } - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + this.copy(portCons = portCons.map { case (e1, e2) => (f(e1), f(e2)) }) + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = portCons.foreach { case (e1, e2) => (f(e1), f(e2)) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } // Resultant width is the same as the maximum input width @@ -172,12 +186,12 @@ case object Dshlw extends PrimOp { * @note This is not allowed to leak from any transform */ private[firrtl] case class InfoExpr(info: Info, expr: Expression) extends Expression { - def foreachExpr(f: Expression => Unit): Unit = f(expr) - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = () - def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(this.expr)) - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(this.expr)) + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = this def tpe: Type = expr.tpe // Members declared in firrtl.ir.FirrtlNode @@ -198,33 +212,35 @@ private[firrtl] object InfoExpr { // TODO this the right name? def map(expr: Expression)(f: Expression => Expression): Expression = expr match { case ie: InfoExpr => ie.mapExpr(f) - case e => f(e) + case e => f(e) } } object WrappedExpression { def apply(e: Expression) = new WrappedExpression(e) - def we(e: Expression) = new WrappedExpression(e) - def weq(e1: Expression, e2: Expression) = we(e1) == we(e2) + def we(e: Expression) = new WrappedExpression(e) + def weq(e1: Expression, e2: Expression) = we(e1) == we(e2) } class WrappedExpression(val e1: Expression) { override def equals(we: Any) = we match { - case (we: WrappedExpression) => (e1,we.e1) match { - case (e1x: UIntLiteral, e2x: UIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width) - case (e1x: SIntLiteral, e2x: SIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width) - case (e1x: WRef, e2x: WRef) => e1x.name equals e2x.name - case (e1x: WSubField, e2x: WSubField) => (e1x.name equals e2x.name) && weq(e1x.expr,e2x.expr) - case (e1x: WSubIndex, e2x: WSubIndex) => (e1x.value == e2x.value) && weq(e1x.expr,e2x.expr) - case (e1x: WSubAccess, e2x: WSubAccess) => weq(e1x.index,e2x.index) && weq(e1x.expr,e2x.expr) - case (WVoid, WVoid) => true - case (WInvalid, WInvalid) => true - case (e1x: DoPrim, e2x: DoPrim) => e1x.op == e2x.op && - ((e1x.consts zip e2x.consts) forall {case (x, y) => x == y}) && - ((e1x.args zip e2x.args) forall {case (x, y) => weq(x, y)}) - case (e1x: Mux, e2x: Mux) => weq(e1x.cond,e2x.cond) && weq(e1x.tval,e2x.tval) && weq(e1x.fval,e2x.fval) - case (e1x: ValidIf, e2x: ValidIf) => weq(e1x.cond,e2x.cond) && weq(e1x.value,e2x.value) - case (e1x, e2x) => false - } + case (we: WrappedExpression) => + (e1, we.e1) match { + case (e1x: UIntLiteral, e2x: UIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width) + case (e1x: SIntLiteral, e2x: SIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width) + case (e1x: WRef, e2x: WRef) => e1x.name.equals(e2x.name) + case (e1x: WSubField, e2x: WSubField) => (e1x.name.equals(e2x.name)) && weq(e1x.expr, e2x.expr) + case (e1x: WSubIndex, e2x: WSubIndex) => (e1x.value == e2x.value) && weq(e1x.expr, e2x.expr) + case (e1x: WSubAccess, e2x: WSubAccess) => weq(e1x.index, e2x.index) && weq(e1x.expr, e2x.expr) + case (WVoid, WVoid) => true + case (WInvalid, WInvalid) => true + case (e1x: DoPrim, e2x: DoPrim) => + e1x.op == e2x.op && + ((e1x.consts.zip(e2x.consts)).forall { case (x, y) => x == y }) && + ((e1x.args.zip(e2x.args)).forall { case (x, y) => weq(x, y) }) + case (e1x: Mux, e2x: Mux) => weq(e1x.cond, e2x.cond) && weq(e1x.tval, e2x.tval) && weq(e1x.fval, e2x.fval) + case (e1x: ValidIf, e2x: ValidIf) => weq(e1x.cond, e2x.cond) && weq(e1x.value, e2x.value) + case (e1x, e2x) => false + } case _ => false } override def hashCode = e1.serialize.hashCode @@ -237,7 +253,7 @@ private[firrtl] sealed trait HasMapWidth { object WrappedType { def apply(t: Type) = new WrappedType(t) - def wt(t: Type) = apply(t) + def wt(t: Type) = apply(t) // Check if it is legal for the source type to drive the sink type // Which is which matters because ResetType can be driven by itself, Bool, or AsyncResetType, but // it cannot drive Bool nor AsyncResetType @@ -245,10 +261,10 @@ object WrappedType { (sink, source) match { case (_: UIntType, _: UIntType) => true case (_: SIntType, _: SIntType) => true - case (ClockType, ClockType) => true + case (ClockType, ClockType) => true case (AsyncResetType, AsyncResetType) => true - case (ResetType, tpe) => legalResetType(tpe) - case (tpe, ResetType) => legalResetType(tpe) + case (ResetType, tpe) => legalResetType(tpe) + case (tpe, ResetType) => legalResetType(tpe) case (_: FixedType, _: FixedType) => true case (_: IntervalType, _: IntervalType) => true // Analog totally skips out of the Firrtl type system. @@ -260,13 +276,14 @@ object WrappedType { sink.size == source.size && compare(sink.tpe, source.tpe) case (sink: BundleType, source: BundleType) => (sink.fields.size == source.fields.size) && - sink.fields.zip(source.fields).forall { case (f1, f2) => - (f1.flip == f2.flip) && (f1.name == f2.name) && (f1.flip match { - case Default => compare(f1.tpe, f2.tpe) - // We allow UInt<1> and AsyncReset to drive Reset but not the other way around - case Flip => compare(f2.tpe, f1.tpe) - }) - } + sink.fields.zip(source.fields).forall { + case (f1, f2) => + (f1.flip == f2.flip) && (f1.name == f2.name) && (f1.flip match { + case Default => compare(f1.tpe, f2.tpe) + // We allow UInt<1> and AsyncReset to drive Reset but not the other way around + case Flip => compare(f2.tpe, f1.tpe) + }) + } case _ => false } } @@ -287,7 +304,7 @@ object WrappedWidth { def eqw(w1: Width, w2: Width): Boolean = new WrappedWidth(w1) == new WrappedWidth(w2) } -class WrappedWidth (val w: Width) { +class WrappedWidth(val w: Width) { def ww(w: Width): WrappedWidth = new WrappedWidth(w) override def toString = w match { case (w: VarWidth) => w.name @@ -295,12 +312,13 @@ class WrappedWidth (val w: Width) { case UnknownWidth => "?" } override def equals(o: Any): Boolean = o match { - case (w2: WrappedWidth) => (w, w2.w) match { - case (w1: VarWidth, w2: VarWidth) => w1.name.equals(w2.name) - case (w1: IntWidth, w2: IntWidth) => w1.width == w2.width - case (UnknownWidth, UnknownWidth) => true - case _ => false - } + case (w2: WrappedWidth) => + (w, w2.w) match { + case (w1: VarWidth, w2: VarWidth) => w1.name.equals(w2.name) + case (w1: IntWidth, w2: IntWidth) => w1.width == w2.width + case (UnknownWidth, UnknownWidth) => true + case _ => false + } case _ => false } } @@ -320,37 +338,38 @@ case object MReadWrite extends MPortDir { } case class CDefMemory( - info: Info, - name: String, - tpe: Type, - size: BigInt, - seq: Boolean, - readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) extends Statement with HasInfo with UseSerializer { - def mapExpr(f: Expression => Expression): Statement = this - def mapStmt(f: Statement => Statement): Statement = this - def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + info: Info, + name: String, + tpe: Type, + size: BigInt, + seq: Boolean, + readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) + extends Statement + with HasInfo + with UseSerializer { + def mapExpr(f: Expression => Expression): Statement = this + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } -case class CDefMPort(info: Info, - name: String, - tpe: Type, - mem: String, - exps: Seq[Expression], - direction: MPortDir) extends Statement with HasInfo with UseSerializer { - def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps map f) - def mapStmt(f: Statement => Statement): Statement = this - def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = exps.foreach(f) - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) +case class CDefMPort(info: Info, name: String, tpe: Type, mem: String, exps: Seq[Expression], direction: MPortDir) + extends Statement + with HasInfo + with UseSerializer { + def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps.map(f)) + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = exps.foreach(f) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/analyses/CircuitGraph.scala b/src/main/scala/firrtl/analyses/CircuitGraph.scala index 506bba57..a1fb0f19 100644 --- a/src/main/scala/firrtl/analyses/CircuitGraph.scala +++ b/src/main/scala/firrtl/analyses/CircuitGraph.scala @@ -80,9 +80,10 @@ class CircuitGraph private[analyses] (connectionGraph: ConnectionGraph) { * @return */ def absolutePaths(mt: ModuleTarget): Seq[IsModule] = instanceGraph.findInstancesInHierarchy(mt.module).map { - case seq if seq.nonEmpty => seq.foldLeft(CircuitTarget(circuit.main).module(circuit.main): IsModule) { - case (it, InstanceKey(instance, ofModule)) => it.instOf(instance, ofModule) - } + case seq if seq.nonEmpty => + seq.foldLeft(CircuitTarget(circuit.main).module(circuit.main): IsModule) { + case (it, InstanceKey(instance, ofModule)) => it.instOf(instance, ofModule) + } } /** Return the sequence of nodes from source to sink, inclusive diff --git a/src/main/scala/firrtl/analyses/ConnectionGraph.scala b/src/main/scala/firrtl/analyses/ConnectionGraph.scala index 0e13711a..f98cf14c 100644 --- a/src/main/scala/firrtl/analyses/ConnectionGraph.scala +++ b/src/main/scala/firrtl/analyses/ConnectionGraph.scala @@ -16,22 +16,24 @@ import scala.collection.mutable * @param circuit firrtl AST of this graph. * @param digraph Directed graph of ReferenceTarget in the AST. * @param irLookup [[IRLookup]] instance of circuit graph. - * */ -class ConnectionGraph protected(val circuit: Circuit, - val digraph: DiGraph[ReferenceTarget], - val irLookup: IRLookup) - extends DiGraph[ReferenceTarget](digraph.getEdgeMap.asInstanceOf[mutable.LinkedHashMap[ReferenceTarget, mutable.LinkedHashSet[ReferenceTarget]]]) { + */ +class ConnectionGraph protected (val circuit: Circuit, val digraph: DiGraph[ReferenceTarget], val irLookup: IRLookup) + extends DiGraph[ReferenceTarget]( + digraph.getEdgeMap.asInstanceOf[mutable.LinkedHashMap[ReferenceTarget, mutable.LinkedHashSet[ReferenceTarget]]] + ) { lazy val serialize: String = s"""{ - |${getEdgeMap.map { case (k, vs) => + |${getEdgeMap.map { + case (k, vs) => s""" "$k": { - | "kind": "${irLookup.kind(k)}", - | "type": "${irLookup.tpe(k)}", - | "expr": "${irLookup.expr(k, irLookup.flow(k))}", - | "sinks": [${vs.map { v => s""""$v"""" }.mkString(", ")}], - | "declaration": "${irLookup.declaration(k)}" - | }""".stripMargin }.mkString(",\n")} - |}""".stripMargin + | "kind": "${irLookup.kind(k)}", + | "type": "${irLookup.tpe(k)}", + | "expr": "${irLookup.expr(k, irLookup.flow(k))}", + | "sinks": [${vs.map { v => s""""$v"""" }.mkString(", ")}], + | "declaration": "${irLookup.declaration(k)}" + | }""".stripMargin + }.mkString(",\n")} + |}""".stripMargin /** Used by BFS to map each visited node to the list of instance inputs visited thus far * @@ -134,7 +136,10 @@ class ConnectionGraph protected(val circuit: Circuit, /** @return a new, reversed connection graph where edges point from sinks to sources. */ def reverseConnectionGraph: ConnectionGraph = new ConnectionGraph(circuit, digraph.reverse, irLookup) - override def BFS(root: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): collection.Map[ReferenceTarget, ReferenceTarget] = { + override def BFS( + root: ReferenceTarget, + blacklist: collection.Set[ReferenceTarget] + ): collection.Map[ReferenceTarget, ReferenceTarget] = { val prev = new mutable.LinkedHashMap[ReferenceTarget, ReferenceTarget]() val ordering = new Ordering[ReferenceTarget] { override def compare(x: ReferenceTarget, y: ReferenceTarget): Int = x.path.size - y.path.size @@ -216,7 +221,6 @@ class ConnectionGraph protected(val circuit: Circuit, bfsShortCuts.get(localSource) match { case Some(set) => set.map { x => x.setPathTarget(source.pathTarget) } case None => - val pathlessEdges = super.getEdges(localSource) val ret = pathlessEdges.flatMap { @@ -246,7 +250,9 @@ class ConnectionGraph protected(val circuit: Circuit, // Exiting to parent, but had unresolved trip through child, so don't update shortcut portConnectivityStack(localSink) = localSource +: currentStack } - Set[ReferenceTarget](localSink.setPathTarget(source.noComponents.targetParent.asInstanceOf[IsComponent].pathTarget)) + Set[ReferenceTarget]( + localSink.setPathTarget(source.noComponents.targetParent.asInstanceOf[IsComponent].pathTarget) + ) case localSink if enteringChildInstance(source)(localSink) => portConnectivityStack(localSink) = localSource +: portConnectivityStack.getOrElse(localSource, Nil) @@ -265,24 +271,31 @@ class ConnectionGraph protected(val circuit: Circuit, } - override def path(start: ReferenceTarget, end: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): Seq[ReferenceTarget] = { + override def path( + start: ReferenceTarget, + end: ReferenceTarget, + blacklist: collection.Set[ReferenceTarget] + ): Seq[ReferenceTarget] = { insertShortCuts(super.path(start, end, blacklist)) } private def insertShortCuts(path: Seq[ReferenceTarget]): Seq[ReferenceTarget] = { val soFar = mutable.HashSet[ReferenceTarget]() if (path.size > 1) { - path.head +: path.sliding(2).flatMap { - case Seq(from, to) => - getShortCut(from) match { - case Some(set) if set.contains(to) && soFar.contains(from.pathlessTarget) => - soFar += from.pathlessTarget - Seq(from.pathTarget.ref("..."), to) - case _ => - soFar += from.pathlessTarget - Seq(to) - } - }.toSeq + path.head +: path + .sliding(2) + .flatMap { + case Seq(from, to) => + getShortCut(from) match { + case Some(set) if set.contains(to) && soFar.contains(from.pathlessTarget) => + soFar += from.pathlessTarget + Seq(from.pathTarget.ref("..."), to) + case _ => + soFar += from.pathlessTarget + Seq(to) + } + } + .toSeq } else path } @@ -325,16 +338,16 @@ object ConnectionGraph { * @return */ def asTarget(m: ModuleTarget, tagger: TokenTagger)(e: FirrtlNode): ReferenceTarget = e match { - case l: Literal => m.ref(tagger.getRef(l.value.toString)) + case l: Literal => m.ref(tagger.getRef(l.value.toString)) case r: Reference => m.ref(r.name) - case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value) - case s: SubField => asTarget(m, tagger)(s.expr).field(s.name) - case d: DoPrim => m.ref(tagger.getRef(d.op.serialize)) - case _: Mux => m.ref(tagger.getRef("mux")) - case _: ValidIf => m.ref(tagger.getRef("validif")) + case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value) + case s: SubField => asTarget(m, tagger)(s.expr).field(s.name) + case d: DoPrim => m.ref(tagger.getRef(d.op.serialize)) + case _: Mux => m.ref(tagger.getRef("mux")) + case _: ValidIf => m.ref(tagger.getRef("validif")) case WInvalid => m.ref(tagger.getRef("invalid")) case _: Print => m.ref(tagger.getRef("print")) - case _: Stop => m.ref(tagger.getRef("print")) + case _: Stop => m.ref(tagger.getRef("print")) case other => sys.error(s"Unsupported: $other") } @@ -354,30 +367,31 @@ object ConnectionGraph { def enteringNonParentInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = { source.path.nonEmpty && - (source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule != localSink.module || - localSink.ref != source.path.last._1.value) + (source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule != localSink.module || + localSink.ref != source.path.last._1.value) } def enteringChildInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match { case ReferenceTarget(_, _, _, _, TargetToken.Field(port) +: comps) - if port == localSink.ref && comps == localSink.component => true + if port == localSink.ref && comps == localSink.component => + true case _ => false } def leavingRootInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match { case ReferenceTarget(_, _, Seq(), port, comps) - if port == localSink.component.head.value && comps == localSink.component.tail => true + if port == localSink.component.head.value && comps == localSink.component.tail => + true case _ => false } - private def buildCircuitGraph(circuit: Circuit): ConnectionGraph = { val mdg = new MutableDiGraph[ReferenceTarget]() val declarations = mutable.LinkedHashMap[ModuleTarget, mutable.LinkedHashMap[ReferenceTarget, FirrtlNode]]() val circuitTarget = CircuitTarget(circuit.main) val moduleMap = circuit.modules.map { m => circuitTarget.module(m.name) -> m }.toMap - circuit map buildModule(circuitTarget) + circuit.map(buildModule(circuitTarget)) def addLabeledVertex(v: ReferenceTarget, f: FirrtlNode): Unit = { mdg.addVertex(v) @@ -386,7 +400,7 @@ object ConnectionGraph { def buildModule(c: CircuitTarget)(module: DefModule): DefModule = { val m = c.module(module.name) - module map buildPort(m) map buildStatement(m, new TokenTagger()) + module.map(buildPort(m)).map(buildStatement(m, new TokenTagger())) } def buildPort(m: ModuleTarget)(port: Port): Port = { @@ -412,7 +426,7 @@ object ConnectionGraph { (Utils.flow(instExp), Utils.flow(modExp)) match { case (SourceFlow, SinkFlow) => mdg.addPairWithEdge(it, mt) case (SinkFlow, SourceFlow) => mdg.addPairWithEdge(mt, it) - case _ => sys.error("Something went wrong...") + case _ => sys.error("Something went wrong...") } } } @@ -461,13 +475,14 @@ object ConnectionGraph { // Connect each subTarget to the corresponding init subTarget val allRegTargets = regTarget.leafSubTargets(d.tpe) val allInitTargets = initTarget.leafSubTargets(d.tpe).zip(Utils.create_exps(d.init)) - allRegTargets.zip(allInitTargets).foreach { case (r, (i, e)) => - mdg.addVertex(i) - mdg.addVertex(r) - mdg.addEdge(clockTarget, r) - mdg.addEdge(resetTarget, r) - mdg.addEdge(i, r) - buildExpression(m, tagger, i)(e) + allRegTargets.zip(allInitTargets).foreach { + case (r, (i, e)) => + mdg.addVertex(i) + mdg.addVertex(r) + mdg.addEdge(clockTarget, r) + mdg.addEdge(resetTarget, r) + mdg.addEdge(i, r) + buildExpression(m, tagger, i)(e) } } @@ -480,9 +495,10 @@ object ConnectionGraph { val sinkTarget = m.ref(d.name) addLabeledVertex(sinkTarget, stmt) val nodeTargets = sinkTarget.leafSubTargets(d.value.tpe) - nodeTargets.zip(Utils.create_exps(d.value)).foreach { case (n, e) => - mdg.addVertex(n) - buildExpression(m, tagger, n)(e) + nodeTargets.zip(Utils.create_exps(d.value)).foreach { + case (n, e) => + mdg.addVertex(n) + buildExpression(m, tagger, n)(e) } case c: Connect => @@ -512,10 +528,10 @@ object ConnectionGraph { addLabeledVertex(m.ref(d.name), d) buildMemory(m, d) - /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.passes.ExpandWhensAndCheck]]*/ + /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.passes.ExpandWhensAndCheck]] */ case _: Conditionally => sys.error("Unsupported! Only works on Middle Firrtl") - case s: Block => s map buildStatement(m, tagger) + case s: Block => s.map(buildStatement(m, tagger)) case a: Attach => val attachTargets = a.exprs.map { r => @@ -523,18 +539,25 @@ object ConnectionGraph { mdg.addVertex(at) at } - attachTargets.combinations(2).foreach { case Seq(l, r) => - mdg.addEdge(l, r) - mdg.addEdge(r, l) + attachTargets.combinations(2).foreach { + case Seq(l, r) => + mdg.addEdge(l, r) + mdg.addEdge(r, l) } case p: Print => addLabeledVertex(asTarget(m, tagger)(p), p) - case s: Stop => addLabeledVertex(asTarget(m, tagger)(s), s) + case s: Stop => addLabeledVertex(asTarget(m, tagger)(s), s) case EmptyStmt => } stmt } - def buildExpression(m: ModuleTarget, tagger: TokenTagger, sinkTarget: ReferenceTarget)(expr: Expression): Expression = { + def buildExpression( + m: ModuleTarget, + tagger: TokenTagger, + sinkTarget: ReferenceTarget + )(expr: Expression + ): Expression = { + /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.stage.Forms.Resolved]]. */ val sourceTarget = asTarget(m, tagger)(expr) mdg.addVertex(sourceTarget) @@ -542,7 +565,7 @@ object ConnectionGraph { expr match { case _: DoPrim | _: Mux | _: ValidIf | _: Literal => addLabeledVertex(sourceTarget, expr) - expr map buildExpression(m, tagger, sourceTarget) + expr.map(buildExpression(m, tagger, sourceTarget)) case _ => } expr @@ -552,7 +575,6 @@ object ConnectionGraph { } } - /** Used for obtaining a tag for a given label unnamed Target. */ class TokenTagger { private val counterMap = mutable.HashMap[String, Int]() diff --git a/src/main/scala/firrtl/analyses/IRLookup.scala b/src/main/scala/firrtl/analyses/IRLookup.scala index f9819ebd..b8528a95 100644 --- a/src/main/scala/firrtl/analyses/IRLookup.scala +++ b/src/main/scala/firrtl/analyses/IRLookup.scala @@ -6,7 +6,22 @@ import firrtl.annotations.TargetToken._ import firrtl.annotations._ import firrtl.ir._ import firrtl.passes.MemPortUtils -import firrtl.{DuplexFlow, ExpKind, Flow, InstanceKind, Kind, MemKind, PortKind, RegKind, SinkFlow, SourceFlow, UnknownFlow, Utils, WInvalid, WireKind} +import firrtl.{ + DuplexFlow, + ExpKind, + Flow, + InstanceKind, + Kind, + MemKind, + PortKind, + RegKind, + SinkFlow, + SourceFlow, + UnknownFlow, + Utils, + WInvalid, + WireKind +} import scala.collection.mutable @@ -19,26 +34,33 @@ object IRLookup { * @param declarations Maps references (not subreferences) to declarations * @param modules Maps module targets to modules */ -class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map[ReferenceTarget, FirrtlNode]], - private val modules: Map[ModuleTarget, DefModule]) { +class IRLookup private[analyses] ( + private val declarations: Map[ModuleTarget, Map[ReferenceTarget, FirrtlNode]], + private val modules: Map[ModuleTarget, DefModule]) { private val flowCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Flow]]() private val kindCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Kind]]() private val tpeCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Type]]() private val exprCache = mutable.HashMap[ModuleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]]() - private val refCache = mutable.HashMap[ModuleTarget, mutable.LinkedHashMap[Kind, mutable.ArrayBuffer[ReferenceTarget]]]() - + private val refCache = + mutable.HashMap[ModuleTarget, mutable.LinkedHashMap[Kind, mutable.ArrayBuffer[ReferenceTarget]]]() /** @example Given ~Top|MyModule/inst:Other>foo.bar, returns ~Top|Other>foo * @return the target converted to its local reference */ def asLocalRef(t: ReferenceTarget): ReferenceTarget = t.pathlessTarget.copy(component = Nil) - def flow(t: ReferenceTarget): Flow = flowCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Flow]()).getOrElseUpdate(t.pathlessTarget, Utils.flow(expr(t.pathlessTarget))) + def flow(t: ReferenceTarget): Flow = flowCache + .getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Flow]()) + .getOrElseUpdate(t.pathlessTarget, Utils.flow(expr(t.pathlessTarget))) - def kind(t: ReferenceTarget): Kind = kindCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Kind]()).getOrElseUpdate(t.pathlessTarget, Utils.kind(expr(t.pathlessTarget))) + def kind(t: ReferenceTarget): Kind = kindCache + .getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Kind]()) + .getOrElseUpdate(t.pathlessTarget, Utils.kind(expr(t.pathlessTarget))) - def tpe(t: ReferenceTarget): Type = tpeCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Type]()).getOrElseUpdate(t.pathlessTarget, expr(t.pathlessTarget).tpe) + def tpe(t: ReferenceTarget): Type = tpeCache + .getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Type]()) + .getOrElseUpdate(t.pathlessTarget, expr(t.pathlessTarget).tpe) /** get expression of the target. * It can return None for many reasons, including @@ -54,7 +76,7 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map val pathless = t.pathlessTarget inCache(pathless, flow) match { - case e@Some(_) => return e + case e @ Some(_) => return e case None => val mt = pathless.moduleTarget val emt = t.encapsulatingModuleTarget @@ -62,36 +84,50 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map declarations(emt)(asLocalRef(t)) match { case e: Expression => require(e.tpe.isInstanceOf[GroundType]) - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).getOrElseUpdate((pathless, Utils.flow(e)), e) - case d: IsDeclaration => d match { - case n: DefNode => - updateExpr(mt, Reference(n.name, n.value.tpe, ExpKind, SourceFlow)) - case p: Port => - updateExpr(mt, Reference(p.name, p.tpe, PortKind, Utils.get_flow(p))) - case w: DefInstance => - updateExpr(mt, Reference(w.name, w.tpe, InstanceKind, SourceFlow)) - case w: DefWire => - updateExpr(mt, Reference(w.name, w.tpe, WireKind, SourceFlow)) - updateExpr(mt, Reference(w.name, w.tpe, WireKind, SinkFlow)) - updateExpr(mt, Reference(w.name, w.tpe, WireKind, DuplexFlow)) - case r: DefRegister if pathless.tokens.last == Clock => - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.clock - case r: DefRegister if pathless.tokens.isDefinedAt(1) && pathless.tokens(1) == Init => - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.init - updateExpr(pathless, r.init) - case r: DefRegister if pathless.tokens.last == Reset => - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.reset - case r: DefRegister => - updateExpr(mt, Reference(r.name, r.tpe, RegKind, SourceFlow)) - updateExpr(mt, Reference(r.name, r.tpe, RegKind, SinkFlow)) - updateExpr(mt, Reference(r.name, r.tpe, RegKind, DuplexFlow)) - case m: DefMemory => - updateExpr(mt, Reference(m.name, MemPortUtils.memType(m), MemKind, SourceFlow)) - case other => - sys.error(s"Cannot call expr with: $t, given declaration $other") - } + exprCache + .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()) + .getOrElseUpdate((pathless, Utils.flow(e)), e) + case d: IsDeclaration => + d match { + case n: DefNode => + updateExpr(mt, Reference(n.name, n.value.tpe, ExpKind, SourceFlow)) + case p: Port => + updateExpr(mt, Reference(p.name, p.tpe, PortKind, Utils.get_flow(p))) + case w: DefInstance => + updateExpr(mt, Reference(w.name, w.tpe, InstanceKind, SourceFlow)) + case w: DefWire => + updateExpr(mt, Reference(w.name, w.tpe, WireKind, SourceFlow)) + updateExpr(mt, Reference(w.name, w.tpe, WireKind, SinkFlow)) + updateExpr(mt, Reference(w.name, w.tpe, WireKind, DuplexFlow)) + case r: DefRegister if pathless.tokens.last == Clock => + exprCache.getOrElseUpdate( + pathless.moduleTarget, + mutable.HashMap[(ReferenceTarget, Flow), Expression]() + )((pathless, SourceFlow)) = r.clock + case r: DefRegister if pathless.tokens.isDefinedAt(1) && pathless.tokens(1) == Init => + exprCache.getOrElseUpdate( + pathless.moduleTarget, + mutable.HashMap[(ReferenceTarget, Flow), Expression]() + )((pathless, SourceFlow)) = r.init + updateExpr(pathless, r.init) + case r: DefRegister if pathless.tokens.last == Reset => + exprCache.getOrElseUpdate( + pathless.moduleTarget, + mutable.HashMap[(ReferenceTarget, Flow), Expression]() + )((pathless, SourceFlow)) = r.reset + case r: DefRegister => + updateExpr(mt, Reference(r.name, r.tpe, RegKind, SourceFlow)) + updateExpr(mt, Reference(r.name, r.tpe, RegKind, SinkFlow)) + updateExpr(mt, Reference(r.name, r.tpe, RegKind, DuplexFlow)) + case m: DefMemory => + updateExpr(mt, Reference(m.name, MemPortUtils.memType(m), MemKind, SourceFlow)) + case other => + sys.error(s"Cannot call expr with: $t, given declaration $other") + } case _: IsInvalid => - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = WInvalid + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, SourceFlow) + ) = WInvalid } } } @@ -118,7 +154,8 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map * * @param moduleTarget [[firrtl.annotations.ModuleTarget]] to be queried. * @param kind [[firrtl.Kind]] to be find. - * @return all [[firrtl.annotations.ReferenceTarget]] in this node. */ + * @return all [[firrtl.annotations.ReferenceTarget]] in this node. + */ def kindFinder(moduleTarget: ModuleTarget, kind: Kind): Seq[ReferenceTarget] = { def updateRefs(kind: Kind, rt: ReferenceTarget): Unit = refCache .getOrElseUpdate(rt.moduleTarget, mutable.LinkedHashMap.empty[Kind, mutable.ArrayBuffer[ReferenceTarget]]) @@ -136,7 +173,11 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map case (rt, _: Port) => updateRefs(PortKind, rt) case _ => } - refCache.get(moduleTarget).map(_.getOrElse(kind, Seq.empty[ReferenceTarget])).getOrElse(Seq.empty[ReferenceTarget]).toSeq + refCache + .get(moduleTarget) + .map(_.getOrElse(kind, Seq.empty[ReferenceTarget])) + .getOrElse(Seq.empty[ReferenceTarget]) + .toSeq } } @@ -181,7 +222,7 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map def moduleLeafPortTargets(m: ModuleTarget): (Seq[(ReferenceTarget, Type)], Seq[(ReferenceTarget, Type)]) = modules(m).ports.flatMap { case Port(_, name, Output, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow)) - case Port(_, name, Input, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow)) + case Port(_, name, Input, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow)) }.foldLeft((Vector.empty[(ReferenceTarget, Type)], Vector.empty[(ReferenceTarget, Type)])) { case ((inputs, outputs), e) if Utils.flow(e) == SourceFlow => (inputs, outputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe)) @@ -189,7 +230,6 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map (inputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe), outputs) } - /** @param t [[firrtl.annotations.ReferenceTarget]] to be queried. * @return whether a ReferenceTarget is contained in this IRLookup */ @@ -213,10 +253,10 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map val all = i.pathAsTargets :+ i.encapsulatingModuleTarget.instOf(i.instance, i.ofModule) all.map { x => declarations.contains(x.moduleTarget) && declarations(x.moduleTarget).contains(x.asReference) && - (declarations(x.moduleTarget)(x.asReference) match { - case DefInstance(_, _, of, _) if of == x.ofModule => validPath(x.ofModuleTarget) - case _ => false - }) + (declarations(x.moduleTarget)(x.asReference) match { + case DefInstance(_, _, of, _) if of == x.ofModule => validPath(x.ofModuleTarget) + case _ => false + }) }.reduce(_ && _) } } @@ -248,17 +288,54 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map /** Optionally returns the expression corresponding to the target if contained in the expression cache. */ private def inCache(pathless: ReferenceTarget, flow: Flow): Option[Expression] = { - (flow, - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SourceFlow)), - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SinkFlow)), - exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains(pathless, DuplexFlow) + ( + flow, + exprCache + .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()) + .contains((pathless, SourceFlow)), + exprCache + .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()) + .contains((pathless, SinkFlow)), + exprCache + .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()) + .contains(pathless, DuplexFlow) ) match { - case (SourceFlow, true, _, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow))) - case (SinkFlow, _, true, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow))) - case (DuplexFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow))) - case (UnknownFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow))) - case (UnknownFlow, true, false, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow))) - case (UnknownFlow, false, true, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SinkFlow))) + case (SourceFlow, true, _, _) => + Some( + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, flow) + ) + ) + case (SinkFlow, _, true, _) => + Some( + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, flow) + ) + ) + case (DuplexFlow, _, _, true) => + Some( + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, DuplexFlow) + ) + ) + case (UnknownFlow, _, _, true) => + Some( + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, DuplexFlow) + ) + ) + case (UnknownFlow, true, false, false) => + Some( + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, SourceFlow) + ) + ) + case (UnknownFlow, false, true, false) => + Some( + exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())( + (pathless, SinkFlow) + ) + ) case _ => None } } diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala index f994b39a..4aab9a3a 100644 --- a/src/main/scala/firrtl/analyses/InstanceGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala @@ -10,7 +10,6 @@ import firrtl.Utils._ import firrtl.traversals.Foreachers._ import firrtl.annotations.TargetToken._ - /** A class representing the instance hierarchy of a working IR Circuit * * @constructor constructs an instance graph from a Circuit @@ -29,7 +28,7 @@ import firrtl.annotations.TargetToken._ class InstanceGraph(c: Circuit) { @deprecated("Use InstanceKeyGraph.moduleMap instead.", "FIRRTL 1.4") - val moduleMap = c.modules.map({m => (m.name,m) }).toMap + val moduleMap = c.modules.map({ m => (m.name, m) }).toMap private val instantiated = new mutable.LinkedHashSet[String] private val childInstances = new mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]] @@ -43,7 +42,7 @@ class InstanceGraph(c: Circuit) { private val instanceQueue = new mutable.Queue[DefInstance] for (subTop <- c.modules.view.map(_.name).filterNot(instantiated)) { - val topInstance = DefInstance(subTop,subTop) + val topInstance = DefInstance(subTop, subTop) instanceQueue.enqueue(topInstance) while (instanceQueue.nonEmpty) { val current = instanceQueue.dequeue @@ -53,7 +52,7 @@ class InstanceGraph(c: Circuit) { instanceQueue.enqueue(child) instanceGraph.addVertex(child) } - instanceGraph.addEdge(current,child) + instanceGraph.addEdge(current, child) } } } @@ -73,7 +72,7 @@ class InstanceGraph(c: Circuit) { * of all module instances in the Circuit. */ @deprecated("Use InstanceKeyGraph.fullHierarchy instead.", "FIRRTL 1.4") - lazy val fullHierarchy: mutable.LinkedHashMap[DefInstance,Seq[Seq[DefInstance]]] = graph.pathsInDAG(trueTopInstance) + lazy val fullHierarchy: mutable.LinkedHashMap[DefInstance, Seq[Seq[DefInstance]]] = graph.pathsInDAG(trueTopInstance) /** A count of the *static* number of instances of each module. For any module other than the top (main) module, this is * equivalent to the number of inst statements in the circuit instantiating each module, irrespective of the number @@ -85,7 +84,7 @@ class InstanceGraph(c: Circuit) { lazy val staticInstanceCount: Map[OfModule, Int] = { val foo = mutable.LinkedHashMap.empty[OfModule, Int] childInstances.keys.foreach { - case main if main == c.main => foo += main.OfModule -> 1 + case main if main == c.main => foo += main.OfModule -> 1 case other => foo += other.OfModule -> 0 } childInstances.values.flatten.map(_.OfModule).foreach { @@ -106,7 +105,7 @@ class InstanceGraph(c: Circuit) { @deprecated("Use InstanceKeyGraph.findInstancesInHierarchy instead (now with caching of vertices!).", "FIRRTL 1.4") def findInstancesInHierarchy(module: String): Seq[Seq[DefInstance]] = { val instances = graph.getVertices.filter(_.module == module).toSeq - instances flatMap { i => fullHierarchy.getOrElse(i, Nil) } + instances.flatMap { i => fullHierarchy.getOrElse(i, Nil) } } /** An [[firrtl.graph.EulerTour EulerTour]] representation of the [[firrtl.graph.DiGraph DiGraph]] */ @@ -117,8 +116,7 @@ class InstanceGraph(c: Circuit) { * a design */ @deprecated("Use InstanceKeyGraph and EulerTour(iGraph.graph, iGraph.top).rmq(moduleA, moduleB).", "FIRRTL 1.4") - def lowestCommonAncestor(moduleA: Seq[DefInstance], - moduleB: Seq[DefInstance]): Seq[DefInstance] = { + def lowestCommonAncestor(moduleA: Seq[DefInstance], moduleB: Seq[DefInstance]): Seq[DefInstance] = { tour.rmq(moduleA, moduleB) } @@ -131,10 +129,9 @@ class InstanceGraph(c: Circuit) { graph.transformNodes(_.module).linearize.map(moduleMap(_)) } - /** Given a circuit, returns a map from module name to children - * instance/module definitions - */ + * instance/module definitions + */ @deprecated("Use InstanceKeyGraph.getChildInstances instead.", "FIRRTL 1.4") def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]] = childInstances @@ -172,7 +169,7 @@ class InstanceGraph(c: Circuit) { /** The set of all modules *not* reachable in the circuit */ @deprecated("Use InstanceKeyGraph.unreachableModules instead.", "FIRRTL 1.4") - lazy val unreachableModules: collection.Set[OfModule] = modules diff reachableModules + lazy val unreachableModules: collection.Set[OfModule] = modules.diff(reachableModules) } @@ -186,10 +183,9 @@ object InstanceGraph { * @return */ @deprecated("Use InstanceKeyGraph.collectInstances instead.", "FIRRTL 1.4") - def collectInstances(insts: mutable.Set[DefInstance]) - (s: Statement): Unit = s match { - case i: DefInstance => insts += i - case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!") + def collectInstances(insts: mutable.Set[DefInstance])(s: Statement): Unit = s match { + case i: DefInstance => insts += i + case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!") case i: WDefInstanceConnector => throwInternalError("Expecting DefInstance, found a DefInstanceConnector!") case _ => s.foreach(collectInstances(insts)) } diff --git a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala index 761315dc..5354888d 100644 --- a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala +++ b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala @@ -14,10 +14,10 @@ import scala.collection.mutable * pairs of InstanceName and Module name as vertex keys instead of using WDefInstance * which will hash the instance type causing some performance issues. */ -class InstanceKeyGraph private(c: ir.Circuit) { +class InstanceKeyGraph private (c: ir.Circuit) { import InstanceKeyGraph._ - private val nameToModule: Map[String, ir.DefModule] = c.modules.map({m => (m.name,m) }).toMap + private val nameToModule: Map[String, ir.DefModule] = c.modules.map({ m => (m.name, m) }).toMap private val childInstances: Seq[(String, Seq[InstanceKey])] = c.modules.map { m => m.name -> InstanceKeyGraph.collectInstances(m) } @@ -37,8 +37,8 @@ class InstanceKeyGraph private(c: ir.Circuit) { circuitTopInstance.OfModule +: internalGraph.reachableFrom(circuitTopInstance).toSeq.map(_.OfModule) private lazy val cachedUnreachableModules: Seq[OfModule] = { - val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1)):_*) - val reachable = mutable.LinkedHashSet(cachedReachableModules:_*) + val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1)): _*) + val reachable = mutable.LinkedHashSet(cachedReachableModules: _*) all.diff(reachable).toSeq } @@ -68,11 +68,11 @@ class InstanceKeyGraph private(c: ir.Circuit) { private lazy val cachedStaticInstanceCount = { val foo = mutable.LinkedHashMap.empty[OfModule, Int] childInstances.foreach { - case (main, _) if main == c.main => foo += main.OfModule -> 1 + case (main, _) if main == c.main => foo += main.OfModule -> 1 case (other, _) => foo += other.OfModule -> 0 } - childInstances.flatMap(_._2).map(_.OfModule).foreach { - mod => foo += mod -> (foo(mod) + 1) + childInstances.flatMap(_._2).map(_.OfModule).foreach { mod => + foo += mod -> (foo(mod) + 1) } foo.toMap } @@ -88,17 +88,18 @@ class InstanceKeyGraph private(c: ir.Circuit) { */ def findInstancesInHierarchy(module: String): Seq[Seq[InstanceKey]] = { val instances = vertices.filter(_.module == module).toSeq - instances.flatMap{ i => cachedFullHierarchy.getOrElse(i, Nil) } + instances.flatMap { i => cachedFullHierarchy.getOrElse(i, Nil) } } /** Given a circuit, returns a map from module name to a map * in turn mapping instances names to corresponding module names */ def getChildInstanceMap: mutable.LinkedHashMap[OfModule, mutable.LinkedHashMap[Instance, OfModule]] = - mutable.LinkedHashMap(childInstances.map { case (k, v) => - val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens):_*) - TargetToken.OfModule(k) -> moduleMap - }:_*) + mutable.LinkedHashMap(childInstances.map { + case (k, v) => + val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens): _*) + TargetToken.OfModule(k) -> moduleMap + }: _*) /** All modules in the circuit reachable from the top module */ def reachableModules: Seq[OfModule] = cachedReachableModules @@ -110,7 +111,6 @@ class InstanceKeyGraph private(c: ir.Circuit) { def fullHierarchy: mutable.LinkedHashMap[InstanceKey, Seq[Seq[InstanceKey]]] = cachedFullHierarchy } - object InstanceKeyGraph { def apply(c: ir.Circuit): InstanceKeyGraph = new InstanceKeyGraph(c) @@ -126,12 +126,12 @@ object InstanceKeyGraph { /** Finds all instance definitions in a firrtl Module. */ def collectInstances(m: ir.DefModule): Seq[InstanceKey] = m match { - case _ : ir.ExtModule => Seq() + case _: ir.ExtModule => Seq() case ir.Module(_, _, _, body) => { val instances = mutable.ArrayBuffer[InstanceKey]() def onStmt(s: ir.Statement): Unit = s match { case firrtl.WDefInstance(_, name, module, _) => instances += InstanceKey(name, module) - case ir.DefInstance(_, name, module, _) => instances += InstanceKey(name, module) + case ir.DefInstance(_, name, module, _) => instances += InstanceKey(name, module) case _: firrtl.WDefInstanceConnector => firrtl.Utils.throwInternalError("Expecting WDefInstance, found a WDefInstanceConnector!") case other => other.foreachStmt(onStmt) @@ -143,8 +143,10 @@ object InstanceKeyGraph { private def topKey(module: String): InstanceKey = InstanceKey(module, module) - private def buildGraph(childInstances: Seq[(String, Seq[InstanceKey])], roots: Iterable[String]): - DiGraph[InstanceKey] = { + private def buildGraph( + childInstances: Seq[(String, Seq[InstanceKey])], + roots: Iterable[String] + ): DiGraph[InstanceKey] = { val instanceGraph = new MutableDiGraph[InstanceKey] val childInstanceMap = childInstances.toMap diff --git a/src/main/scala/firrtl/analyses/NodeCount.scala b/src/main/scala/firrtl/analyses/NodeCount.scala index 0276f4f5..63571503 100644 --- a/src/main/scala/firrtl/analyses/NodeCount.scala +++ b/src/main/scala/firrtl/analyses/NodeCount.scala @@ -21,18 +21,19 @@ class NodeCount private (node: FirrtlNode) { @tailrec private final def rec(xs: List[Any]): Unit = - if (xs.isEmpty) { } - else { + if (xs.isEmpty) {} else { val node = xs.head - require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], - "Unexpected FirrtlNode that does not implement Product!") + require( + node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], + "Unexpected FirrtlNode that does not implement Product!" + ) val moreToVisit = if (identityMap.containsKey(node)) List.empty else { // Haven't seen yet identityMap.put(node, true) regularSet += node node match { // FirrtlNodes are Products - case p: Product => p.productIterator + case p: Product => p.productIterator case i: Iterable[Any] => i case _ => List.empty } diff --git a/src/main/scala/firrtl/analyses/SymbolTable.scala b/src/main/scala/firrtl/analyses/SymbolTable.scala index 53ad1614..36549160 100644 --- a/src/main/scala/firrtl/analyses/SymbolTable.scala +++ b/src/main/scala/firrtl/analyses/SymbolTable.scala @@ -17,26 +17,27 @@ import scala.collection.mutable * Different implementations of SymbolTable might want to store different * information (e.g., only the names without the types) or build * different indices depending on what information the transform needs. - * */ + */ trait SymbolTable { // methods that need to be implemented by any Symbol table - def declare(name: String, tpe: Type, kind: Kind): Unit + def declare(name: String, tpe: Type, kind: Kind): Unit def declareInstance(name: String, module: String): Unit // convenience methods def declare(d: DefInstance): Unit = declareInstance(d.name, d.module) - def declare(d: DefMemory): Unit = declare(d.name, MemPortUtils.memType(d), firrtl.MemKind) - def declare(d: DefNode): Unit = declare(d.name, d.value.tpe, firrtl.NodeKind) - def declare(d: DefWire): Unit = declare(d.name, d.tpe, firrtl.WireKind) + def declare(d: DefMemory): Unit = declare(d.name, MemPortUtils.memType(d), firrtl.MemKind) + def declare(d: DefNode): Unit = declare(d.name, d.value.tpe, firrtl.NodeKind) + def declare(d: DefWire): Unit = declare(d.name, d.tpe, firrtl.WireKind) def declare(d: DefRegister): Unit = declare(d.name, d.tpe, firrtl.RegKind) - def declare(d: Port): Unit = declare(d.name, d.tpe, firrtl.PortKind) + def declare(d: Port): Unit = declare(d.name, d.tpe, firrtl.PortKind) } /** Trusts the type annotation on DefInstance nodes instead of re-deriving the type from - * the module ports which would require global (cross-module) information. */ + * the module ports which would require global (cross-module) information. + */ private[firrtl] abstract class LocalSymbolTable extends SymbolTable { def declareInstance(name: String, module: String): Unit = declare(name, UnknownType, InstanceKind) - override def declare(d: WDefInstance): Unit = declare(d.name, d.tpe, InstanceKind) + override def declare(d: WDefInstance): Unit = declare(d.name, d.tpe, InstanceKind) } /** Uses a function to derive instance types from module names */ @@ -63,10 +64,10 @@ private[firrtl] trait WithMap extends SymbolTable { } private case class Sym(name: String, tpe: Type, kind: Kind) extends Symbol -private[firrtl] trait Symbol { def name: String; def tpe: Type; def kind: Kind } +private[firrtl] trait Symbol { def name: String; def tpe: Type; def kind: Kind } /** only remembers the names of symbols */ -private[firrtl] class NamespaceTable extends LocalSymbolTable { +private[firrtl] class NamespaceTable extends LocalSymbolTable { private var names = List[String]() override def declare(name: String, tpe: Type, kind: Kind): Unit = names = name :: names def getNames: Seq[String] = names @@ -82,9 +83,9 @@ object SymbolTable { } private def scanStatement(s: Statement)(implicit table: SymbolTable): Unit = s match { case d: DefInstance => table.declare(d) - case d: DefMemory => table.declare(d) - case d: DefNode => table.declare(d) - case d: DefWire => table.declare(d) + case d: DefMemory => table.declare(d) + case d: DefNode => table.declare(d) + case d: DefWire => table.declare(d) case d: DefRegister => table.declare(d) case other => other.foreachStmt(scanStatement) } diff --git a/src/main/scala/firrtl/annotations/Annotation.scala b/src/main/scala/firrtl/annotations/Annotation.scala index a382f685..16f85e67 100644 --- a/src/main/scala/firrtl/annotations/Annotation.scala +++ b/src/main/scala/firrtl/annotations/Annotation.scala @@ -5,7 +5,6 @@ package annotations import firrtl.options.StageUtils - case class AnnotationException(message: String) extends Exception(message) /** Base type of auxiliary information */ @@ -26,8 +25,8 @@ trait Annotation extends Product { */ private def extractComponents(ls: scala.collection.Traversable[_]): Seq[Target] = { ls.collect { - case c: Target => Seq(c) - case o: Product => extractComponents(o.productIterator.toIterable) + case c: Target => Seq(c) + case o: Product => extractComponents(o.productIterator.toIterable) case x: scala.collection.Traversable[_] => extractComponents(x) }.foldRight(Seq.empty[Target])((seq, c) => c ++ seq) } @@ -62,52 +61,54 @@ trait SingleTargetAnnotation[T <: Named] extends Annotation { x.map(newTargets => newTargets.map(t => duplicate(t.asInstanceOf[T]))).getOrElse(List(this)) case from: Named => val ret = renames.get(Target.convertNamed2Target(target)) - ret.map(_.map { newT => - val result = newT match { - case c: InstanceTarget => ModuleName(c.ofModule, CircuitName(c.circuit)) - case c: IsMember => - val local = Target.referringModule(c) - c.setPathTarget(local) - case c: CircuitTarget => c.toNamed - case other => throw Target.NamedException(s"Cannot convert $other to [[Named]]") - } - Target.convertTarget2Named(result) match { - case newTarget: T @unchecked => - try { - duplicate(newTarget) - } - catch { - case _: java.lang.ClassCastException => - val msg = s"${this.getClass.getName} target ${target.getClass.getName} " + - s"cannot be renamed to ${newTarget.getClass}" - throw AnnotationException(msg) - } - } - }).getOrElse(List(this)) + ret + .map(_.map { newT => + val result = newT match { + case c: InstanceTarget => ModuleName(c.ofModule, CircuitName(c.circuit)) + case c: IsMember => + val local = Target.referringModule(c) + c.setPathTarget(local) + case c: CircuitTarget => c.toNamed + case other => throw Target.NamedException(s"Cannot convert $other to [[Named]]") + } + Target.convertTarget2Named(result) match { + case newTarget: T @unchecked => + try { + duplicate(newTarget) + } catch { + case _: java.lang.ClassCastException => + val msg = s"${this.getClass.getName} target ${target.getClass.getName} " + + s"cannot be renamed to ${newTarget.getClass}" + throw AnnotationException(msg) + } + } + }) + .getOrElse(List(this)) } } } /** [[MultiTargetAnnotation]] keeps the renamed targets grouped within a single annotation. */ trait MultiTargetAnnotation extends Annotation { + /** Contains a sequence of [[firrtl.annotations.Target Target]]. * When created, [[targets]] should be assigned by `Seq(Seq(TargetA), Seq(TargetB), Seq(TargetC))` */ val targets: Seq[Seq[Target]] - /** Create another instance of this Annotation*/ + /** Create another instance of this Annotation */ def duplicate(n: Seq[Seq[Target]]): Annotation /** Assume [[RenameMap]] is `Map(TargetA -> Seq(TargetA1, TargetA2, TargetA3), TargetB -> Seq(TargetB1, TargetB2))` * in the update, this Annotation is still one annotation, but the contents are renamed in the below form * Seq(Seq(TargetA1, TargetA2, TargetA3), Seq(TargetB1, TargetB2), Seq(TargetC)) - **/ + */ def update(renames: RenameMap): Seq[Annotation] = Seq(duplicate(targets.map(ts => ts.flatMap(renames(_))))) private def crossJoin[T](list: Seq[Seq[T]]): Seq[Seq[T]] = list match { - case Nil => Nil - case x :: Nil => x map (Seq(_)) + case Nil => Nil + case x :: Nil => x.map(Seq(_)) case x :: xs => val xsJoin = crossJoin(xs) for { @@ -123,7 +124,7 @@ trait MultiTargetAnnotation extends Annotation { * Seq(Seq(TargetA1), Seq(TargetB1), Seq(TargetC)); Seq(Seq(TargetA1), Seq(TargetB2), Seq(TargetC)) * Seq(Seq(TargetA2), Seq(TargetB1), Seq(TargetC)); Seq(Seq(TargetA2), Seq(TargetB2), Seq(TargetC)) * Seq(Seq(TargetA3), Seq(TargetB1), Seq(TargetC)); Seq(Seq(TargetA3), Seq(TargetB2), Seq(TargetC)) - * */ + */ def flat(): AnnotationSeq = crossJoin(targets).map(r => duplicate(r.map(Seq(_)))) } diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala index 58cc0097..a1276e0e 100644 --- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala +++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala @@ -8,14 +8,16 @@ import java.io.File import firrtl.ir._ case class InvalidAnnotationFileException(file: File, cause: FirrtlUserException = null) - extends FirrtlUserException(s"$file", cause) + extends FirrtlUserException(s"$file", cause) case class InvalidAnnotationJSONException(msg: String) extends FirrtlUserException(msg) -case class AnnotationFileNotFoundException(file: File) extends FirrtlUserException( - s"Annotation file $file not found!" -) -case class AnnotationClassNotFoundException(className: String) extends FirrtlUserException( - s"Annotation class $className not found! Please check spelling and classpath" -) +case class AnnotationFileNotFoundException(file: File) + extends FirrtlUserException( + s"Annotation file $file not found!" + ) +case class AnnotationClassNotFoundException(className: String) + extends FirrtlUserException( + s"Annotation class $className not found! Please check spelling and classpath" + ) object AnnotationUtils { @@ -23,33 +25,33 @@ object AnnotationUtils { val SerializedModuleName = """([a-zA-Z_][a-zA-Z_0-9~!@#$%^*\-+=?/]*)""".r def validModuleName(s: String): Boolean = s match { case SerializedModuleName(name) => true - case _ => false + case _ => false } /** Returns true if a valid component/subcomponent name */ val SerializedComponentName = """([a-zA-Z_][a-zA-Z_0-9\[\]\.~!@#$%^*\-+=?/]*)""".r def validComponentName(s: String): Boolean = s match { case SerializedComponentName(name) => true - case _ => false + case _ => false } /** Tokenizes a string with '[', ']', '.' as tokens, e.g.: - * "foo.bar[boo.far]" becomes Seq("foo" "." "bar" "[" "boo" "." "far" "]") - */ + * "foo.bar[boo.far]" becomes Seq("foo" "." "bar" "[" "boo" "." "far" "]") + */ def tokenize(s: String): Seq[String] = s.find(c => "[].".contains(c)) match { case Some(_) => val i = s.indexWhere(c => "[].".contains(c)) s.slice(0, i) match { case "" => s(i).toString +: tokenize(s.drop(i + 1)) - case x => x +: s(i).toString +: tokenize(s.drop(i + 1)) + case x => x +: s(i).toString +: tokenize(s.drop(i + 1)) } case None if s == "" => Nil - case None => Seq(s) + case None => Seq(s) } def toNamed(s: String): Named = s.split("\\.", 3) match { - case Array(n) => CircuitName(n) - case Array(c, m) => ModuleName(m, CircuitName(c)) + case Array(n) => CircuitName(n) + case Array(c, m) => ModuleName(m, CircuitName(c)) case Array(c, m, x) => ComponentName(x, ModuleName(m, CircuitName(c))) } @@ -60,38 +62,39 @@ object AnnotationUtils { def toSubComponents(s: String): Seq[TargetToken] = { import TargetToken._ def exp2subcomp(e: ir.Expression): Seq[TargetToken] = e match { - case ir.Reference(name, _, _, _) => Seq(Ref(name)) + case ir.Reference(name, _, _, _) => Seq(Ref(name)) case ir.SubField(expr, name, _, _) => exp2subcomp(expr) :+ Field(name) case ir.SubIndex(expr, idx, _, _) => exp2subcomp(expr) :+ Index(idx) - case ir.SubAccess(expr, idx, _, _) => Utils.throwInternalError(s"For string $s, cannot convert a subaccess $e into a Target") + case ir.SubAccess(expr, idx, _, _) => + Utils.throwInternalError(s"For string $s, cannot convert a subaccess $e into a Target") } exp2subcomp(toExp(s)) } - /** Given a serialized component/subcomponent reference, subindex, subaccess, - * or subfield, return the corresponding IR expression. - * E.g. "foo.bar" becomes SubField(Reference("foo", UnknownType), "bar", UnknownType) - */ + * or subfield, return the corresponding IR expression. + * E.g. "foo.bar" becomes SubField(Reference("foo", UnknownType), "bar", UnknownType) + */ def toExp(s: String): Expression = { def parse(tokens: Seq[String]): Expression = { val DecPattern = """(\d+)""".r def findClose(tokens: Seq[String], index: Int, nOpen: Int): Seq[String] = { - if(index >= tokens.size) { + if (index >= tokens.size) { Utils.error("Cannot find closing bracket ]") - } else tokens(index) match { - case "[" => findClose(tokens, index + 1, nOpen + 1) - case "]" if nOpen == 1 => tokens.slice(1, index) - case "]" => findClose(tokens, index + 1, nOpen - 1) - case _ => findClose(tokens, index + 1, nOpen) - } + } else + tokens(index) match { + case "[" => findClose(tokens, index + 1, nOpen + 1) + case "]" if nOpen == 1 => tokens.slice(1, index) + case "]" => findClose(tokens, index + 1, nOpen - 1) + case _ => findClose(tokens, index + 1, nOpen) + } } def buildup(e: Expression, tokens: Seq[String]): Expression = tokens match { case "[" :: tail => val indexOrAccess = findClose(tokens, 0, 0) val exp = indexOrAccess.head match { case DecPattern(d) => SubIndex(e, d.toInt, UnknownType) - case _ => SubAccess(e, parse(indexOrAccess), UnknownType) + case _ => SubAccess(e, parse(indexOrAccess), UnknownType) } buildup(exp, tokens.drop(2 + indexOrAccess.size)) case "." :: tail => @@ -101,7 +104,7 @@ object AnnotationUtils { val root = Reference(tokens.head, UnknownType) buildup(root, tokens.tail) } - if(validComponentName(s)) { + if (validComponentName(s)) { parse(tokenize(s)) } else { Utils.error(s"Cannot convert $s into an expression.") diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala index 941bf003..0ef8b020 100644 --- a/src/main/scala/firrtl/annotations/JsonProtocol.scala +++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala @@ -5,7 +5,7 @@ package annotations import firrtl.ir._ -import scala.util.{Try, Failure} +import scala.util.{Failure, Try} import org.json4s._ import org.json4s.native.JsonMethods._ @@ -20,112 +20,189 @@ trait HasSerializationHints { } object JsonProtocol { - class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => ( - { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] }, - { case x: Class[_] => JString(x.getName) } - )) + class TransformClassSerializer + extends CustomSerializer[Class[_ <: Transform]](format => + ( + { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] }, + { case x: Class[_] => JString(x.getName) } + ) + ) // TODO Reduce boilerplate? - class NamedSerializer extends CustomSerializer[Named](format => ( - { case JString(s) => AnnotationUtils.toNamed(s) }, - { case named: Named => JString(named.serialize) } - )) - class CircuitNameSerializer extends CustomSerializer[CircuitName](format => ( - { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[CircuitName] }, - { case named: CircuitName => JString(named.serialize) } - )) - class ModuleNameSerializer extends CustomSerializer[ModuleName](format => ( - { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ModuleName] }, - { case named: ModuleName => JString(named.serialize) } - )) - class ComponentNameSerializer extends CustomSerializer[ComponentName](format => ( - { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ComponentName] }, - { case named: ComponentName => JString(named.serialize) } - )) - class TransformSerializer extends CustomSerializer[Transform](format => ( - { case JString(s) => - try { - Class.forName(s).asInstanceOf[Class[_ <: Transform]].newInstance() - } catch { - case e: java.lang.InstantiationException => throw new FirrtlInternalException( - "NoSuchMethodException during construction of serialized Transform. Is your Transform an inner class?", e) - case t: Throwable => throw t - }}, - { case x: Transform => JString(x.getClass.getName) } - )) - class LoadMemoryFileTypeSerializer extends CustomSerializer[MemoryLoadFileType](format => ( - { case JString(s) => MemoryLoadFileType.deserialize(s) }, - { case named: MemoryLoadFileType => JString(named.serialize) } - )) + class NamedSerializer + extends CustomSerializer[Named](format => + ( + { case JString(s) => AnnotationUtils.toNamed(s) }, + { case named: Named => JString(named.serialize) } + ) + ) + class CircuitNameSerializer + extends CustomSerializer[CircuitName](format => + ( + { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[CircuitName] }, + { case named: CircuitName => JString(named.serialize) } + ) + ) + class ModuleNameSerializer + extends CustomSerializer[ModuleName](format => + ( + { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ModuleName] }, + { case named: ModuleName => JString(named.serialize) } + ) + ) + class ComponentNameSerializer + extends CustomSerializer[ComponentName](format => + ( + { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ComponentName] }, + { case named: ComponentName => JString(named.serialize) } + ) + ) + class TransformSerializer + extends CustomSerializer[Transform](format => + ( + { + case JString(s) => + try { + Class.forName(s).asInstanceOf[Class[_ <: Transform]].newInstance() + } catch { + case e: java.lang.InstantiationException => + throw new FirrtlInternalException( + "NoSuchMethodException during construction of serialized Transform. Is your Transform an inner class?", + e + ) + case t: Throwable => throw t + } + }, + { case x: Transform => JString(x.getClass.getName) } + ) + ) + class LoadMemoryFileTypeSerializer + extends CustomSerializer[MemoryLoadFileType](format => + ( + { case JString(s) => MemoryLoadFileType.deserialize(s) }, + { case named: MemoryLoadFileType => JString(named.serialize) } + ) + ) - class TargetSerializer extends CustomSerializer[Target](format => ( - { case JString(s) => Target.deserialize(s) }, - { case named: Target => JString(named.serialize) } - )) - class GenericTargetSerializer extends CustomSerializer[GenericTarget](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[GenericTarget] }, - { case named: GenericTarget => JString(named.serialize) } - )) - class CircuitTargetSerializer extends CustomSerializer[CircuitTarget](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[CircuitTarget] }, - { case named: CircuitTarget => JString(named.serialize) } - )) - class ModuleTargetSerializer extends CustomSerializer[ModuleTarget](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[ModuleTarget] }, - { case named: ModuleTarget => JString(named.serialize) } - )) - class InstanceTargetSerializer extends CustomSerializer[InstanceTarget](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[InstanceTarget] }, - { case named: InstanceTarget => JString(named.serialize) } - )) - class ReferenceTargetSerializer extends CustomSerializer[ReferenceTarget](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[ReferenceTarget] }, - { case named: ReferenceTarget => JString(named.serialize) } - )) - class IsModuleSerializer extends CustomSerializer[IsModule](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[IsModule] }, - { case named: IsModule => JString(named.serialize) } - )) - class IsMemberSerializer extends CustomSerializer[IsMember](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[IsMember] }, - { case named: IsMember => JString(named.serialize) } - )) - class CompleteTargetSerializer extends CustomSerializer[CompleteTarget](format => ( - { case JString(s) => Target.deserialize(s).asInstanceOf[CompleteTarget] }, - { case named: CompleteTarget => JString(named.serialize) } - )) + class TargetSerializer + extends CustomSerializer[Target](format => + ( + { case JString(s) => Target.deserialize(s) }, + { case named: Target => JString(named.serialize) } + ) + ) + class GenericTargetSerializer + extends CustomSerializer[GenericTarget](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[GenericTarget] }, + { case named: GenericTarget => JString(named.serialize) } + ) + ) + class CircuitTargetSerializer + extends CustomSerializer[CircuitTarget](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[CircuitTarget] }, + { case named: CircuitTarget => JString(named.serialize) } + ) + ) + class ModuleTargetSerializer + extends CustomSerializer[ModuleTarget](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[ModuleTarget] }, + { case named: ModuleTarget => JString(named.serialize) } + ) + ) + class InstanceTargetSerializer + extends CustomSerializer[InstanceTarget](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[InstanceTarget] }, + { case named: InstanceTarget => JString(named.serialize) } + ) + ) + class ReferenceTargetSerializer + extends CustomSerializer[ReferenceTarget](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[ReferenceTarget] }, + { case named: ReferenceTarget => JString(named.serialize) } + ) + ) + class IsModuleSerializer + extends CustomSerializer[IsModule](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[IsModule] }, + { case named: IsModule => JString(named.serialize) } + ) + ) + class IsMemberSerializer + extends CustomSerializer[IsMember](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[IsMember] }, + { case named: IsMember => JString(named.serialize) } + ) + ) + class CompleteTargetSerializer + extends CustomSerializer[CompleteTarget](format => + ( + { case JString(s) => Target.deserialize(s).asInstanceOf[CompleteTarget] }, + { case named: CompleteTarget => JString(named.serialize) } + ) + ) // FIRRTL Serializers - class TypeSerializer extends CustomSerializer[Type](format => ( - { case JString(s) => Parser.parseType(s) }, - { case tpe: Type => JString(tpe.serialize) } - )) - class ExpressionSerializer extends CustomSerializer[Expression](format => ( - { case JString(s) => Parser.parseExpression(s) }, - { case expr: Expression => JString(expr.serialize) } - )) - class StatementSerializer extends CustomSerializer[Statement](format => ( - { case JString(s) => Parser.parseStatement(s) }, - { case statement: Statement => JString(statement.serialize) } - )) - class PortSerializer extends CustomSerializer[Port](format => ( - { case JString(s) => Parser.parsePort(s) }, - { case port: Port => JString(port.serialize) } - )) - class DefModuleSerializer extends CustomSerializer[DefModule](format => ( - { case JString(s) => Parser.parseDefModule(s) }, - { case mod: DefModule => JString(mod.serialize) } - )) - class CircuitSerializer extends CustomSerializer[Circuit](format => ( - { case JString(s) => Parser.parse(s) }, - { case cir: Circuit => JString(cir.serialize) } - )) - class InfoSerializer extends CustomSerializer[Info](format => ( - { case JString(s) => Parser.parseInfo(s) }, - { case info: Info => JString(info.serialize) } - )) - class GroundTypeSerializer extends CustomSerializer[GroundType](format => ( - { case JString(s) => Parser.parseType(s).asInstanceOf[GroundType] }, - { case tpe: GroundType => JString(tpe.serialize) } - )) + class TypeSerializer + extends CustomSerializer[Type](format => + ( + { case JString(s) => Parser.parseType(s) }, + { case tpe: Type => JString(tpe.serialize) } + ) + ) + class ExpressionSerializer + extends CustomSerializer[Expression](format => + ( + { case JString(s) => Parser.parseExpression(s) }, + { case expr: Expression => JString(expr.serialize) } + ) + ) + class StatementSerializer + extends CustomSerializer[Statement](format => + ( + { case JString(s) => Parser.parseStatement(s) }, + { case statement: Statement => JString(statement.serialize) } + ) + ) + class PortSerializer + extends CustomSerializer[Port](format => + ( + { case JString(s) => Parser.parsePort(s) }, + { case port: Port => JString(port.serialize) } + ) + ) + class DefModuleSerializer + extends CustomSerializer[DefModule](format => + ( + { case JString(s) => Parser.parseDefModule(s) }, + { case mod: DefModule => JString(mod.serialize) } + ) + ) + class CircuitSerializer + extends CustomSerializer[Circuit](format => + ( + { case JString(s) => Parser.parse(s) }, + { case cir: Circuit => JString(cir.serialize) } + ) + ) + class InfoSerializer + extends CustomSerializer[Info](format => + ( + { case JString(s) => Parser.parseInfo(s) }, + { case info: Info => JString(info.serialize) } + ) + ) + class GroundTypeSerializer + extends CustomSerializer[GroundType](format => + ( + { case JString(s) => Parser.parseType(s).asInstanceOf[GroundType] }, + { case tpe: GroundType => JString(tpe.serialize) } + ) + ) /** Construct Json formatter for annotations */ def jsonFormat(tags: Seq[Class[_]]) = { @@ -133,7 +210,7 @@ object JsonProtocol { new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer + new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer + new GenericTargetSerializer + new CircuitTargetSerializer + new ModuleTargetSerializer + - new InstanceTargetSerializer + new ReferenceTargetSerializer + new TransformSerializer + + new InstanceTargetSerializer + new ReferenceTargetSerializer + new TransformSerializer + new LoadMemoryFileTypeSerializer + new IsModuleSerializer + new IsMemberSerializer + new CompleteTargetSerializer + new TypeSerializer + new ExpressionSerializer + new StatementSerializer + new PortSerializer + new DefModuleSerializer + @@ -144,10 +221,12 @@ object JsonProtocol { def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get def serializeTry(annos: Seq[Annotation]): Try[String] = { - val tags = annos.flatMap({ - case anno: HasSerializationHints => anno.getClass +: anno.typeHints - case anno => Seq(anno.getClass) - }).distinct + val tags = annos + .flatMap({ + case anno: HasSerializationHints => anno.getClass +: anno.typeHints + case anno => Seq(anno.getClass) + }) + .distinct implicit val formats = jsonFormat(tags) Try(writePretty(annos)) @@ -159,20 +238,25 @@ object JsonProtocol { val parsed = parse(in) val annos = parsed match { case JArray(objs) => objs - case x => throw new InvalidAnnotationJSONException( - s"Annotations must be serialized as a JArray, got ${x.getClass.getName} instead!") + case x => + throw new InvalidAnnotationJSONException( + s"Annotations must be serialized as a JArray, got ${x.getClass.getName} instead!" + ) } // Recursively gather typeHints by pulling the "class" field from JObjects // Json4s should emit this as the first field in all serialized classes // Setting requireClassField mandates that all JObjects must provide a typeHint, // this used on the first invocation to check all annotations do so - def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst.flatMap({ - case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) - case obj: JObject if requireClassField => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") - case JObject(fields) => findTypeHints(fields.map(_._2)) - case JArray(arr) => findTypeHints(arr) - case oJValue => Seq() - }).distinct + def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst + .flatMap({ + case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2)) + case obj: JObject if requireClassField => + throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj") + case JObject(fields) => findTypeHints(fields.map(_._2)) + case JArray(arr) => findTypeHints(arr) + case oJValue => Seq() + }) + .distinct val classes = findTypeHints(annos, true) val loaded = classes.map(Class.forName(_)) @@ -186,10 +270,11 @@ object JsonProtocol { case e @ (_: org.json4s.ParserUtil.ParseException | _: org.json4s.MappingException) => Failure(new InvalidAnnotationJSONException(e.getMessage)) }.recoverWith { // If the input is a file, wrap in InvalidAnnotationFileException - case e: FirrtlUserException => in match { - case FileInput(file) => - Failure(new InvalidAnnotationFileException(file, e)) - case _ => Failure(e) - } + case e: FirrtlUserException => + in match { + case FileInput(file) => + Failure(new InvalidAnnotationFileException(file, e)) + case _ => Failure(e) + } } } diff --git a/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala b/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala index 64c30bdb..043c1b3b 100644 --- a/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala +++ b/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala @@ -21,7 +21,7 @@ object MemoryLoadFileType { def deserialize(s: String): MemoryLoadFileType = s match { case "h" => MemoryLoadFileType.Hex case "b" => MemoryLoadFileType.Binary - case _ => throw new FirrtlUserException(s"Unrecognized MemoryLoadFileType: $s") + case _ => throw new FirrtlUserException(s"Unrecognized MemoryLoadFileType: $s") } } @@ -31,11 +31,11 @@ object MemoryLoadFileType { * @param hexOrBinary use `\$readmemh` or `\$readmemb` */ case class LoadMemoryAnnotation( - target: ComponentName, - fileName: String, - hexOrBinary: MemoryLoadFileType = MemoryLoadFileType.Hex, - originalMemoryNameOpt: Option[String] = None -) extends SingleTargetAnnotation[Named] { + target: ComponentName, + fileName: String, + hexOrBinary: MemoryLoadFileType = MemoryLoadFileType.Hex, + originalMemoryNameOpt: Option[String] = None) + extends SingleTargetAnnotation[Named] { val (prefix, suffix) = { fileName.split("""\.""").toList match { @@ -57,7 +57,7 @@ case class LoadMemoryAnnotation( def getPrefix: String = prefix + originalMemoryNameOpt.map(n => target.name.drop(n.length)).getOrElse("") - def getSuffix: String = suffix + def getSuffix: String = suffix def getFileName: String = getPrefix + getSuffix def duplicate(newNamed: Named): LoadMemoryAnnotation = { diff --git a/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala b/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala index 44a8e3b5..7cefdef8 100644 --- a/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala +++ b/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala @@ -5,10 +5,10 @@ package firrtl.annotations import firrtl.{MemoryArrayInit, MemoryEmissionOption, MemoryInitValue, MemoryRandomInit, MemoryScalarInit} /** - * Represents the initial value of the annotated memory. - * While not supported on normal ASIC flows, it can be useful for simulation and FPGA flows. - * This annotation is consumed by the verilog emitter. - */ + * Represents the initial value of the annotated memory. + * While not supported on normal ASIC flows, it can be useful for simulation and FPGA flows. + * This annotation is consumed by the verilog emitter. + */ sealed trait MemoryInitAnnotation extends SingleTargetAnnotation[ReferenceTarget] with MemoryEmissionOption { def isRandomInit: Boolean } @@ -16,20 +16,20 @@ sealed trait MemoryInitAnnotation extends SingleTargetAnnotation[ReferenceTarget /** Randomly initialize the `target` memory. This is the same as the default behavior. */ case class MemoryRandomInitAnnotation(target: ReferenceTarget) extends MemoryInitAnnotation { override def duplicate(n: ReferenceTarget): Annotation = copy(n) - override def initValue: MemoryInitValue = MemoryRandomInit + override def initValue: MemoryInitValue = MemoryRandomInit override def isRandomInit: Boolean = true } /** Initialize all entries of the `target` memory with the scalar `value`. */ case class MemoryScalarInitAnnotation(target: ReferenceTarget, value: BigInt) extends MemoryInitAnnotation { override def duplicate(n: ReferenceTarget): Annotation = copy(n) - override def initValue: MemoryInitValue = MemoryScalarInit(value) - override def isRandomInit: Boolean = false + override def initValue: MemoryInitValue = MemoryScalarInit(value) + override def isRandomInit: Boolean = false } /** Initialize the `target` memory with the array of `values` which must be the same size as the memory depth. */ case class MemoryArrayInitAnnotation(target: ReferenceTarget, values: Seq[BigInt]) extends MemoryInitAnnotation { override def duplicate(n: ReferenceTarget): Annotation = copy(n) - override def initValue: MemoryInitValue = MemoryArrayInit(values) - override def isRandomInit: Boolean = false -}
\ No newline at end of file + override def initValue: MemoryInitValue = MemoryArrayInit(values) + override def isRandomInit: Boolean = false +} diff --git a/src/main/scala/firrtl/annotations/PresetAnnotations.scala b/src/main/scala/firrtl/annotations/PresetAnnotations.scala index 727417c1..d6066aa7 100644 --- a/src/main/scala/firrtl/annotations/PresetAnnotations.scala +++ b/src/main/scala/firrtl/annotations/PresetAnnotations.scala @@ -10,11 +10,11 @@ package annotations * @param target ReferenceTarget to an AsyncReset */ case class PresetAnnotation(target: ReferenceTarget) - extends SingleTargetAnnotation[ReferenceTarget] with firrtl.transforms.DontTouchAllTargets { + extends SingleTargetAnnotation[ReferenceTarget] + with firrtl.transforms.DontTouchAllTargets { override def duplicate(n: ReferenceTarget) = this.copy(target = n) } - /** * Transform the targeted asynchronously-reset Reg into a bitstream preset Reg * Used internally to annotate all registers associated to an AsyncReset tree @@ -22,12 +22,10 @@ case class PresetAnnotation(target: ReferenceTarget) * @param target ReferenceTarget to a Reg */ private[firrtl] case class PresetRegAnnotation( - target: ReferenceTarget -) extends SingleTargetAnnotation[ReferenceTarget] with RegisterEmissionOption { + target: ReferenceTarget) + extends SingleTargetAnnotation[ReferenceTarget] + with RegisterEmissionOption { def duplicate(n: ReferenceTarget) = this.copy(target = n) override def useInitAsPreset = true override def disableRandomization = true } - - - diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala index 4d1cdc2f..afde84dc 100644 --- a/src/main/scala/firrtl/annotations/Target.scala +++ b/src/main/scala/firrtl/annotations/Target.scala @@ -4,7 +4,7 @@ package firrtl package annotations import firrtl.ir.{Field => _, _} -import firrtl.Utils.{sub_type, field_type} +import firrtl.Utils.{field_type, sub_type} import AnnotationUtils.{toExp, validComponentName, validModuleName} import TargetToken._ @@ -29,27 +29,29 @@ sealed trait Target extends Named { def tokens: Seq[TargetToken] /** @return Returns a new [[GenericTarget]] with new values */ - def modify(circuitOpt: Option[String] = circuitOpt, - moduleOpt: Option[String] = moduleOpt, - tokens: Seq[TargetToken] = tokens): GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector) + def modify( + circuitOpt: Option[String] = circuitOpt, + moduleOpt: Option[String] = moduleOpt, + tokens: Seq[TargetToken] = tokens + ): GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector) /** @return Human-readable serialization */ def serialize: String = { val circuitString = "~" + circuitOpt.getOrElse("???") val moduleString = "|" + moduleOpt.getOrElse("???") val tokensString = tokens.map { - case Ref(r) => s">$r" - case Instance(i) => s"/$i" - case OfModule(o) => s":$o" + case Ref(r) => s">$r" + case Instance(i) => s"/$i" + case OfModule(o) => s":$o" case TargetToken.Field(f) => s".$f" - case Index(v) => s"[$v]" - case Clock => s"@clock" - case Reset => s"@reset" - case Init => s"@init" + case Index(v) => s"[$v]" + case Clock => s"@clock" + case Reset => s"@reset" + case Init => s"@init" }.mkString("") - if(moduleOpt.isEmpty && tokens.isEmpty) { + if (moduleOpt.isEmpty && tokens.isEmpty) { circuitString - } else if(tokens.isEmpty) { + } else if (tokens.isEmpty) { circuitString + moduleString } else { circuitString + moduleString + tokensString @@ -64,24 +66,23 @@ sealed trait Target extends Named { val moduleString = s"""\n$tab└── module ${moduleOpt.getOrElse("???")}:""" var depth = 4 val tokenString = tokens.map { - case Ref(r) => val rx = s"""\n$tab${" "*depth}└── $r"""; depth += 4; rx - case Instance(i) => val ix = s"""\n$tab${" "*depth}└── inst $i """; ix + case Ref(r) => val rx = s"""\n$tab${" " * depth}└── $r"""; depth += 4; rx + case Instance(i) => val ix = s"""\n$tab${" " * depth}└── inst $i """; ix case OfModule(o) => val ox = s"of $o:"; depth += 4; ox - case Field(f) => s".$f" - case Index(v) => s"[$v]" - case Clock => s"@clock" - case Reset => s"@reset" - case Init => s"@init" + case Field(f) => s".$f" + case Index(v) => s"[$v]" + case Clock => s"@clock" + case Reset => s"@reset" + case Init => s"@init" }.mkString("") (moduleOpt.isEmpty, tokens.isEmpty) match { case (true, true) => circuitString - case (_, true) => circuitString + moduleString - case (_, _) => circuitString + moduleString + tokenString + case (_, true) => circuitString + moduleString + case (_, _) => circuitString + moduleString + tokenString } } - /** @return Converts this [[Target]] into a [[GenericTarget]] */ def toGenericTarget: GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector) @@ -113,13 +114,13 @@ sealed trait Target extends Named { object Target { def asTarget(m: ModuleTarget)(e: Expression): ReferenceTarget = e match { case r: ir.Reference => m.ref(r.name) - case s: ir.SubIndex => asTarget(m)(s.expr).index(s.value) - case s: ir.SubField => asTarget(m)(s.expr).field(s.name) + case s: ir.SubIndex => asTarget(m)(s.expr).index(s.value) + case s: ir.SubField => asTarget(m)(s.expr).field(s.name) case s: ir.SubAccess => asTarget(m)(s.expr).field("@" + s.index.serialize) - case d: DoPrim => m.ref("@" + d.serialize) - case d: Mux => m.ref("@" + d.serialize) - case d: ValidIf => m.ref("@" + d.serialize) - case d: Literal => m.ref("@" + d.serialize) + case d: DoPrim => m.ref("@" + d.serialize) + case d: Mux => m.ref("@" + d.serialize) + case d: ValidIf => m.ref("@" + d.serialize) + case d: Literal => m.ref("@" + d.serialize) case other => sys.error(s"Unsupported: $other") } @@ -131,14 +132,14 @@ object Target { case class NamedException(message: String) extends Exception(message) - implicit def convertCircuitTarget2CircuitName(c: CircuitTarget): CircuitName = c.toNamed - implicit def convertModuleTarget2ModuleName(c: ModuleTarget): ModuleName = c.toNamed - implicit def convertIsComponent2ComponentName(c: IsComponent): ComponentName = c.toNamed - implicit def convertTarget2Named(c: Target): Named = c.toNamed - implicit def convertCircuitName2CircuitTarget(c: CircuitName): CircuitTarget = c.toTarget - implicit def convertModuleName2ModuleTarget(c: ModuleName): ModuleTarget = c.toTarget + implicit def convertCircuitTarget2CircuitName(c: CircuitTarget): CircuitName = c.toNamed + implicit def convertModuleTarget2ModuleName(c: ModuleTarget): ModuleName = c.toNamed + implicit def convertIsComponent2ComponentName(c: IsComponent): ComponentName = c.toNamed + implicit def convertTarget2Named(c: Target): Named = c.toNamed + implicit def convertCircuitName2CircuitTarget(c: CircuitName): CircuitTarget = c.toTarget + implicit def convertModuleName2ModuleTarget(c: ModuleName): ModuleTarget = c.toTarget implicit def convertComponentName2ReferenceTarget(c: ComponentName): ReferenceTarget = c.toTarget - implicit def convertNamed2Target(n: Named): CompleteTarget = n.toTarget + implicit def convertNamed2Target(n: Named): CompleteTarget = n.toTarget /** Converts [[ComponentName]]'s name into TargetTokens * @param name @@ -148,7 +149,7 @@ object Target { val tokens = AnnotationUtils.tokenize(name) val subComps = mutable.ArrayBuffer[TargetToken]() subComps += Ref(tokens.head) - if(tokens.tail.nonEmpty) { + if (tokens.tail.nonEmpty) { tokens.tail.zip(tokens.tail.tail).foreach { case (".", value: String) => subComps += Field(value) case ("[", value: String) => subComps += Index(value.toInt) @@ -163,31 +164,33 @@ object Target { * @param keywords * @return */ - def isOnly(seq: Seq[TargetToken], keywords:String*): Boolean = { - seq.map(_.is(keywords:_*)).foldLeft(false)(_ || _) && keywords.nonEmpty + def isOnly(seq: Seq[TargetToken], keywords: String*): Boolean = { + seq.map(_.is(keywords: _*)).foldLeft(false)(_ || _) && keywords.nonEmpty } /** @return [[Target]] from human-readable serialization */ def deserialize(s: String): Target = { val regex = """(?=[~|>/:.\[@])""" - s.split(regex).foldLeft(GenericTarget(None, None, Vector.empty)) { (t, tokenString) => - val value = tokenString.tail - tokenString(0) match { - case '~' if t.circuitOpt.isEmpty && t.moduleOpt.isEmpty && t.tokens.isEmpty => - if(value == "???") t else t.copy(circuitOpt = Some(value)) - case '|' if t.moduleOpt.isEmpty && t.tokens.isEmpty => - if(value == "???") t else t.copy(moduleOpt = Some(value)) - case '/' => t.add(Instance(value)) - case ':' => t.add(OfModule(value)) - case '>' => t.add(Ref(value)) - case '.' => t.add(Field(value)) - case '[' if value.dropRight(1).toInt >= 0 => t.add(Index(value.dropRight(1).toInt)) - case '@' if value == "clock" => t.add(Clock) - case '@' if value == "init" => t.add(Init) - case '@' if value == "reset" => t.add(Reset) - case other => throw NamedException(s"Cannot deserialize Target: $s") + s.split(regex) + .foldLeft(GenericTarget(None, None, Vector.empty)) { (t, tokenString) => + val value = tokenString.tail + tokenString(0) match { + case '~' if t.circuitOpt.isEmpty && t.moduleOpt.isEmpty && t.tokens.isEmpty => + if (value == "???") t else t.copy(circuitOpt = Some(value)) + case '|' if t.moduleOpt.isEmpty && t.tokens.isEmpty => + if (value == "???") t else t.copy(moduleOpt = Some(value)) + case '/' => t.add(Instance(value)) + case ':' => t.add(OfModule(value)) + case '>' => t.add(Ref(value)) + case '.' => t.add(Field(value)) + case '[' if value.dropRight(1).toInt >= 0 => t.add(Index(value.dropRight(1).toInt)) + case '@' if value == "clock" => t.add(Clock) + case '@' if value == "init" => t.add(Init) + case '@' if value == "reset" => t.add(Reset) + case other => throw NamedException(s"Cannot deserialize Target: $s") + } } - }.tryToComplete + .tryToComplete } /** Returns the module that a [[Target]] "refers" to. @@ -217,14 +220,16 @@ object Target { def getReferenceTarget(t: Target): Target = { (t.toGenericTarget match { case t: GenericTarget if t.isLegal => - val newTokens = t.tokens.reverse.dropWhile({ - case x: Field => true - case x: Index => true - case Clock => true - case Init => true - case Reset => true - case other => false - }).reverse + val newTokens = t.tokens.reverse + .dropWhile({ + case x: Field => true + case x: Index => true + case Clock => true + case Init => true + case Reset => true + case other => false + }) + .reverse GenericTarget(t.circuitOpt, t.moduleOpt, newTokens) case other => sys.error(s"Can't make $other pathless!") }).tryToComplete @@ -236,9 +241,8 @@ object Target { * @param moduleOpt Optional module name * @param tokens [[TargetToken]]s to represent the target in a circuit and module */ -case class GenericTarget(circuitOpt: Option[String], - moduleOpt: Option[String], - tokens: Vector[TargetToken]) extends Target { +case class GenericTarget(circuitOpt: Option[String], moduleOpt: Option[String], tokens: Vector[TargetToken]) + extends Target { override def toGenericTarget: GenericTarget = this @@ -252,11 +256,12 @@ case class GenericTarget(circuitOpt: Option[String], override def toTarget: CompleteTarget = getComplete.get override def getComplete: Option[CompleteTarget] = { - if(!isComplete) None else { + if (!isComplete) None + else { val target = this match { - case GenericTarget(Some(c), None, Vector()) => CircuitTarget(c) - case GenericTarget(Some(c), Some(m), Vector()) => ModuleTarget(c, m) - case GenericTarget(Some(c), Some(m), Ref(r) +: component) => ReferenceTarget(c, m, Nil, r, component) + case GenericTarget(Some(c), None, Vector()) => CircuitTarget(c) + case GenericTarget(Some(c), Some(m), Vector()) => ModuleTarget(c, m) + case GenericTarget(Some(c), Some(m), Ref(r) +: component) => ReferenceTarget(c, m, Nil, r, component) case GenericTarget(Some(c), Some(m), Instance(i) +: OfModule(o) +: Vector()) => InstanceTarget(c, m, Nil, i, o) case GenericTarget(Some(c), Some(m), component) => val path = getPath.getOrElse(Nil) @@ -271,7 +276,7 @@ case class GenericTarget(circuitOpt: Option[String], override def isLocal: Boolean = !(getPath.nonEmpty && getPath.get.nonEmpty) - def path: Vector[(Instance, OfModule)] = if(isComplete){ + def path: Vector[(Instance, OfModule)] = if (isComplete) { tokens.zip(tokens.tail).collect { case (i: Instance, o: OfModule) => (i, o) } @@ -280,9 +285,9 @@ case class GenericTarget(circuitOpt: Option[String], /** If complete, return this [[GenericTarget]]'s path * @return */ - def getPath: Option[Seq[(Instance, OfModule)]] = if(isComplete) { - val allInstOfs = tokens.grouped(2).collect { case Seq(i: Instance, o:OfModule) => (i, o)}.toSeq - if(tokens.nonEmpty && tokens.last.isInstanceOf[OfModule]) Some(allInstOfs.dropRight(1)) else Some(allInstOfs) + def getPath: Option[Seq[(Instance, OfModule)]] = if (isComplete) { + val allInstOfs = tokens.grouped(2).collect { case Seq(i: Instance, o: OfModule) => (i, o) }.toSeq + if (tokens.nonEmpty && tokens.last.isInstanceOf[OfModule]) Some(allInstOfs.dropRight(1)) else Some(allInstOfs) } else { None } @@ -290,7 +295,7 @@ case class GenericTarget(circuitOpt: Option[String], /** If complete and a reference, return the reference and subcomponents * @return */ - def getRef: Option[(String, Seq[TargetToken])] = if(isComplete) { + def getRef: Option[(String, Seq[TargetToken])] = if (isComplete) { val (optRef, comps) = tokens.foldLeft((None: Option[String], Vector.empty[TargetToken])) { case ((None, v), Ref(r)) => (Some(r), v) case ((r: Some[String], comps), c) => (r, comps :+ c) @@ -304,7 +309,7 @@ case class GenericTarget(circuitOpt: Option[String], /** If complete and an instance target, return the instance and ofmodule * @return */ - def getInstanceOf: Option[(String, String)] = if(isComplete) { + def getInstanceOf: Option[(String, String)] = if (isComplete) { tokens.grouped(2).foldLeft(None: Option[(String, String)]) { case (instOf, Seq(i: Instance, o: OfModule)) => Some((i.value, o.value)) case (instOf, _) => None @@ -328,14 +333,14 @@ case class GenericTarget(circuitOpt: Option[String], */ def add(token: TargetToken): GenericTarget = { token match { - case _: Instance => requireLast(true, "inst", "of") - case _: OfModule => requireLast(false, "inst") - case _: Ref => requireLast(true, "inst", "of") - case _: Field => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") - case _: Index => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") - case Init => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") - case Clock => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") - case Reset => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case _: Instance => requireLast(true, "inst", "of") + case _: OfModule => requireLast(false, "inst") + case _: Ref => requireLast(true, "inst", "of") + case _: Field => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case _: Index => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case Init => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case Clock => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") + case Reset => requireLast(true, "ref", "[]", ".", "init", "clock", "reset") } this.copy(tokens = tokens :+ token) } @@ -345,7 +350,7 @@ case class GenericTarget(circuitOpt: Option[String], /** Optionally tries to append token to tokens, fails return is not a legal Target */ def optAdd(token: TargetToken): Option[Target] = { - try{ + try { Some(add(token)) } catch { case _: IllegalArgumentException => None @@ -358,7 +363,7 @@ case class GenericTarget(circuitOpt: Option[String], def isLegal: Boolean = { try { var comp: GenericTarget = this.copy(tokens = Vector.empty) - for(token <- tokens) { + for (token <- tokens) { comp = comp.add(token) } true @@ -374,19 +379,18 @@ case class GenericTarget(circuitOpt: Option[String], def isComplete: Boolean = { isLegal && (isCircuitTarget || isModuleTarget || (isComponentTarget && tokens.tails.forall { case Instance(_) +: OfModule(_) +: tail => true - case Instance(_) +: x +: tail => false - case x +: OfModule(_) +: tail => false - case _ => true - } )) + case Instance(_) +: x +: tail => false + case x +: OfModule(_) +: tail => false + case _ => true + })) } - - def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty - def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty + def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty + def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty def isComponentTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.nonEmpty lazy val (parentModule: Option[String], astModule: Option[String]) = path match { - case Seq() => (None, moduleOpt) + case Seq() => (None, moduleOpt) case Seq((i, OfModule(o))) => (moduleOpt, Some(o)) case seq if seq.size > 1 => val reversed = seq.reverse @@ -421,7 +425,6 @@ trait CompleteTarget extends Target { override def toString: String = serialize } - /** A member of a FIRRTL Circuit (e.g. cannot point to a CircuitTarget) * Concrete Subclasses are: [[ModuleTarget]], [[InstanceTarget]], and [[ReferenceTarget]] */ @@ -456,10 +459,12 @@ trait IsMember extends CompleteTarget { /** @return List of local Instance Targets refering to each instance/ofModule in this member's path */ def pathAsTargets: Seq[InstanceTarget] = { - path.foldLeft((module, Vector.empty[InstanceTarget])) { - case ((m, vec), (Instance(i), OfModule(o))) => - (o, vec :+ InstanceTarget(circuit, m, Nil, i, o)) - }._2 + path + .foldLeft((module, Vector.empty[InstanceTarget])) { + case ((m, vec), (Instance(i), OfModule(o))) => + (o, vec :+ InstanceTarget(circuit, m, Nil, i, o)) + } + ._2 } /** Resets this target to have a new path @@ -469,7 +474,7 @@ trait IsMember extends CompleteTarget { def setPathTarget(newPath: IsModule): CompleteTarget /** @return The [[ModuleTarget]] of the module that directly contains this component */ - def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value + def encapsulatingModule: String = if (path.isEmpty) module else path.last._2.value def encapsulatingModuleTarget: ModuleTarget = ModuleTarget(circuit, encapsulatingModule) @@ -492,6 +497,7 @@ trait IsModule extends IsMember { /** A component of a FIRRTL Module (e.g. cannot point to a CircuitTarget or ModuleTarget) */ trait IsComponent extends IsMember { + /** Removes n levels of instance hierarchy * * Example: n=1, transforms (Top, A)/b:B/c:C -> (Top, B)/c:C @@ -501,13 +507,13 @@ trait IsComponent extends IsMember { def stripHierarchy(n: Int): IsMember override def toNamed: ComponentName = { - if(isLocal){ + if (isLocal) { val mn = ModuleName(module, CircuitName(circuit)) - Seq(tokens:_*) match { + Seq(tokens: _*) match { case Seq(Ref(name)) => ComponentName(name, mn) case Ref(_) :: tail if Target.isOnly(tail, ".", "[]") => - val name = tokens.foldLeft(""){ - case ("", Ref(name)) => name + val name = tokens.foldLeft("") { + case ("", Ref(name)) => name case (string, Field(value)) => s"$string.$value" case (string, Index(value)) => s"$string[$value]" } @@ -524,7 +530,8 @@ trait IsComponent extends IsMember { } override def pathTarget: IsModule = { - if(path.isEmpty) moduleTarget else { + if (path.isEmpty) moduleTarget + else { val (i, o) = path.last InstanceTarget(circuit, module, path.dropRight(1), i.value, o.value) } @@ -535,7 +542,6 @@ trait IsComponent extends IsMember { override def isLocal = path.isEmpty } - /** Target pointing to a FIRRTL [[firrtl.ir.Circuit]] * @param circuit Name of a FIRRTL circuit */ @@ -577,7 +583,8 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule { override def targetParent: CircuitTarget = CircuitTarget(circuit) - override def addHierarchy(root: String, instance: String): InstanceTarget = InstanceTarget(circuit, root, Nil, instance, module) + override def addHierarchy(root: String, instance: String): InstanceTarget = + InstanceTarget(circuit, root, Nil, instance, module) override def ref(value: String): ReferenceTarget = ReferenceTarget(circuit, module, Nil, value, Nil) @@ -613,11 +620,13 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule { * @param ref Name of component * @param component Subcomponent of this reference, e.g. field or index */ -case class ReferenceTarget(circuit: String, - module: String, - override val path: Seq[(Instance, OfModule)], - ref: String, - component: Seq[TargetToken]) extends IsComponent { +case class ReferenceTarget( + circuit: String, + module: String, + override val path: Seq[(Instance, OfModule)], + ref: String, + component: Seq[TargetToken]) + extends IsComponent { /** @param value Index value of this target * @return A new [[ReferenceTarget]] to the specified index of this [[ReferenceTarget]] @@ -648,7 +657,7 @@ case class ReferenceTarget(circuit: String, baseType } else { val headType = tokens.head match { - case Index(idx) => sub_type(baseType) + case Index(idx) => sub_type(baseType) case Field(field) => field_type(baseType, field) case _: Ref => baseType } @@ -662,7 +671,8 @@ case class ReferenceTarget(circuit: String, override def targetParent: CompleteTarget = component match { case Nil => - if(path.isEmpty) moduleTarget else { + if (path.isEmpty) moduleTarget + else { val (i, o) = path.last InstanceTarget(circuit, module, path.dropRight(1), i.value, o.value) } @@ -676,7 +686,8 @@ case class ReferenceTarget(circuit: String, override def stripHierarchy(n: Int): ReferenceTarget = { require(path.size >= n, s"Cannot strip $n levels of hierarchy from $this") - if(n == 0) this else { + if (n == 0) this + else { val newModule = path(n - 1)._2.value ReferenceTarget(circuit, newModule, path.drop(n), ref, component) } @@ -700,15 +711,15 @@ case class ReferenceTarget(circuit: String, def leafSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match { case _: firrtl.ir.GroundType => Vector(this) case firrtl.ir.VectorType(t, size) => (0 until size).flatMap { i => index(i).leafSubTargets(t) } - case firrtl.ir.BundleType(fields) => fields.flatMap { f => field(f.name).leafSubTargets(f.tpe)} - case other => sys.error(s"Error! Unexpected type $other") + case firrtl.ir.BundleType(fields) => fields.flatMap { f => field(f.name).leafSubTargets(f.tpe) } + case other => sys.error(s"Error! Unexpected type $other") } def allSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match { case _: firrtl.ir.GroundType => Vector(this) case firrtl.ir.VectorType(t, size) => this +: (0 until size).flatMap { i => index(i).allSubTargets(t) } - case firrtl.ir.BundleType(fields) => this +: fields.flatMap { f => field(f.name).allSubTargets(f.tpe)} - case other => sys.error(s"Error! Unexpected type $other") + case firrtl.ir.BundleType(fields) => this +: fields.flatMap { f => field(f.name).allSubTargets(f.tpe) } + case other => sys.error(s"Error! Unexpected type $other") } override def leafModule: String = encapsulatingModule @@ -721,11 +732,14 @@ case class ReferenceTarget(circuit: String, * @param instance Name of the instance * @param ofModule Name of the instance's module */ -case class InstanceTarget(circuit: String, - module: String, - override val path: Seq[(Instance, OfModule)], - instance: String, - ofModule: String) extends IsModule with IsComponent { +case class InstanceTarget( + circuit: String, + module: String, + override val path: Seq[(Instance, OfModule)], + instance: String, + ofModule: String) + extends IsModule + with IsComponent { /** @return a [[ReferenceTarget]] referring to this declaration of this instance */ def asReference: ReferenceTarget = ReferenceTarget(circuit, module, path, instance, Nil) @@ -744,7 +758,8 @@ case class InstanceTarget(circuit: String, override def moduleOpt: Option[String] = Some(module) override def targetParent: IsModule = { - if(isLocal) ModuleTarget(circuit, module) else { + if (isLocal) ModuleTarget(circuit, module) + else { val (newInstance, newOfModule) = path.last InstanceTarget(circuit, module, path.dropRight(1), newInstance.value, newOfModule.value) } @@ -759,8 +774,9 @@ case class InstanceTarget(circuit: String, override def stripHierarchy(n: Int): IsModule = { require(path.size + 1 >= n, s"Cannot strip $n levels of hierarchy from $this") - if(n == 0) this else { - if(path.size < n){ + if (n == 0) this + else { + if (path.size < n) { ModuleTarget(circuit, ofModule) } else { val newModule = path(n - 1)._2.value @@ -769,7 +785,7 @@ case class InstanceTarget(circuit: String, } } - override def asPath: Seq[(Instance, OfModule)] = path :+( (Instance(instance), OfModule(ofModule)) ) + override def asPath: Seq[(Instance, OfModule)] = path :+ ((Instance(instance), OfModule(ofModule))) override def pathlessTarget: InstanceTarget = InstanceTarget(circuit, encapsulatingModule, Nil, instance, ofModule) @@ -781,33 +797,32 @@ case class InstanceTarget(circuit: String, override def leafModule: String = ofModule } - /** Named classes associate an annotation with a component in a Firrtl circuit */ sealed trait Named { def serialize: String - def toTarget: CompleteTarget + def toTarget: CompleteTarget } final case class CircuitName(name: String) extends Named { - if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name") + if (!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name") def serialize: String = name - def toTarget: CircuitTarget = CircuitTarget(name) + def toTarget: CircuitTarget = CircuitTarget(name) } final case class ModuleName(name: String, circuit: CircuitName) extends Named { - if(!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name") + if (!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name") def serialize: String = circuit.serialize + "." + name - def toTarget: ModuleTarget = ModuleTarget(circuit.name, name) + def toTarget: ModuleTarget = ModuleTarget(circuit.name, name) } final case class ComponentName(name: String, module: ModuleName) extends Named { - if(!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name") - def expr: Expression = toExp(name) + if (!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name") + def expr: Expression = toExp(name) def serialize: String = module.serialize + "." + name def toTarget: ReferenceTarget = { Target.toTargetTokens(name).toList match { case Ref(r) :: components => ReferenceTarget(module.circuit.name, module.name, Nil, r, components) - case other => throw Target.NamedException(s"Cannot convert $this into [[ReferenceTarget]]: $other") + case other => throw Target.NamedException(s"Cannot convert $this into [[ReferenceTarget]]: $other") } } } diff --git a/src/main/scala/firrtl/annotations/TargetToken.scala b/src/main/scala/firrtl/annotations/TargetToken.scala index 765102a6..a4a98eed 100644 --- a/src/main/scala/firrtl/annotations/TargetToken.scala +++ b/src/main/scala/firrtl/annotations/TargetToken.scala @@ -3,12 +3,12 @@ package firrtl.annotations import firrtl._ -import ir.{DefModule, DefInstance} +import ir.{DefInstance, DefModule} /** Building block to represent a [[Target]] of a FIRRTL component */ sealed trait TargetToken { def keyword: String - def value: Any + def value: Any /** Returns whether this token is one of the type of tokens whose keyword is passed as an argument * @param keywords @@ -16,8 +16,10 @@ sealed trait TargetToken { */ def is(keywords: String*): Boolean = { keywords.map { kw => - require(TargetToken.keyword2targettoken.keySet.contains(kw), - s"Keyword $kw must be in set ${TargetToken.keyword2targettoken.keys}") + require( + TargetToken.keyword2targettoken.keySet.contains(kw), + s"Keyword $kw must be in set ${TargetToken.keyword2targettoken.keys}" + ) val lastClass = this.getClass lastClass == TargetToken.keyword2targettoken(kw)("0").getClass }.reduce(_ || _) @@ -26,20 +28,20 @@ sealed trait TargetToken { /** Object containing all [[TargetToken]] subclasses */ case object TargetToken { - case class Instance(value: String) extends TargetToken { override def keyword: String = "inst" } - case class OfModule(value: String) extends TargetToken { override def keyword: String = "of" } - case class Ref(value: String) extends TargetToken { override def keyword: String = "ref" } - case class Index(value: Int) extends TargetToken { override def keyword: String = "[]" } - case class Field(value: String) extends TargetToken { override def keyword: String = "." } - case object Clock extends TargetToken { override def keyword: String = "clock"; val value = "" } - case object Init extends TargetToken { override def keyword: String = "init"; val value = "" } - case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" } + case class Instance(value: String) extends TargetToken { override def keyword: String = "inst" } + case class OfModule(value: String) extends TargetToken { override def keyword: String = "of" } + case class Ref(value: String) extends TargetToken { override def keyword: String = "ref" } + case class Index(value: Int) extends TargetToken { override def keyword: String = "[]" } + case class Field(value: String) extends TargetToken { override def keyword: String = "." } + case object Clock extends TargetToken { override def keyword: String = "clock"; val value = "" } + case object Init extends TargetToken { override def keyword: String = "init"; val value = "" } + case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" } implicit class fromStringToTargetToken(s: String) { def Instance: Instance = new TargetToken.Instance(s) def OfModule: OfModule = new TargetToken.OfModule(s) - def Ref: Ref = new TargetToken.Ref(s) - def Field: Field = new TargetToken.Field(s) + def Ref: Ref = new TargetToken.Ref(s) + def Field: Field = new TargetToken.Field(s) } implicit class fromIntToTargetToken(i: Int) { @@ -67,4 +69,3 @@ case object TargetToken { "reset" -> ((value: String) => Reset) ) } - diff --git a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala index 8f925ee7..31d13139 100644 --- a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala +++ b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala @@ -88,10 +88,12 @@ case class DuplicationHelper(existingModules: Set[String]) { * @param originalOfModule original module being instantiated in originalModule * @return */ - def getNewOfModule(originalModule: String, - newModule: String, - instance: Instance, - originalOfModule: OfModule): OfModule = { + def getNewOfModule( + originalModule: String, + newModule: String, + instance: Instance, + originalOfModule: OfModule + ): OfModule = { dupMap.get(originalModule) match { case None => // No duplication, can return originalOfModule originalOfModule @@ -129,18 +131,18 @@ case class DuplicationHelper(existingModules: Set[String]) { val newTops = getDuplicates(top) newTops.map { newTop => val newPath = mutable.ArrayBuffer[TargetToken]() - path.foldLeft((top, newTop)) { case ((originalModule, newModule), (instance, ofModule)) => - val newOfModule = getNewOfModule(originalModule, newModule, instance, ofModule) - newPath ++= Seq(instance, newOfModule) - (ofModule.value, newOfModule.value) + path.foldLeft((top, newTop)) { + case ((originalModule, newModule), (instance, ofModule)) => + val newOfModule = getNewOfModule(originalModule, newModule, instance, ofModule) + newPath ++= Seq(instance, newOfModule) + (ofModule.value, newOfModule.value) } - val module = if(newPath.nonEmpty) newPath.last.value.toString else newTop + val module = if (newPath.nonEmpty) newPath.last.value.toString else newTop t.notPath match { - case Seq() => ModuleTarget(t.circuit, module) + case Seq() => ModuleTarget(t.circuit, module) case Instance(i) +: OfModule(m) +: Seq() => ModuleTarget(t.circuit, module) - case Ref(r) +: components => ReferenceTarget(t.circuit, module, Nil, r, components) + case Ref(r) +: components => ReferenceTarget(t.circuit, module, Nil, r, components) } }.toSeq } } - diff --git a/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala b/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala index a4219a03..20304378 100644 --- a/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala +++ b/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala @@ -27,19 +27,25 @@ class CleanupNamedTargets extends Transform with DependencyAPIMigration { override def invalidates(a: Transform) = false - private def onStatement(statement: ir.Statement) - (implicit references: ISet[ReferenceTarget], - renameMap: RenameMap, - module: ModuleTarget): Unit = statement match { + private def onStatement( + statement: ir.Statement + )( + implicit references: ISet[ReferenceTarget], + renameMap: RenameMap, + module: ModuleTarget + ): Unit = statement match { case ir.DefInstance(_, a, b, _) if references(module.instOf(a, b).asReference) => renameMap.record(module.instOf(a, b).asReference, module.instOf(a, b)) case a => statement.foreach(onStatement) } - private def onModule(module: ir.DefModule) - (implicit references: ISet[ReferenceTarget], - renameMap: RenameMap, - circuit: CircuitTarget): Unit = { + private def onModule( + module: ir.DefModule + )( + implicit references: ISet[ReferenceTarget], + renameMap: RenameMap, + circuit: CircuitTarget + ): Unit = { implicit val mTarget = circuit.module(module.name) module.foreach(onStatement) } @@ -49,7 +55,7 @@ class CleanupNamedTargets extends Transform with DependencyAPIMigration { implicit val rTargets: ISet[ReferenceTarget] = state.annotations.flatMap { case a: SingleTargetAnnotation[_] => Some(a.target) case a: MultiTargetAnnotation => a.targets.flatten - case _ => None + case _ => None }.collect { case a: ReferenceTarget => a }.toSet diff --git a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala index d92d3b5e..596a344f 100644 --- a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala +++ b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala @@ -5,7 +5,7 @@ package firrtl.annotations.transforms import firrtl.Mappers._ import firrtl.analyses.InstanceKeyGraph import firrtl.annotations.ModuleTarget -import firrtl.annotations.TargetToken.{Instance, OfModule, fromDefModuleToTargetToken} +import firrtl.annotations.TargetToken.{fromDefModuleToTargetToken, Instance, OfModule} import firrtl.annotations.analysis.DuplicationHelper import firrtl.annotations._ import firrtl.ir._ @@ -15,7 +15,6 @@ import firrtl.transforms.DedupedResult import scala.collection.mutable - /** Group of targets that should become local targets * @param targets */ @@ -36,7 +35,7 @@ case class DupedResult(newModules: Set[IsModule], originalModule: ModuleTarget) override def duplicate(n: Seq[Seq[Target]]): Annotation = { n.toList match { case Seq(newMods) => DupedResult(newMods.collect { case x: IsModule => x }.toSet, originalModule) - case _ => DupedResult(Set.empty, originalModule) + case _ => DupedResult(Set.empty, originalModule) } } } @@ -47,35 +46,35 @@ object EliminateTargetPaths { def renameModules(c: Circuit, toRename: Map[String, String], renameMap: RenameMap): Circuit = { val ct = CircuitTarget(c.main) - val cx = if(toRename.contains(c.main)) { + val cx = if (toRename.contains(c.main)) { renameMap.record(ct, CircuitTarget(toRename(c.main))) c.copy(main = toRename(c.main)) } else { c } def onMod(m: DefModule): DefModule = { - m map onStmt match { + m.map(onStmt) match { case e: ExtModule if toRename.contains(e.name) => renameMap.record(ct.module(e.name), ct.module(toRename(e.name))) e.copy(name = toRename(e.name)) - case e: Module if toRename.contains(e.name) => + case e: Module if toRename.contains(e.name) => renameMap.record(ct.module(e.name), ct.module(toRename(e.name))) e.copy(name = toRename(e.name)) case o => o } } - def onStmt(s: Statement): Statement = s map onStmt match { - case w@DefInstance(info, name, module, _) if toRename.contains(module) => w.copy(module = toRename(module)) - case other => other + def onStmt(s: Statement): Statement = s.map(onStmt) match { + case w @ DefInstance(info, name, module, _) if toRename.contains(module) => w.copy(module = toRename(module)) + case other => other } - cx map onMod + cx.map(onMod) } def reorderModules(c: Circuit, toReorder: Map[String, Double]): Circuit = { val newOrderMap = c.modules.zipWithIndex.map { case (m, _) if toReorder.contains(m.name) => m.name -> toReorder(m.name) - case (m, i) if c.modules.size > 1 => m.name -> i.toDouble / (c.modules.size - 1) - case (m, _) => m.name -> 1.0 + case (m, i) if c.modules.size > 1 => m.name -> i.toDouble / (c.modules.size - 1) + case (m, _) => m.name -> 1.0 }.toMap val newOrder = c.modules.sortBy { m => newOrderMap(m.name) } @@ -83,7 +82,6 @@ object EliminateTargetPaths { c.copy(modules = newOrder) } - } /** For a set of non-local targets, modify the instance/module hierarchy of the circuit such that @@ -116,24 +114,20 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { * @param s * @return */ - private def onStmt(dupMap: DuplicationHelper) - (originalModule: String, newModule: String) - (s: Statement): Statement = s match { - case d@DefInstance(_, name, module, _) => - val ofModule = dupMap.getNewOfModule(originalModule, newModule, Instance(name), OfModule(module)).value - d.copy(module = ofModule) - case other => other map onStmt(dupMap)(originalModule, newModule) - } + private def onStmt(dupMap: DuplicationHelper)(originalModule: String, newModule: String)(s: Statement): Statement = + s match { + case d @ DefInstance(_, name, module, _) => + val ofModule = dupMap.getNewOfModule(originalModule, newModule, Instance(name), OfModule(module)).value + d.copy(module = ofModule) + case other => other.map(onStmt(dupMap)(originalModule, newModule)) + } /** Returns a modified circuit and [[RenameMap]] containing the associated target remapping * @param cir * @param targets * @return */ - def run(cir: Circuit, - targets: Seq[IsMember], - iGraph: InstanceKeyGraph - ): (Circuit, RenameMap, AnnotationSeq) = { + def run(cir: Circuit, targets: Seq[IsMember], iGraph: InstanceKeyGraph): (Circuit, RenameMap, AnnotationSeq) = { val dupMap = DuplicationHelper(cir.modules.map(_.name).toSet) @@ -161,7 +155,7 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { } val finalModuleList = duplicatedModuleList - lazy val finalModuleSet = finalModuleList.map{ case a: DefModule => a.name }.toSet + lazy val finalModuleSet = finalModuleList.map { case a: DefModule => a.name }.toSet // Records how targets have been renamed val renameMap = RenameMap() @@ -203,8 +197,9 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { duplicatedParents.foreach { parent => val paths = iGraph.findInstancesInHierarchy(parent.value) val newTargets = paths.map { path => - path.tail.foldLeft(topMod: IsModule) { case (mod, wDefInst) => - mod.instOf(wDefInst.name, wDefInst.module) + path.tail.foldLeft(topMod: IsModule) { + case (mod, wDefInst) => + mod.instOf(wDefInst.name, wDefInst.module) } } newTargets.foreach(addSelfRecord(_)) @@ -219,13 +214,11 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { val (remainingAnnotations, targetsToEliminate, previouslyDeduped) = state.annotations.foldLeft( - ( Vector.empty[Annotation], - Seq.empty[CompleteTarget], - Map.empty[IsModule, (ModuleTarget, Double)] - ) - ) { case ((remainingAnnos, targets, dedupedResult), anno) => + (Vector.empty[Annotation], Seq.empty[CompleteTarget], Map.empty[IsModule, (ModuleTarget, Double)]) + ) { + case ((remainingAnnos, targets, dedupedResult), anno) => anno match { - case ResolvePaths(ts) => + case ResolvePaths(ts) => (remainingAnnos, ts ++ targets, dedupedResult) case DedupedResult(orig, dups, idx) if dups.nonEmpty => (remainingAnnos, targets, dedupedResult ++ dups.map(_ -> (orig, idx)).toMap) @@ -234,29 +227,29 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { } } - // Collect targets that are not local val targets = targetsToEliminate.collect { case x: IsMember => x } // Check validity of paths in targets val iGraph = InstanceKeyGraph(state.circuit) - val instanceOfModules = iGraph.getChildInstances.map { case(k,v) => k -> v.map(_.toTokens) }.toMap + val instanceOfModules = iGraph.getChildInstances.map { case (k, v) => k -> v.map(_.toTokens) }.toMap val targetsWithInvalidPaths = mutable.ArrayBuffer[IsMember]() targets.foreach { t => val path = t match { - case _: ModuleTarget => Nil - case i: InstanceTarget => i.asPath + case _: ModuleTarget => Nil + case i: InstanceTarget => i.asPath case r: ReferenceTarget => r.path } - path.foldLeft(t.module) { case (module, (inst: Instance, of: OfModule)) => - val childrenOpt = instanceOfModules.get(module) - if(childrenOpt.isEmpty || !childrenOpt.get.contains((inst, of))) { - targetsWithInvalidPaths += t - } - of.value + path.foldLeft(t.module) { + case (module, (inst: Instance, of: OfModule)) => + val childrenOpt = instanceOfModules.get(module) + if (childrenOpt.isEmpty || !childrenOpt.get.contains((inst, of))) { + targetsWithInvalidPaths += t + } + of.value } } - if(targetsWithInvalidPaths.nonEmpty) { + if (targetsWithInvalidPaths.nonEmpty) { val string = targetsWithInvalidPaths.mkString(",") throw NoSuchTargetException(s"""Some targets have illegal paths that cannot be resolved/eliminated: $string""") } @@ -292,7 +285,7 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { } val newTarget = t match { case r: ReferenceTarget => r.setPathTarget(newIsModule) - case i: InstanceTarget => newIsModule + case i: InstanceTarget => newIsModule } firstRenameMap.record(t, Seq(newTarget)) newTarget +: acc @@ -312,10 +305,10 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { } val iGraphx = InstanceKeyGraph(newCircuit) - val newlyUnreachableModules = iGraphx.unreachableModules.toSet diff iGraph.unreachableModules.toSet + val newlyUnreachableModules = iGraphx.unreachableModules.toSet.diff(iGraph.unreachableModules.toSet) val newCircuitGC = { - val modulesx = newCircuit.modules.flatMap{ + val modulesx = newCircuit.modules.flatMap { case dead if newlyUnreachableModules(dead.OfModule) => None case live => val m = CircuitTarget(newCircuit.main).module(live.name) @@ -338,7 +331,8 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration { val renamedCircuit = renameModules(newCircuitGC, newModuleNameMapping, renamedModuleMap) - val reorderedCircuit = reorderModules(renamedCircuit, + val reorderedCircuit = reorderModules( + renamedCircuit, previouslyDeduped.map { case (current: IsModule, (orig: ModuleTarget, idx)) => orig.name -> idx diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala index f7ab9927..66690f56 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala @@ -26,7 +26,8 @@ private class Btor2Serializer private () { private def comment(c: String): Unit = { lines += s"; $c" } private def trailingComment(c: String): Unit = { val lastLine = lines.last - val newLine = if(lastLine.contains(';')) { lastLine + " " + c} else { lastLine + " ; " + c } + val newLine = if (lastLine.contains(';')) { lastLine + " " + c } + else { lastLine + " ; " + c } lines(lines.size - 1) = newLine } @@ -38,54 +39,55 @@ private class Btor2Serializer private () { // bit vector expression serialization private def s(expr: BVExpr): Int = expr match { case BVLiteral(value, width) => lit(value, width) - case BVSymbol(name, _) => symbols(name) - case BVExtend(e, 0, _) => s(e) - case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by") - case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by") + case BVSymbol(name, _) => symbols(name) + case BVExtend(e, 0, _) => s(e) + case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by") + case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by") case BVSlice(e, hi, lo) => - if (lo == 0 && hi == e.width - 1) { s(e) } else { + if (lo == 0 && hi == e.width - 1) { s(e) } + else { line(s"slice ${t(expr.width)} ${s(e)} $hi $lo") } - case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b) - case BVNot(BVNot(e)) => s(e) - case BVNot(e) => unary("not", expr.width, e) - case BVNegate(e) => unary("neg", expr.width, e) - case BVReduceAnd(e) => unary("redand", expr.width, e) - case BVReduceOr(e) => unary("redor", expr.width, e) - case BVReduceXor(e) => unary("redxor", expr.width, e) - case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b) - case BVImplies(a, b) => binary("implies", expr.width, a, b) - case BVEqual(a, b) => binary("eq", expr.width, a, b) - case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}") - case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b) + case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b) + case BVNot(BVNot(e)) => s(e) + case BVNot(e) => unary("not", expr.width, e) + case BVNegate(e) => unary("neg", expr.width, e) + case BVReduceAnd(e) => unary("redand", expr.width, e) + case BVReduceOr(e) => unary("redor", expr.width, e) + case BVReduceXor(e) => unary("redxor", expr.width, e) + case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b) + case BVImplies(a, b) => binary("implies", expr.width, a, b) + case BVEqual(a, b) => binary("eq", expr.width, a, b) + case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}") + case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b) case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b) - case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b) - case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) - case BVOp(op, a, b) => binary(s(op), expr.width, a, b) - case BVConcat(a, b) => binary("concat", expr.width, a, b) + case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b) + case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b) + case BVOp(op, a, b) => binary(s(op), expr.width, a, b) + case BVConcat(a, b) => binary("concat", expr.width, a, b) case ArrayRead(array, index) => line(s"read ${t(expr.width)} ${s(array)} ${s(index)}") case BVIte(cond, tru, fals) => line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}") - case r : BVRawExpr => + case r: BVRawExpr => throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") } private def s(op: Op.Value): String = op match { - case Op.And => "and" - case Op.Or => "or" - case Op.Xor => "xor" + case Op.And => "and" + case Op.Or => "or" + case Op.Xor => "xor" case Op.ArithmeticShiftRight => "sra" - case Op.ShiftRight => "srl" - case Op.ShiftLeft => "sll" - case Op.Add => "add" - case Op.Mul => "mul" - case Op.Sub => "sub" - case Op.SignedDiv => "sdiv" - case Op.UnsignedDiv => "udiv" - case Op.SignedMod => "smod" - case Op.SignedRem => "srem" - case Op.UnsignedRem => "urem" + case Op.ShiftRight => "srl" + case Op.ShiftLeft => "sll" + case Op.Add => "add" + case Op.Mul => "mul" + case Op.Sub => "sub" + case Op.SignedDiv => "sdiv" + case Op.UnsignedDiv => "udiv" + case Op.SignedMod => "smod" + case Op.SignedRem => "srem" + case Op.UnsignedRem => "urem" } private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}") @@ -123,18 +125,18 @@ private class Btor2Serializer private () { // It is essential to model memories, so any support in the wild should be fairly well tested. line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}") case ArrayConstant(e, _) => s(e) - case r : ArrayRawExpr => + case r: ArrayRawExpr => throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}") } private def s(expr: SMTExpr): Int = expr match { - case b: BVExpr => s(b) + case b: BVExpr => s(b) case a: ArrayExpr => s(a) } // serialize the type of the expression private def t(expr: SMTExpr): Int = expr match { - case b: BVExpr => t(b.width) + case b: BVExpr => t(b.width) case a: ArrayExpr => t(a.indexWidth, a.dataWidth) } @@ -145,7 +147,7 @@ private class Btor2Serializer private () { symbols(name) = id if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name") if (sys.assumes.contains(name)) line(s"constraint $id ; $name") - if (sys.asserts.contains(name)){ + if (sys.asserts.contains(name)) { val invertedId = line(s"not ${t(1)} $id") line(s"bad $invertedId ; $name") } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index 0a223840..efa89687 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -9,26 +9,26 @@ import firrtl.passes.CheckWidths.WidthTooBig private trait TranslationContext { def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe)) - def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe)) - def getRandom(width: Int): BVExpr + def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe)) + def getRandom(width: Int): BVExpr } private object FirrtlExpressionSemantics { def getWidth(tpe: ir.Type): Int = tpe match { - case ir.UIntType(ir.IntWidth(w)) => w.toInt - case ir.SIntType(ir.IntWidth(w)) => w.toInt - case ir.ClockType => 1 - case ir.ResetType => 1 + case ir.UIntType(ir.IntWidth(w)) => w.toInt + case ir.SIntType(ir.IntWidth(w)) => w.toInt + case ir.ClockType => 1 + case ir.ResetType => 1 case ir.AnalogType(ir.IntWidth(w)) => w.toInt - case other => throw new RuntimeException(s"Cannot handle type $other") + case other => throw new RuntimeException(s"Cannot handle type $other") } def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = { val eSMT = e match { case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) - case r : ir.Reference => ctx.getReference(r.serialize, r.tpe) - case r : ir.SubField => ctx.getReference(r.serialize, r.tpe) - case r : ir.SubIndex => ctx.getReference(r.serialize, r.tpe) + case r: ir.Reference => ctx.getReference(r.serialize, r.tpe) + case r: ir.SubField => ctx.getReference(r.serialize, r.tpe) + case r: ir.SubIndex => ctx.getReference(r.serialize, r.tpe) case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt) case ir.Mux(cond, tval, fval, _) => @@ -38,7 +38,10 @@ private object FirrtlExpressionSemantics { val tru = toSMT(value) BVIte(toSMT(cond), tru, ctx.getRandom(tpe)) } - assert(eSMT.width == getWidth(e), "We aim to always produce a SMT expression of the same width as the firrtl expression.") + assert( + eSMT.width == getWidth(e), + "We aim to always produce a SMT expression of the same width as the firrtl expression." + ) eSMT } @@ -47,8 +50,8 @@ private object FirrtlExpressionSemantics { forceWidth(toSMT(e), isSigned(e), width, allowNarrow) private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = { - if(eSMT.width == width) { eSMT } - else if(width < eSMT.width) { + if (eSMT.width == width) { eSMT } + else if (width < eSMT.width) { assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!") BVSlice(eSMT, width - 1, 0) } else { @@ -57,8 +60,13 @@ private object FirrtlExpressionSemantics { } // see "Primitive Operations" section in the Firrtl Specification - private def onPrim(op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt])(implicit ctx: TranslationContext): - BVExpr = { + private def onPrim( + op: ir.PrimOp, + args: Seq[ir.Expression], + consts: Seq[BigInt] + )( + implicit ctx: TranslationContext + ): BVExpr = { (op, args, consts) match { case (PrimOps.Add, Seq(e1, e2), _) => val width = args.map(getWidth).max + 1 @@ -70,7 +78,7 @@ private object FirrtlExpressionSemantics { val width = args.map(getWidth).sum BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width)) case (PrimOps.Div, Seq(num, den), _) => - val (width, op) = if(isSigned(num)) { + val (width, op) = if (isSigned(num)) { (getWidth(num) + 1, Op.SignedDiv) } else { (getWidth(num), Op.UnsignedDiv) } // "The result of a division where den is zero is undefined." @@ -83,11 +91,12 @@ private object FirrtlExpressionSemantics { val width = getWidth(num) + 1 BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width)) case (PrimOps.Rem, Seq(num, den), _) => - val op = if(isSigned(num)) Op.SignedRem else Op.UnsignedRem + val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem val width = args.map(getWidth).max val resWidth = args.map(getWidth).min val res = BVOp(op, toSMT(num, width), toSMT(den, width)) - if(res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } else { res } + if (res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } + else { res } case (PrimOps.Lt, Seq(e1, e2), _) => val width = args.map(getWidth).max BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1))) @@ -108,25 +117,29 @@ private object FirrtlExpressionSemantics { BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width))) case (PrimOps.Pad, Seq(e), Seq(n)) => val width = getWidth(e) - if(n <= width) { toSMT(e) } else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) } - case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e) ; toSMT(e) - case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e) ; toSMT(e) + if (n <= width) { toSMT(e) } + else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) } + case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e); toSMT(e) + case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e); toSMT(e) case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!") - case (PrimOps.AsClock, Seq(e), _) => toSMT(e) + case (PrimOps.AsClock, Seq(e), _) => toSMT(e) case (PrimOps.AsAsyncReset, Seq(e), _) => checkForClockInCast(PrimOps.AsAsyncReset, e) throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.") - case (PrimOps.Shl, Seq(e), Seq(n)) => if(n == 0) { toSMT(e) } else { - val zeros = BVLiteral(0, n.toInt) - BVConcat(toSMT(e), zeros) - } + case (PrimOps.Shl, Seq(e), Seq(n)) => + if (n == 0) { toSMT(e) } + else { + val zeros = BVLiteral(0, n.toInt) + BVConcat(toSMT(e), zeros) + } case (PrimOps.Shr, Seq(e), Seq(n)) => val width = getWidth(e) // "If n is greater than or equal to the bit-width of e, // the resulting value will be zero for unsigned types // and the sign bit for signed types" - if(n >= width) { - if(isSigned(e)) { BV1BitZero } else { BVSlice(toSMT(e), width - 1, width - 1) } + if (n >= width) { + if (isSigned(e)) { BV1BitZero } + else { BVSlice(toSMT(e), width - 1, width - 1) } } else { BVSlice(toSMT(e), width - 1, n.toInt) } @@ -135,9 +148,11 @@ private object FirrtlExpressionSemantics { BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width)) case (PrimOps.Dshr, Seq(e1, e2), _) => val width = getWidth(e1) - val o = if(isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight + val o = if (isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight BVOp(o, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Cvt, Seq(e), _) => if(isSigned(e)) { toSMT(e) } else { BVConcat(BV1BitZero, toSMT(e)) } + case (PrimOps.Cvt, Seq(e), _) => + if (isSigned(e)) { toSMT(e) } + else { BVConcat(BV1BitZero, toSMT(e)) } case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e))) case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e)) case (PrimOps.And, Seq(e1, e2), _) => @@ -149,10 +164,10 @@ private object FirrtlExpressionSemantics { case (PrimOps.Xor, Seq(e1, e2), _) => val width = args.map(getWidth).max BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width)) - case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e)) - case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e)) - case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e)) - case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2)) + case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e)) + case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e)) + case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e)) + case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2)) case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt) case (PrimOps.Head, Seq(e), Seq(n)) => val width = getWidth(e) @@ -167,7 +182,8 @@ private object FirrtlExpressionSemantics { } /** For now we strictly forbid casting clocks to anything else. - * Eventually this should be replaced by a more sophisticated clock analysis pass. */ + * Eventually this should be replaced by a more sophisticated clock analysis pass. + */ private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = { assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!") } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index b3a2ff17..0888b062 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -11,7 +11,16 @@ import firrtl.passes.PassException import firrtl.stage.Forms import firrtl.stage.TransformManager.TransformDependency import firrtl.transforms.PropagatePresetAnnotations -import firrtl.{CircuitState, DependencyAPIMigration, MemoryArrayInit, MemoryInitValue, MemoryScalarInit, Transform, Utils, ir} +import firrtl.{ + ir, + CircuitState, + DependencyAPIMigration, + MemoryArrayInit, + MemoryInitValue, + MemoryScalarInit, + Transform, + Utils +} import logger.LazyLogging import scala.collection.mutable @@ -22,15 +31,21 @@ import scala.collection.mutable private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr]) private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) } private case class TransitionSystem( - name: String, inputs: Array[BVSymbol], states: Array[State], signals: Array[Signal], - outputs: Set[String], assumes: Set[String], asserts: Set[String], fair: Set[String], - comments: Map[String, String] = Map(), header: Array[String] = Array()) { + name: String, + inputs: Array[BVSymbol], + states: Array[State], + signals: Array[Signal], + outputs: Set[String], + assumes: Set[String], + asserts: Set[String], + fair: Set[String], + comments: Map[String, String] = Map(), + header: Array[String] = Array()) { def serialize: String = { (Iterator(name) ++ inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++ signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++ - states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}") - ).mkString("\n") + states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n") } } @@ -53,26 +68,30 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // run the preset pass to extract all preset registers and remove preset reset signals val afterPreset = presetPass.execute(state) val circuit = afterPreset.circuit - val presetRegs = afterPreset.annotations - .collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet + val presetRegs = afterPreset.annotations.collect { + case PresetRegAnnotation(target) if target.module == circuit.main => target.ref + }.toSet // collect all non-random memory initialization val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a } - .filter(_.target.module == circuit.main).map(a => a.target.ref -> a.initValue).toMap + .filter(_.target.module == circuit.main) + .map(a => a.target.ref -> a.initValue) + .toMap // convert the main module val main = circuit.modules.find(_.name == circuit.main).get val sys = main match { case x: ir.ExtModule => throw new ExtModuleException( - "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog.") + "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog." + ) case m: ir.Module => - new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit=memInit) + new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit) } val sortedSys = TopologicalSort.run(sys) val anno = TransitionSystemAnnotation(sortedSys) - state.copy(circuit=circuit, annotations = afterPreset.annotations :+ anno ) + state.copy(circuit = circuit, annotations = afterPreset.annotations :+ anno) } } @@ -94,18 +113,23 @@ private object UnsupportedException { } private class ExtModuleException(s: String) extends PassException(s) -private class AsyncResetException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering) -private class MultiClockException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering) -private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s) +private class AsyncResetException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering) +private class MultiClockException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering) +private class MissingFeatureException(s: String) + extends PassException("Unfortunately the SMT backend does not yet support: " + s) private class ModuleToTransitionSystem extends LazyLogging { - def run(m: ir.Module, presetRegs: Set[String] = Set(), memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + def run( + m: ir.Module, + presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map() + ): TransitionSystem = { // first pass over the module to convert expressions; discover state and I/O val scan = new ModuleScanner(makeRandom) m.foreachPort(scan.onPort) // multi-clock support requires the StutteringClock transform to be run - if(scan.clocks.size > 1) { + if (scan.clocks.size > 1) { throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}") } m.foreachStmt(scan.onStatement) @@ -115,14 +139,16 @@ private class ModuleToTransitionSystem extends LazyLogging { val constraints = scan.assumes.toSet val bad = scan.asserts.toSet val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad - val signals = scan.connects.filter{ case(name, _) => isSignal.contains(name) } - .map { case (name, expr) => Signal(name, expr) } + val signals = scan.connects.filter { case (name, _) => isSignal.contains(name) }.map { + case (name, expr) => Signal(name, expr) + } // turn registers and memories into states val registers = scan.registers.map(r => r._1 -> r).toMap - val regStates = scan.connects.filter(s => registers.contains(s._1)).map { case (name, nextExpr) => - val (_, width, resetExpr, initExpr) = registers(name) - onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) + val regStates = scan.connects.filter(s => registers.contains(s._1)).map { + case (name, nextExpr) => + val (_, width, resetExpr, initExpr) = registers(name) + onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) } // turn memories into state val memoryEncoding = new MemoryEncoding(makeRandom) @@ -135,16 +161,22 @@ private class ModuleToTransitionSystem extends LazyLogging { } else { s } } // filter out any left-over self assignments (this happens when we have a registered read port) - .filter(s => s match { case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false case _ => true }) + .filter(s => + s match { + case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false + case _ => true + } + ) val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1) // generate comments from infos val comments = mutable.HashMap[String, String]() - scan.infos.foreach { case (name, info) => - serializeInfo(info).foreach { infoString => - if(comments.contains(name)) { comments(name) += InfoSeparator + infoString } - else { comments(name) = InfoPrefix + infoString } - } + scan.infos.foreach { + case (name, info) => + serializeInfo(info).foreach { infoString => + if (comments.contains(name)) { comments(name) += InfoSeparator + infoString } + else { comments(name) = InfoPrefix + infoString } + } } // inputs are original module inputs and any "random" signal we need for modelling @@ -154,11 +186,28 @@ private class ModuleToTransitionSystem extends LazyLogging { val header = serializeInfo(m.info).map(InfoPrefix + _).toArray val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints - TransitionSystem(m.name, inputs.toArray, states, signalsWithMem.toArray, outputs, constraints, bad, fair, comments.toMap, header) + TransitionSystem( + m.name, + inputs.toArray, + states, + signalsWithMem.toArray, + outputs, + constraints, + bad, + fair, + comments.toMap, + header + ) } - private def onRegister(name: String, width: Int, resetExpr: BVExpr, initExpr: BVExpr, - nextExpr: BVExpr, presetRegs: Set[String]): State = { + private def onRegister( + name: String, + width: Int, + resetExpr: BVExpr, + initExpr: BVExpr, + nextExpr: BVExpr, + presetRegs: Set[String] + ): State = { assert(initExpr.width == width) assert(nextExpr.width == width) assert(resetExpr.width == 1) @@ -166,9 +215,9 @@ private class ModuleToTransitionSystem extends LazyLogging { val hasReset = initExpr != sym val isPreset = presetRegs.contains(name) assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!") - if(hasReset) { - val init = if(isPreset) Some(initExpr) else None - val next = if(isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr) + if (hasReset) { + val init = if (isPreset) Some(initExpr) else None + val next = if (isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr) State(sym, next = Some(next), init = init) } else { State(sym, next = Some(nextExpr), init = None) @@ -179,10 +228,11 @@ private class ModuleToTransitionSystem extends LazyLogging { private val InfoPrefix = "@ " private def serializeInfo(info: ir.Info): Option[String] = info match { case ir.NoInfo => None - case f : ir.FileInfo => Some(f.escaped) - case m : ir.MultiInfo => + case f: ir.FileInfo => Some(f.escaped) + case m: ir.MultiInfo => val infos = m.flatten - if(infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } + if (infos.isEmpty) { None } + else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } } private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() @@ -190,7 +240,7 @@ private class ModuleToTransitionSystem extends LazyLogging { // TODO: actually ensure that there cannot be any name clashes with other identifiers val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get - val sym = BVSymbol(name, width) + val sym = BVSymbol(name, width) randoms(name) = sym sym } @@ -198,10 +248,16 @@ private class ModuleToTransitionSystem extends LazyLogging { private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging { type Connects = Iterable[(String, BVExpr)] - def onMemory(defMem: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (Iterable[State], Connects) = { + def onMemory( + defMem: ir.DefMemory, + connects: Connects, + initValue: Option[MemoryInitValue] + ): (Iterable[State], Connects) = { // we can only work on appropriately lowered memories - assert(defMem.dataType.isInstanceOf[ir.GroundType], - s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!") + assert( + defMem.dataType.isInstanceOf[ir.GroundType], + s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!" + ) assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") // collect all memory meta-data in a custom class @@ -214,17 +270,19 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo val init = initValue.map(getInit(m, _)) // parse and check read and write ports - val writers = defMem.writers.map( w => new WritePort(m, w, inputs)) - val readers = defMem.readers.map( r => new ReadPort(m, r, inputs)) + val writers = defMem.writers.map(w => new WritePort(m, w, inputs)) + val readers = defMem.readers.map(r => new ReadPort(m, r, inputs)) // derive next state from all write ports assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.") - val next: ArrayExpr = if(writers.isEmpty) { m.sym } else { - if(writers.length > 2) { + val next: ArrayExpr = if (writers.isEmpty) { m.sym } + else { + if (writers.length > 2) { throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})") } val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) } - if(writers.length == 1) { validData } else { + if (writers.length == 1) { validData } + else { assert(writers.length == 2) val conflict = writers.head.doesConflict(writers.last) val conflictData = writers.head.makeRandomData("_write_write_collision") @@ -236,13 +294,13 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo // derive data signals from all read ports assert(defMem.readLatency >= 0) - if(defMem.readLatency > 1) { + if (defMem.readLatency > 1) { throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})") } - val readPortSignals = if(defMem.readLatency == 0) { + val readPortSignals = if (defMem.readLatency == 0) { readers.map { r => // combinatorial read - if(defMem.readUnderWrite != ir.ReadUnderWrite.New) { + if (defMem.readUnderWrite != ir.ReadUnderWrite.New) { //logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." + // s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.") } @@ -251,22 +309,25 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo r.data.name -> data } } else { Seq() } - val readPortStates = if(defMem.readLatency == 1) { + val readPortStates = if (defMem.readLatency == 1) { readers.map { r => // we create a register for the read port data val next = defMem.readUnderWrite match { case ir.ReadUnderWrite.New => - throw new UnsupportedFeatureException(s"registered read ports that return the new value (${m.name}.${r.name})") - // the thing that makes this hard is to properly handle write conflicts + throw new UnsupportedFeatureException( + s"registered read ports that return the new value (${m.name}.${r.name})" + ) + // the thing that makes this hard is to properly handle write conflicts case ir.ReadUnderWrite.Undefined => val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r))) - if(anyWriteToTheSameAddress == False) { r.readOld() } else { + if (anyWriteToTheSameAddress == False) { r.readOld() } + else { val readUnderWriteData = r.makeRandomData("_read_under_write_undefined") BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld()) } case ir.ReadUnderWrite.Old => r.readOld() } - State(r.data, init=None, next=Some(next)) + State(r.data, init = None, next = Some(next)) } } else { Seq() } @@ -276,16 +337,20 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match { case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth) case MemoryArrayInit(values) => - assert(values.length == m.depth, - s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!") + assert( + values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" + ) // in order to get a more compact encoding try to find the most common values val histogram = mutable.LinkedHashMap[BigInt, Int]() values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) val baseValue = histogram.maxBy(_._2)._1 val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth) - values.zipWithIndex.filterNot(_._1 == baseValue) - .foldLeft[ArrayExpr](base) { case (array, (value, index)) => - ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) + values.zipWithIndex + .filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { + case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) } case other => throw new RuntimeException(s"Unsupported memory init option: $other") } @@ -295,19 +360,20 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo val depth = m.depth // derrive the type of the memory from the dataType and depth val dataWidth = getWidth(m.dataType) - val indexWidth = Utils.getUIntWidth(m.depth - 1) max 1 + val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) val sym = ArraySymbol(m.name, indexWidth, dataWidth) val prefix = m.name + "." val fullAddressRange = (BigInt(1) << indexWidth) == m.depth lazy val depthBV = BVLiteral(m.depth, indexWidth) def isValidAddress(addr: BVExpr): BVExpr = { - if(fullAddressRange) { True } else { + if (fullAddressRange) { True } + else { BVComparison(Compare.Greater, depthBV, addr, signed = false) } } } private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) { - val en: BVSymbol = makeField("en", 1) + val en: BVSymbol = makeField("en", 1) val data: BVSymbol = makeField("data", memory.dataWidth) val addr: BVSymbol = makeField("addr", memory.indexWidth) protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width) @@ -321,11 +387,11 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo val canBeOutOfRange = !memory.fullAddressRange val canBeDisabled = !enIsTrue val data = ArrayRead(memory.sym, addr) - val dataWithRangeCheck = if(canBeOutOfRange) { + val dataWithRangeCheck = if (canBeOutOfRange) { val outOfRangeData = makeRandomData("_addr_out_of_range") BVIte(memory.isValidAddress(addr), data, outOfRangeData) } else { data } - val dataWithEnabledCheck = if(canBeDisabled) { + val dataWithEnabledCheck = if (canBeDisabled) { val disabledData = makeRandomData("_not_enabled") BVIte(en, dataWithRangeCheck, disabledData) } else { dataWithRangeCheck } @@ -333,48 +399,49 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo } } private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) { + extends MemPort(memory, name, inputs) { assert(inputs(data.name).width == data.width) val mask: BVSymbol = makeField("mask", 1) assert(inputs(mask.name).width == mask.width) val maskIsTrue: Boolean = inputs(mask.name) == True val doWrite: BVExpr = (enIsTrue, maskIsTrue) match { - case (true, true) => True - case (true, false) => mask - case (false, true) => en + case (true, true) => True + case (true, false) => mask + case (false, true) => en case (false, false) => and(en, mask) } def doesConflict(r: ReadPort): BVExpr = { val sameAddress = BVEqual(r.addr, addr) - if(doWrite == True) { sameAddress } else { and(doWrite, sameAddress) } + if (doWrite == True) { sameAddress } + else { and(doWrite, sameAddress) } } def doesConflict(w: WritePort): BVExpr = { val bothWrite = and(doWrite, w.doWrite) val sameAddress = BVEqual(addr, w.addr) - if(bothWrite == True) { sameAddress } else { and(doWrite, sameAddress) } + if (bothWrite == True) { sameAddress } + else { and(doWrite, sameAddress) } } def writeTo(array: ArrayExpr): ArrayExpr = { - val doUpdate = if(memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) - val update = ArrayStore(array, index=addr, data=data) - if(doUpdate == True) update else ArrayIte(doUpdate, update, array) + val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) + val update = ArrayStore(array, index = addr, data = data) + if (doUpdate == True) update else ArrayIte(doUpdate, update, array) } } private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) { - } + extends MemPort(memory, name, inputs) {} - private def and(a: BVExpr, b: BVExpr): BVExpr = (a,b) match { + private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match { case (True, True) => True - case (True, x) => x - case (x, True) => x - case _ => BVOp(Op.And, a, b) + case (True, x) => x + case (x, True) => x + case _ => BVOp(Op.And, a, b) } private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b) private val True = BVLiteral(1, 1) private val False = BVLiteral(0, 1) - private def all(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) False else b.reduce((a,b) => and(a,b)) - private def any(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) True else b.reduce((a,b) => or(a,b)) + private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b)) + private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b)) } // performas a first pass over the module collecting all connections, wires, registers, input and outputs @@ -399,13 +466,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]() private[firrtl] def onPort(p: ir.Port): Unit = { - if(isAsyncReset(p.tpe)) { + if (isAsyncReset(p.tpe)) { throw new AsyncResetException(s"Found AsyncReset ${p.name}.") } infos.append(p.name -> p.info) p.direction match { case ir.Input => - if(isClock(p.tpe)) { + if (isClock(p.tpe)) { clocks.add(p.name) } else { inputs.append(BVSymbol(p.name, getWidth(p.tpe))) @@ -416,12 +483,12 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog private[firrtl] def onStatement(s: ir.Statement): Unit = s match { case ir.DefWire(info, name, tpe) => - if(!isClock(tpe)) { + if (!isClock(tpe)) { infos.append(name -> info) wires.append(name) } case ir.DefNode(info, name, expr) => - if(!isClock(expr.tpe)) { + if (!isClock(expr.tpe)) { insertDummyAssignsForMemoryOutputs(expr) infos.append(name -> info) val e = onExpression(expr, name) @@ -436,7 +503,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog val resetExpr = onExpression(reset, 1, name + "_reset") val initExpr = onExpression(init, width, name + "_init") registers.append((name, width, resetExpr, initExpr)) - case m : ir.DefMemory => + case m: ir.DefMemory => infos.append(m.name -> m.info) val outputs = getMemOutputs(m) (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) @@ -444,37 +511,39 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog outputs.foreach(name => unusedMemOutputs(name) = dataWidth) memories.append(m) case ir.Connect(info, loc, expr) => - if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") val name = loc.serialize insertDummyAssignsForMemoryOutputs(expr) infos.append(name -> info) connects.append((name, onExpression(expr, getWidth(loc.tpe), name))) case ir.IsInvalid(info, loc) => - if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") + if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") val name = loc.serialize infos.append(name -> info) connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe)))) case ir.DefInstance(info, name, module, tpe) => - if(!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") + if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}") // we treat all instances as blackboxes - logger.warn(s"WARN: treating instance $name of $module as blackbox. " + - "Please flatten your hierarchy if you want to include submodules in the formal model.") + logger.warn( + s"WARN: treating instance $name of $module as blackbox. " + + "Please flatten your hierarchy if you want to include submodules in the formal model." + ) val ports = tpe.asInstanceOf[ir.BundleType].fields // skip clock and async reset ports - ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe) ).foreach { p => - if(!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") + ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe)).foreach { p => + if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p") val isOutput = p.flip == ir.Default val pName = name + "." + p.name infos.append(pName -> info) // outputs of the submodule become inputs to our module - if(isOutput) { + if (isOutput) { inputs.append(BVSymbol(pName, getWidth(p.tpe))) } else { outputs.append(pName) } } case s @ ir.Verification(op, info, _, pred, en, msg) => - if(op == ir.Formal.Cover) { + if (op == ir.Formal.Cover) { logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}") } else { val name = msgToName(op.toString, msg.string) @@ -483,22 +552,22 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog val e = BVImplies(enabled, predicate) infos.append(name -> info) connects.append(name -> e) - if(op == ir.Formal.Assert) { + if (op == ir.Formal.Assert) { asserts.append(name) } else { assumes.append(name) } } - case s : ir.Conditionally => + case s: ir.Conditionally => error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}") - case s : ir.PartialConnect => + case s: ir.PartialConnect => error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}") - case s : ir.Attach => + case s: ir.Attach => error(s"Analog wires are not supported in the SMT backend: ${s.serialize}") - case s : ir.Stop => + case s: ir.Stop => // we could wire up the stop condition as output for debug reasons logger.warn(s"WARN: Stop statements are currently not supported. Ignoring: ${s.serialize}") - case s : ir.Print => + case s: ir.Print => logger.warn(s"WARN: Print statements are not supported. Ignoring: ${s.serialize}") case other => other.foreachStmt(onStatement) } @@ -520,21 +589,22 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog // example: // m.r.data <= m.r.data ; this is the dummy assign // test <= m.r.data ; this is the first use of m.r.data - private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if(unusedMemOutputs.nonEmpty) { + private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if (unusedMemOutputs.nonEmpty) { implicit val uses = mutable.ArrayBuffer[String]() findUnusedMemoryOutputUse(next) - if(uses.nonEmpty) { + if (uses.nonEmpty) { val useSet = uses.toSet - unusedMemOutputs.foreach { case (name, width) => - if(useSet.contains(name)) connects.append(name -> BVSymbol(name, width)) + unusedMemOutputs.foreach { + case (name, width) => + if (useSet.contains(name)) connects.append(name -> BVSymbol(name, width)) } useSet.foreach(name => unusedMemOutputs.remove(name)) } } private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match { - case s : ir.SubField => + case s: ir.SubField => val name = s.serialize - if(unusedMemOutputs.contains(name)) uses.append(name) + if (unusedMemOutputs.contains(name)) uses.append(name) case other => other.foreachExpr(findUnusedMemoryOutputUse) } @@ -555,17 +625,18 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog // TODO: ensure that we can generate unique names prefix + "_" + msg.replace(" ", "_").replace("|", "") } - private def error(msg: String): Unit = throw new RuntimeException(msg) + private def error(msg: String): Unit = throw new RuntimeException(msg) private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType] - private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType + private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType } private object TopologicalSort { + /** Ensures that all signals in the resulting system are topologically sorted. * This is necessary because [[firrtl.transforms.RemoveWires]] does * not sort assignments to outputs, submodule inputs nor memory ports. - * */ + */ def run(sys: TransitionSystem): TransitionSystem = { val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name) val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates) @@ -583,23 +654,24 @@ private object TopologicalSort { val known = new mutable.HashSet[String]() ++ globalSignals var needsReordering = false val digraph = new MutableDiGraph[String] - signals.foreach { case (name, expr) => - digraph.addVertex(name) - val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) - uniqueDependencies.foreach { d => - if(!known.contains(d)) { needsReordering = true } - digraph.addPairWithEdge(name, d) - } - known.add(name) + signals.foreach { + case (name, expr) => + digraph.addVertex(name) + val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr) + uniqueDependencies.foreach { d => + if (!known.contains(d)) { needsReordering = true } + digraph.addPairWithEdge(name, d) + } + known.add(name) } - if(needsReordering) { + if (needsReordering) { Some(digraph.linearize.reverse) } else { None } } private def findDependencies(expr: SMTExpr): List[String] = expr match { - case BVSymbol(name, _) => List(name) + case BVSymbol(name, _) => List(name) case ArraySymbol(name, _, _) => List(name) - case other => other.children.flatMap(findDependencies) + case other => other.children.flatMap(findDependencies) } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala index 322b8961..1c7ea42f 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala @@ -11,8 +11,10 @@ import firrtl.options.Viewer.view import firrtl.options.{CustomFileEmission, Dependency} import firrtl.stage.FirrtlOptions - -private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform with Emitter with DependencyAPIMigration { +private[firrtl] abstract class SMTEmitter private[firrtl] () + extends Transform + with Emitter + with DependencyAPIMigration { override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem)) override def invalidates(a: Transform): Boolean = false @@ -30,16 +32,16 @@ private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform w override protected def execute(state: CircuitState): CircuitState = { val emitCircuit = state.annotations.exists { - case EmitCircuitAnnotation(a) if this.getClass == a => true + case EmitCircuitAnnotation(a) if this.getClass == a => true case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!") - case _ => false + case _ => false } - if(!emitCircuit) { return state } + if (!emitCircuit) { return state } logger.warn(BleedingEdgeWarning) - val sys = state.annotations.collectFirst{ case TransitionSystemAnnotation(sys) => sys }.getOrElse { + val sys = state.annotations.collectFirst { case TransitionSystemAnnotation(sys) => sys }.getOrElse { error("Could not find the transition system!") } state.copy(annotations = state.annotations :+ serialize(sys)) @@ -52,11 +54,12 @@ private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform w } case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String) - extends NoTargetAnnotation with CustomFileEmission { + extends NoTargetAnnotation + with CustomFileEmission { override protected def baseFileName(annotations: AnnotationSeq): String = view[FirrtlOptions](annotations).outputFileName.getOrElse(name) override protected def suffix: Option[String] = Some(outputSuffix) - override def getBytes: Iterable[Byte] = src.getBytes + override def getBytes: Iterable[Byte] = src.getBytes } private[firrtl] class Btor2Emitter extends SMTEmitter { @@ -72,14 +75,14 @@ private[firrtl] class SMTLibEmitter extends SMTEmitter { override protected def serialize(sys: TransitionSystem): Annotation = { val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr]) val logic = SMTLibSerializer.setLogic(hasMemory) + "\n" - val header = if(hasMemory) { + val header = if (hasMemory) { "; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" + - "; see https://github.com/Z3Prover/z3/issues/1803\n" + - "; for CVC4 you probably want to include the logic\n" + - ";" + logic + "; see https://github.com/Z3Prover/z3/issues/1803\n" + + "; for CVC4 you probably want to include the logic\n" + + ";" + logic } else { logic } val smt = generatedHeader("SMT-LIBv2", sys.name) + header + SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n" EmittedSMTModelAnnotation(sys.name, smt, outputSuffix) } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala index 10a89e8d..ebb9e309 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala @@ -9,7 +9,7 @@ private sealed trait SMTExpr { def children: List[SMTExpr] } private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String } private object SMTSymbol { def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match { - case b: BVExpr => BVSymbol(name, b.width) + case b: BVExpr => BVSymbol(name, b.width) case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth) } } @@ -19,19 +19,19 @@ private sealed trait SMTNullaryExpr extends SMTExpr { private sealed trait BVExpr extends SMTExpr { def width: Int } private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr { - private def minWidth = value.bitLength + (if(value <= 0) 1 else 0) + private def minWidth = value.bitLength + (if (value <= 0) 1 else 0) assert(width > 0, "Zero or negative width literals are not allowed!") assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.") - override def toString: String = if(width <= 8) { + override def toString: String = if (width <= 8) { width.toString + "'b" + value.toString(2) } else { width.toString + "'x" + value.toString(16) } } private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol { - assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") + assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") assert(!name.contains("\\"), s"Invalid id $name contains `\\`") assert(width > 0, "Zero width bit vectors are not supported!") override def toString: String = name - def toStringWithType: String = name + " : " + SMTExpr.serializeType(this) + def toStringWithType: String = name + " : " + SMTExpr.serializeType(this) } private sealed trait BVUnaryExpr extends BVExpr { @@ -41,34 +41,35 @@ private sealed trait BVUnaryExpr extends BVExpr { private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr { assert(by >= 0, "Extension must be non-negative!") override val width: Int = e.width + by - override def toString: String = if(signed) { s"sext($e, $by)" } else { s"zext($e, $by)" } + override def toString: String = if (signed) { s"sext($e, $by)" } + else { s"zext($e, $by)" } } // also known as bit extract operation private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr { assert(lo >= 0, s"lo (lsb) must be non-negative!") assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo") assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi") - override def width: Int = hi - lo + 1 - override def toString: String = if(hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]" + override def width: Int = hi - lo + 1 + override def toString: String = if (hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]" } private case class BVNot(e: BVExpr) extends BVUnaryExpr { - override val width: Int = e.width + override val width: Int = e.width override def toString: String = s"not($e)" } private case class BVNegate(e: BVExpr) extends BVUnaryExpr { - override val width: Int = e.width + override val width: Int = e.width override def toString: String = s"neg($e)" } private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 + override def width: Int = 1 override def toString: String = s"redor($e)" } private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 + override def width: Int = 1 override def toString: String = s"redand($e)" } private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr { - override def width: Int = 1 + override def width: Int = 1 override def toString: String = s"redxor($e)" } @@ -79,12 +80,12 @@ private sealed trait BVBinaryExpr extends BVExpr { } private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr { assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!") - override def width: Int = 1 + override def width: Int = 1 override def toString: String = s"impl($a, $b)" } private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr { assert(a.width == b.width, s"Both argument need to be the same width!") - override def width: Int = 1 + override def width: Int = 1 override def toString: String = s"eq($a, $b)" } private object Compare extends Enumeration { @@ -94,8 +95,8 @@ private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed: assert(a.width == b.width, s"Both argument need to be the same width!") override def width: Int = 1 override def toString: String = op match { - case Compare.Greater => (if(signed) "sgt" else "ugt") + s"($a, $b)" - case Compare.GreaterEqual => (if(signed) "sgeq" else "ugeq") + s"($a, $b)" + case Compare.Greater => (if (signed) "sgt" else "ugt") + s"($a, $b)" + case Compare.GreaterEqual => (if (signed) "sgeq" else "ugeq") + s"($a, $b)" } } private object Op extends Enumeration { @@ -116,81 +117,87 @@ private object Op extends Enumeration { } private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr { assert(a.width == b.width, s"Both argument need to be the same width!") - override val width: Int = a.width + override val width: Int = a.width override def toString: String = s"$op($a, $b)" } private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr { - override val width: Int = a.width + b.width + override val width: Int = a.width + b.width override def toString: String = s"concat($a, $b)" } private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr { assert(array.indexWidth == index.width, "Index with does not match expected array index width!") - override val width: Int = array.dataWidth + override val width: Int = array.dataWidth override def toString: String = s"$array[$index]" override def children: List[SMTExpr] = List(array, index) } private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr { assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}") - override val width: Int = tru.width + override val width: Int = tru.width override def toString: String = s"ite($cond, $tru, $fals)" override def children: List[BVExpr] = List(cond, tru, fals) } private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int } private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol { - assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") + assert(!name.contains("|"), s"Invalid id $name contains escape character `|`") assert(!name.contains("\\"), s"Invalid id $name contains `\\`") override def toString: String = name - def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>" + def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>" } private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr { assert(array.indexWidth == index.width, "Index with does not match expected array index width!") assert(array.dataWidth == data.width, "Data with does not match expected array data width!") - override val dataWidth: Int = array.dataWidth + override val dataWidth: Int = array.dataWidth override val indexWidth: Int = array.indexWidth - override def toString: String = s"$array[$index := $data]" - override def children: List[SMTExpr] = List(array, index, data) + override def toString: String = s"$array[$index := $data]" + override def children: List[SMTExpr] = List(array, index, data) } private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr { assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!") - assert(tru.indexWidth == fals.indexWidth, - s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}") - assert(tru.dataWidth == fals.dataWidth, - s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}") - override val dataWidth: Int = tru.dataWidth + assert( + tru.indexWidth == fals.indexWidth, + s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}" + ) + assert( + tru.dataWidth == fals.dataWidth, + s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}" + ) + override val dataWidth: Int = tru.dataWidth override val indexWidth: Int = tru.indexWidth - override def toString: String = s"ite($cond, $tru, $fals)" - override def children: List[SMTExpr] = List(cond, tru, fals) + override def toString: String = s"ite($cond, $tru, $fals)" + override def children: List[SMTExpr] = List(cond, tru, fals) } private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr { assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!") assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!") - override def width: Int = 1 + override def width: Int = 1 override def toString: String = s"eq($a, $b)" override def children: List[SMTExpr] = List(a, b) } private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr { override val dataWidth: Int = e.width - override def toString: String = s"([$e] x ${ (BigInt(1) << indexWidth) })" - override def children: List[SMTExpr] = List(e) + override def toString: String = s"([$e] x ${(BigInt(1) << indexWidth)})" + override def children: List[SMTExpr] = List(e) } private object SMTEqual { - def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a,b) match { - case (ab : BVExpr, bb : BVExpr) => BVEqual(ab, bb) - case (aa : ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba) + def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a, b) match { + case (ab: BVExpr, bb: BVExpr) => BVEqual(ab, bb) + case (aa: ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba) case _ => throw new RuntimeException(s"Cannot compare $a and $b") } } private object SMTExpr { def serializeType(e: SMTExpr): String = e match { - case b: BVExpr => s"bv<${b.width}>" + case b: BVExpr => s"bv<${b.width}>" case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>" } } // Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]] private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr -private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTNullaryExpr
\ No newline at end of file +private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) + extends ArrayExpr + with SMTNullaryExpr diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala index 14e73253..defc787c 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala @@ -9,7 +9,7 @@ private object SMTExprVisitor { type BVFun = BVExpr => BVExpr def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match { - case b: BVExpr => map(b, bv, ar).asInstanceOf[T] + case b: BVExpr => map(b, bv, ar).asInstanceOf[T] case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T] } def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T = @@ -17,57 +17,56 @@ private object SMTExprVisitor { private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match { // nullary - case old : BVLiteral => bv(old) - case old : BVSymbol => bv(old) - case old : BVRawExpr => bv(old) + case old: BVLiteral => bv(old) + case old: BVSymbol => bv(old) + case old: BVRawExpr => bv(old) // unary - case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed)) - case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo)) - case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n)) - case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n)) - case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n)) - case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n)) - case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n)) + case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVExtend(n, by, signed)) + case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVSlice(n, hi, lo)) + case old @ BVNot(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNot(n)) + case old @ BVNegate(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNegate(n)) + case old @ BVReduceAnd(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceAnd(n)) + case old @ BVReduceOr(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceOr(n)) + case old @ BVReduceXor(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceXor(n)) // binary case old @ BVImplies(a, b) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)) + bv(if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB)) case old @ BVEqual(a, b) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)) + bv(if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB)) case old @ ArrayEqual(a, b) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)) + bv(if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB)) case old @ BVComparison(op, a, b, signed) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)) + bv(if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed)) case old @ BVOp(op, a, b) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)) + bv(if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB)) case old @ BVConcat(a, b) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)) + bv(if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB)) case old @ ArrayRead(a, b) => val (nA, nB) = (map(a, bv, ar), map(b, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)) + bv(if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB)) // ternary case old @ BVIte(a, b, c) => val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) - bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) + bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC)) } - private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match { - case old : ArrayRawExpr => ar(old) - case old : ArraySymbol => ar(old) + case old: ArrayRawExpr => ar(old) + case old: ArraySymbol => ar(old) case old @ ArrayConstant(e, indexWidth) => - val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth)) + val n = map(e, bv, ar); ar(if (n.eq(e)) old else ArrayConstant(n, indexWidth)) case old @ ArrayStore(a, b, c) => val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) - ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)) + ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC)) case old @ ArrayIte(a, b, c) => val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar)) - ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)) + ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC)) } } diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala index 1993da87..bd5e4d8c 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala @@ -6,83 +6,87 @@ package firrtl.backends.experimental.smt import scala.util.matching.Regex /** Converts STM Expressions to a SMTLib compatible string representation. - * See http://smtlib.cs.uiowa.edu/ - * Assumes well typed expression, so it is advisable to run the TypeChecker - * before serializing! - * Automatically converts 1-bit vectors to bool. - */ + * See http://smtlib.cs.uiowa.edu/ + * Assumes well typed expression, so it is advisable to run the TypeChecker + * before serializing! + * Automatically converts 1-bit vectors to bool. + */ private object SMTLibSerializer { - def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if(hasMem) "A" else "") + "UFBV)" + def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if (hasMem) "A" else "") + "UFBV)" def serialize(e: SMTExpr): String = e match { - case b : BVExpr => serialize(b) - case a : ArrayExpr => serialize(a) + case b: BVExpr => serialize(b) + case a: ArrayExpr => serialize(a) } def serializeType(e: SMTExpr): String = e match { - case b : BVExpr => serializeBitVectorType(b.width) - case a : ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) + case b: BVExpr => serializeBitVectorType(b.width) + case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth) } private def serialize(e: BVExpr): String = e match { case BVLiteral(value, width) => - val mask = (BigInt(1) << width) - 1 - val twosComplement = if(value < 0) { ((~(-value)) & mask) + 1 } else value - if(width == 1) { - if(twosComplement == 1) "true" else "false" + val mask = (BigInt(1) << width) - 1 + val twosComplement = if (value < 0) { ((~(-value)) & mask) + 1 } + else value + if (width == 1) { + if (twosComplement == 1) "true" else "false" } else { s"(_ bv$twosComplement $width)" } - case BVSymbol(name, _) => escapeIdentifier(name) - case BVExtend(e, 0, _) => serialize(e) + case BVSymbol(name, _) => escapeIdentifier(name) + case BVExtend(e, 0, _) => serialize(e) case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by)) case BVExtend(e, by, signed) => - val foo = if(signed) "sign_extend" else "zero_extend" + val foo = if (signed) "sign_extend" else "zero_extend" s"((_ $foo $by) ${asBitVector(e)})" case BVSlice(e, hi, lo) => - if(lo == 0 && hi == e.width - 1) { serialize(e) - } else { + if (lo == 0 && hi == e.width - 1) { serialize(e) } + else { val bits = s"((_ extract $hi $lo) ${asBitVector(e)})" // 1-bit extracts need to be turned into a boolean - if(lo == hi) { toBool(bits) } else { bits } + if (lo == hi) { toBool(bits) } + else { bits } } case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})" - case BVNot(BVNot(e)) => serialize(e) - case BVNot(e) => if(e.width == 1) { s"(not ${serialize(e)})" } else { s"(bvnot ${serialize(e)})" } + case BVNot(BVNot(e)) => serialize(e) + case BVNot(e) => + if (e.width == 1) { s"(not ${serialize(e)})" } + else { s"(bvnot ${serialize(e)})" } case BVNegate(e) => s"(bvneg ${asBitVector(e)})" case r: BVReduceAnd => serialize(Expander.expand(r)) - case r: BVReduceOr => serialize(Expander.expand(r)) + case r: BVReduceOr => serialize(Expander.expand(r)) case r: BVReduceXor => serialize(Expander.expand(r)) - case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b) - case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})" - case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" - case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" - case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})" + case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b) + case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})" + case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" + case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})" + case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})" case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})" - case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" - case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})" + case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})" // boolean operations get a special treatment for 1-bit vectors aka bools case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})" - case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})" + case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})" case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})" - case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") - case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" - case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" - case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" - case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case BVRawExpr(serialized, _) => serialized + case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})") + case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})" + case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})" + case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})" + case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case BVRawExpr(serialized, _) => serialized } def serialize(e: ArrayExpr): String = e match { - case ArraySymbol(name, _, _) => escapeIdentifier(name) + case ArraySymbol(name, _, _) => escapeIdentifier(name) case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})" - case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" - case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" + case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})" + case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})" case ArrayRawExpr(serialized, _, _) => serialized } def serialize(c: SMTCommand): String = c match { - case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") + case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n") case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)" case DefineFunction(name, args, e) => val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ") @@ -95,23 +99,24 @@ private object SMTLibSerializer { private def serializeArrayType(indexWidth: Int, dataWidth: Int): String = s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})" private def serializeBitVectorType(width: Int): String = - if(width == 1) { "Bool" } else { assert(width > 1) ; s"(_ BitVec $width)" } + if (width == 1) { "Bool" } + else { assert(width > 1); s"(_ BitVec $width)" } private def serialize(op: Op.Value): String = op match { - case Op.And => "bvand" - case Op.Or => "bvor" - case Op.Xor => "bvxor" + case Op.And => "bvand" + case Op.Or => "bvor" + case Op.Xor => "bvxor" case Op.ArithmeticShiftRight => "bvashr" - case Op.ShiftRight => "bvlshr" - case Op.ShiftLeft => "bvshl" - case Op.Add => "bvadd" - case Op.Mul => "bvmul" - case Op.Sub => "bvsub" - case Op.SignedDiv => "bvsdiv" - case Op.UnsignedDiv => "bvudiv" - case Op.SignedMod => "bvsmod" - case Op.SignedRem => "bvsrem" - case Op.UnsignedRem => "bvurem" + case Op.ShiftRight => "bvlshr" + case Op.ShiftLeft => "bvshl" + case Op.Add => "bvadd" + case Op.Mul => "bvmul" + case Op.Sub => "bvsub" + case Op.SignedDiv => "bvsdiv" + case Op.UnsignedDiv => "bvudiv" + case Op.SignedMod => "bvsmod" + case Op.SignedRem => "bvsrem" + case Op.UnsignedRem => "bvurem" } private def toBool(e: String): String = s"(= $e (_ bv1 1))" @@ -119,33 +124,37 @@ private object SMTLibSerializer { private val bvZero = "(_ bv0 1)" private val bvOne = "(_ bv1 1)" private def asBitVector(e: BVExpr): String = - if(e.width > 1) { serialize(e) } else { s"(ite ${serialize(e)} $bvOne $bvZero)" } + if (e.width > 1) { serialize(e) } + else { s"(ite ${serialize(e)} $bvOne $bvZero)" } // See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.\$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.\$$_~&\^<>@]*".r def escapeIdentifier(name: String): String = name match { case simple() => name - case _ => if(name.startsWith("|") && name.endsWith("|")) name else s"|$name|" + case _ => if (name.startsWith("|") && name.endsWith("|")) name else s"|$name|" } } /** Expands expressions that are not natively supported by SMTLib */ private object Expander { def expand(r: BVReduceAnd): BVExpr = { - if(r.e.width == 1) { r.e } else { + if (r.e.width == 1) { r.e } + else { val allOnes = (BigInt(1) << r.e.width) - 1 BVEqual(r.e, BVLiteral(allOnes, r.e.width)) } } def expand(r: BVReduceOr): BVExpr = { - if(r.e.width == 1) { r.e } else { + if (r.e.width == 1) { r.e } + else { BVNot(BVEqual(r.e, BVLiteral(0, r.e.width))) } } def expand(r: BVReduceXor): BVExpr = { - if(r.e.width == 1) { r.e } else { + if (r.e.width == 1) { r.e } + else { val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii)) - bits.reduce[BVExpr]((a,b) => BVOp(Op.Xor, a, b)) + bits.reduce[BVExpr]((a, b) => BVOp(Op.Xor, a, b)) } } } diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala index e9acc05b..4c60a1b0 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala @@ -10,7 +10,7 @@ import scala.collection.mutable * It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state * prevents this encoding from working with boolector. * For simplicity reasons, we do not support hierarchical designs (no `_h` function). - * */ + */ private object SMTTransitionSystemEncoder { def encode(sys: TransitionSystem): Iterable[SMTCommand] = { @@ -38,10 +38,10 @@ private object SMTTransitionSystemEncoder { cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e)) } sys.signals.foreach { signal => - val kind = if(sys.outputs.contains(signal.name)) { "output" - } else if(sys.assumes.contains(signal.name)) { "assume" - } else if(sys.asserts.contains(signal.name)) { "assert" - } else { "wire" } + val kind = if (sys.outputs.contains(signal.name)) { "output" } + else if (sys.assumes.contains(signal.name)) { "assume" } + else if (sys.asserts.contains(signal.name)) { "assert" } + else { "wire" } val sym = SMTSymbol.fromExpr(signal.name, signal.e) cmds ++= toDescription(sym, kind, sys.comments.get) define(sym, signal.e) @@ -105,18 +105,18 @@ private object SMTTransitionSystemEncoder { } private def andReduce(e: Iterable[BVExpr]): BVExpr = - if(e.isEmpty) BVLiteral(1, 1) else e.reduce((a,b) => BVOp(Op.And, a, b)) + if (e.isEmpty) BVLiteral(1, 1) else e.reduce((a, b) => BVOp(Op.And, a, b)) // All signals are modelled with functions that need to be called with the state as argument, // this replaces all Symbols with function applications to the state. private def replaceSymbols(e: SMTExpr): SMTExpr = { SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e) } - private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr] + private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr] private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match { - case BVSymbol(name, width) => BVRawExpr(s"(${id(name+suffix)} $arg)", width) - case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name+suffix)} $arg)", indexWidth, dataWidth) - case other => other + case BVSymbol(name, width) => BVRawExpr(s"(${id(name + suffix)} $arg)", width) + case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name + suffix)} $arg)", indexWidth, dataWidth) + case other => other } } diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala index d8e203f8..95db95ef 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala @@ -3,7 +3,7 @@ package firrtl.backends.experimental.smt -import firrtl.{CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils, ir} +import firrtl.{ir, CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils} import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation} import firrtl.ir.EmptyStmt import firrtl.options.Dependency @@ -32,16 +32,17 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // since this pass only runs on the main module, inlining needs to happen before override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) - override protected def execute(state: CircuitState): CircuitState = { - if(state.circuit.modules.size > 1) { - logger.warn("WARN: StutteringClockTransform currently only supports running on a single module.\n" + - s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want.") + if (state.circuit.modules.size > 1) { + logger.warn( + "WARN: StutteringClockTransform currently only supports running on a single module.\n" + + s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want." + ) } // get main module val main = state.circuit.modules.find(_.name == state.circuit.main).get match { - case m: ir.Module => m + case m: ir.Module => m case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e") } mainName = main.name @@ -64,19 +65,21 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // replace all other clocks with enable signals, unless they are the global clock val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name) - val clockToEnable = clocks.map{c => + val clockToEnable = clocks.map { c => c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow) }.toMap val portsWithEnableSignals = portsWithGlobalClock.map { p => - if(clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } else { p } + if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } + else { p } } // replace async reset with synchronous reset (since everything will we synchronous with the global clock) // unless it is a preset reset val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name) - val isPresetReset = state.annotations.collect{ case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet + val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet val resetsToChange = asyncResets.filterNot(isPresetReset).toSet val portsWithSyncReset = portsWithEnableSignals.map { p => - if(resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } else { p } + if (resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } + else { p } } // discover clock and reset connections @@ -85,8 +88,9 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // rename clocks to clock enable signals val mRef = CircuitTarget(state.circuit.main).module(main.name) val renameMap = RenameMap() - scan.clockToEnable.foreach { case (clk, en) => - renameMap.record(mRef.ref(clk), mRef.ref(en.name)) + scan.clockToEnable.foreach { + case (clk, en) => + renameMap.record(mRef.ref(clk), mRef.ref(en.name)) } // make changes @@ -103,51 +107,58 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { s match { // memory field connects case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _) - if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) => + if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) => // replace clock with the global clock - if(field == "clk") { + if (field == "clk") { c.copy(expr = ctx.globalClock) - } else if(field == "en") { + } else if (field == "en") { val m = ctx.memInfo(mem) val isWritePort = m.writers.contains(port) assert(isWritePort || m.readers.contains(port)) // for write ports we guard the write enable with the clock enable signal, similar to registers - if(isWritePort) { + if (isWritePort) { val clockEn = ctx.memPortToClockEnable(mem + "." + port) val guardedEnable = and(clockEn, c.expr) c.copy(expr = guardedEnable) } else { c } - } else { c} + } else { c } // register field connects - case c @ ir.Connect(_, r : ir.Reference, next) if ctx.registerToEnable.contains(r.name) => + case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) => val clockEnable = ctx.registerToEnable(r.name) val guardedNext = mux(clockEnable, next, r) c.copy(expr = guardedNext) // remove other clock wires and nodes case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt - case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt - case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt + case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt + case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt // change async reset to synchronous reset - case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) => - ir.Connect(info, loc.copy(tpe=Bool), expr.copy(tpe=Bool)) - case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) => - d.copy(value = value.copy(tpe=Bool)) - case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe=Bool) + case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) + if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) => + ir.Connect(info, loc.copy(tpe = Bool), expr.copy(tpe = Bool)) + case d @ ir.DefNode(_, name, value: ir.Reference) + if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) => + d.copy(value = value.copy(tpe = Bool)) + case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe = Bool) // change memory clock and synchronize reset case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) => val clockEnable = ctx.registerToEnable(name) val newReset = reset match { - case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe=Bool) - case other => other + case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe = Bool) + case other => other } - val synchronizedReset = if(reset.tpe == ir.AsyncResetType) { newReset } else { and(newReset, clockEnable) } + val synchronizedReset = if (reset.tpe == ir.AsyncResetType) { newReset } + else { and(newReset, clockEnable) } ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init) case other => other.mapStmt(onStatement) } } - private def scanClocks(m: ir.Module, initialClockToEnable: Map[String, ir.Reference], resetsToChange: Set[String]): ScanCtx = { + private def scanClocks( + m: ir.Module, + initialClockToEnable: Map[String, ir.Reference], + resetsToChange: Set[String] + ): ScanCtx = { implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange) m.foreachStmt(scanClocksAndResets) ctx @@ -162,9 +173,9 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { ctx.clockToEnable.get(expr.serialize).foreach { clockEn => ctx.clockToEnable(locName) = clockEn // keep track of memory clocks - if(loc.isInstanceOf[ir.SubField]) { + if (loc.isInstanceOf[ir.SubField]) { val parts = locName.split('.') - if(ctx.mems.contains(parts.head)) { + if (ctx.mems.contains(parts.head)) { assert(parts.length == 3 && parts.last == "clk") ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn) } @@ -182,11 +193,11 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { ctx.clockToEnable.get(clock.serialize).foreach { clockEnable => ctx.registerToEnable.append(name -> clockEnable) } - case m : ir.DefMemory => + case m: ir.DefMemory => assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!") assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!") assert(m.writeLatency == 1, "Only write-latency 1 is supported!") - if(m.readers.nonEmpty && m.readLatency == 1) { + if (m.readers.nonEmpty && m.readLatency == 1) { unsupportedError("Registers memory read ports are not properly implemented yet :(") } ctx.mems(m.name) = m @@ -233,8 +244,8 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { // memory enables which need to be guarded with clock enables val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap // keep track of memory names - val isMem: String => Boolean = scanResults.mems.contains - val memInfo: String => ir.DefMemory = scanResults.mems + val isMem: String => Boolean = scanResults.mems.contains + val memInfo: String => ir.DefMemory = scanResults.mems val isResetToChange: String => Boolean = scanResults.resetsToChange.contains } @@ -250,4 +261,4 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration { private val Bool = ir.UIntType(ir.IntWidth(1)) } -private class UnsupportedFeatureException(s: String) extends PassException(s)
\ No newline at end of file +private class UnsupportedFeatureException(s: String) extends PassException(s) diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala index 06bd5cba..a17e3e7b 100644 --- a/src/main/scala/firrtl/checks/CheckResets.scala +++ b/src/main/scala/firrtl/checks/CheckResets.scala @@ -14,8 +14,8 @@ import scala.collection.mutable import scala.annotation.tailrec object CheckResets { - class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException( - s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'") + class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) + extends PassException(s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'") // Map of Initialization Expression to check private type RegCheckList = mutable.ListBuffer[(Expression, DefRegister)] @@ -31,9 +31,11 @@ object CheckResets { class CheckResets extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), - Dependency(firrtl.transforms.RemoveReset) ) ++ firrtl.stage.Forms.MidForm + Seq( + Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(firrtl.transforms.RemoveReset) + ) ++ firrtl.stage.Forms.MidForm override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.CheckCombLoops]) @@ -45,10 +47,10 @@ class CheckResets extends Transform with DependencyAPIMigration { private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = { stmt match { - case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr - case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs - case reg @ DefRegister(_, name, _,_,_, init) if weq(WRef(name), init) => // Self-reset, allowed! - case reg @ DefRegister(_,_,_,_, reset, init) if reset.tpe == AsyncResetType => + case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr + case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs + case reg @ DefRegister(_, name, _, _, _, init) if weq(WRef(name), init) => // Self-reset, allowed! + case reg @ DefRegister(_, _, _, _, reset, init) if reset.tpe == AsyncResetType => regCheck += init -> reg case _ => // Do nothing } @@ -60,11 +62,12 @@ class CheckResets extends Transform with DependencyAPIMigration { @tailrec private def findDriver(drivers: DirectDriverMap)(expr: Expression): Expression = expr match { case lit: Literal => lit - case DoPrim(op, args, _,_) if isCast(op) => findDriver(drivers)(args.head) - case other => drivers.get(we(other)) match { - case Some(e) if wireOrNode(Utils.kind(other)) => findDriver(drivers)(e) - case _ => other - } + case DoPrim(op, args, _, _) if isCast(op) => findDriver(drivers)(args.head) + case other => + drivers.get(we(other)) match { + case Some(e) if wireOrNode(Utils.kind(other)) => findDriver(drivers)(e) + case _ => other + } } private def onMod(errors: Errors)(mod: DefModule): Unit = { diff --git a/src/main/scala/firrtl/constraint/Constraint.scala b/src/main/scala/firrtl/constraint/Constraint.scala index 247593ee..1a3bc21a 100644 --- a/src/main/scala/firrtl/constraint/Constraint.scala +++ b/src/main/scala/firrtl/constraint/Constraint.scala @@ -12,7 +12,7 @@ trait Constraint { /** Trait for constraints with more than one argument */ trait MultiAry extends Constraint { - def op(a: IsKnown, b: IsKnown): IsKnown + def op(a: IsKnown, b: IsKnown): IsKnown def merge(b1: Option[IsKnown], b2: Option[IsKnown]): Option[IsKnown] = (b1, b2) match { case (Some(x), Some(y)) => Some(op(x, y)) case (_, y: Some[_]) => y diff --git a/src/main/scala/firrtl/constraint/ConstraintSolver.scala b/src/main/scala/firrtl/constraint/ConstraintSolver.scala index a421ae17..64271ae1 100644 --- a/src/main/scala/firrtl/constraint/ConstraintSolver.scala +++ b/src/main/scala/firrtl/constraint/ConstraintSolver.scala @@ -24,7 +24,6 @@ class ConstraintSolver { type ConstraintMap = mutable.HashMap[String, (Constraint, Boolean)] private val solvedConstraintMap = new ConstraintMap() - /** Clear all previously recorded/solved constraints */ def clear(): Unit = { constraints.clear() @@ -78,7 +77,7 @@ class ConstraintSolver { def get(b: Constraint): Option[IsKnown] = { val name = b match { case IsVar(name) => name - case x => "" + case x => "" } solvedConstraintMap.get(name) match { case None => None @@ -94,7 +93,7 @@ class ConstraintSolver { def get(b: Width): Option[IsKnown] = { val name = b match { case IsVar(name) => name - case x => "" + case x => "" } solvedConstraintMap.get(name) match { case None => None @@ -103,10 +102,8 @@ class ConstraintSolver { } } - private def add(c: Inequality) = constraints += c - /** Creates an Inequality given a variable name, constraint, and whether its >= or <= * @param left * @param right @@ -114,7 +111,7 @@ class ConstraintSolver { * @return */ private def genConst(left: String, right: Constraint, geq: Boolean): Inequality = geq match { - case true => GreaterOrEqual(left, right) + case true => GreaterOrEqual(left, right) case false => LesserOrEqual(left, right) } @@ -122,14 +119,13 @@ class ConstraintSolver { def serializeConstraints: String = constraints.mkString("\n") /** For debugging, can serialize the solved constraints */ - def serializeSolutions: String = solvedConstraintMap.map{ + def serializeSolutions: String = solvedConstraintMap.map { case (k, (v, true)) => s"$k >= ${v.serialize}" case (k, (v, false)) => s"$k <= ${v.serialize}" }.mkString("\n") - - - /************* Constraint Solver Engine ****************/ + /** *********** Constraint Solver Engine *************** + */ /** Merges constraints on the same variable * @@ -148,17 +144,16 @@ class ConstraintSolver { private def mergeConstraints(constraints: Seq[Inequality]): Seq[Inequality] = { val mergedMap = mutable.HashMap[String, Inequality]() constraints.foreach { - case c if c.geq && mergedMap.contains(c.left) => - mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true) - case c if !c.geq && mergedMap.contains(c.left) => - mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false) - case c => - mergedMap(c.left) = c + case c if c.geq && mergedMap.contains(c.left) => + mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true) + case c if !c.geq && mergedMap.contains(c.left) => + mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false) + case c => + mergedMap(c.left) = c } mergedMap.values.toList } - /** Attempts to substitute variables with their corresponding forward-solved constraints * If no corresponding constraint has been visited yet, keep variable as is * @@ -167,15 +162,16 @@ class ConstraintSolver { * @return Forward solved constraint */ private def forwardSubstitution(forwardSolved: ConstraintMap)(constraint: Constraint): Constraint = { - val x = constraint map forwardSubstitution(forwardSolved) + val x = constraint.map(forwardSubstitution(forwardSolved)) x match { - case isVar: IsVar => forwardSolved get isVar.name match { - case None => isVar.asInstanceOf[Constraint] - case Some((p, geq)) => - val newT = forwardSubstitution(forwardSolved)(p) - forwardSolved(isVar.name) = (newT, geq) - newT - } + case isVar: IsVar => + forwardSolved.get(isVar.name) match { + case None => isVar.asInstanceOf[Constraint] + case Some((p, geq)) => + val newT = forwardSubstitution(forwardSolved)(p) + forwardSolved(isVar.name) = (newT, geq) + newT + } case other => other } } @@ -190,11 +186,12 @@ class ConstraintSolver { */ private def backwardSubstitution(backwardSolved: ConstraintMap)(constraint: Constraint): Constraint = { constraint match { - case isVar: IsVar => backwardSolved.get(isVar.name) match { - case Some((p, geq)) => p - case _ => isVar - } - case other => other map backwardSubstitution(backwardSolved) + case isVar: IsVar => + backwardSolved.get(isVar.name) match { + case Some((p, geq)) => p + case _ => isVar + } + case other => other.map(backwardSubstitution(backwardSolved)) } } @@ -211,7 +208,7 @@ class ConstraintSolver { * @return */ private def removeCycle(name: String, geq: Boolean)(constraint: Constraint): Constraint = - if(geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint) + if (geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint) /** Removes solvable cycles of <= inequalities * @param name Name of the variable on left side of inequality @@ -220,7 +217,7 @@ class ConstraintSolver { */ private def removeLeqCycle(name: String)(constraint: Constraint): Constraint = constraint match { case x if greaterEqThan(name)(x) => VarCon(name) - case isMin: IsMin => IsMin(isMin.children.filter{ c => !greaterEqThan(name)(c)}) + case isMin: IsMin => IsMin(isMin.children.filter { c => !greaterEqThan(name)(c) }) case x => x } @@ -231,43 +228,48 @@ class ConstraintSolver { */ private def removeGeqCycle(name: String)(constraint: Constraint): Constraint = constraint match { case x if lessEqThan(name)(x) => VarCon(name) - case isMax: IsMax => IsMax(isMax.children.filter{c => !lessEqThan(name)(c)}) + case isMax: IsMax => IsMax(isMax.children.filter { c => !lessEqThan(name)(c) }) case x => x } private def greaterEqThan(name: String)(constraint: Constraint): Boolean = constraint match { case isMin: IsMin => isMin.children.map(greaterEqThan(name)).reduce(_ && _) - case isAdd: IsAdd => isAdd.children match { - case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true - case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true - case _ => false - } - case isMul: IsMul => isMul.children match { - case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true - case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true - case _ => false - } + case isAdd: IsAdd => + isAdd.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true + case _ => false + } + case isMul: IsMul => + isMul.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true + case _ => false + } case isVar: IsVar if isVar.name == name => true case _ => false } private def lessEqThan(name: String)(constraint: Constraint): Boolean = constraint match { case isMax: IsMax => isMax.children.map(lessEqThan(name)).reduce(_ && _) - case isAdd: IsAdd => isAdd.children match { - case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true - case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true - case _ => false - } - case isMul: IsMul => isMul.children match { - case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true - case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true - case _ => false - } + case isAdd: IsAdd => + isAdd.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true + case _ => false + } + case isMul: IsMul => + isMul.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true + case _ => false + } case isVar: IsVar if isVar.name == name => true - case isNeg: IsNeg => isNeg.child match { - case isVar: IsVar if isVar.name == name => true - case _ => false - } + case isNeg: IsNeg => + isNeg.child match { + case isVar: IsVar if isVar.name == name => true + case _ => false + } case _ => false } @@ -283,7 +285,7 @@ class ConstraintSolver { case isVar: IsVar if isVar.name == name => has = true case _ => } - constraint map rec + constraint.map(rec) } rec(constraint) has @@ -300,7 +302,7 @@ class ConstraintSolver { checkMap(c.left) = c seq ++ Nil case Some(x) if x.geq != c.geq => seq ++ Seq(x, c) - case Some(x) => seq ++ Nil + case Some(x) => seq ++ Nil } } } diff --git a/src/main/scala/firrtl/constraint/Inequality.scala b/src/main/scala/firrtl/constraint/Inequality.scala index 0fa1d2eb..a01b7c85 100644 --- a/src/main/scala/firrtl/constraint/Inequality.scala +++ b/src/main/scala/firrtl/constraint/Inequality.scala @@ -6,9 +6,9 @@ package firrtl.constraint * Is passed to the constraint solver to resolve */ trait Inequality { - def left: String + def left: String def right: Constraint - def geq: Boolean + def geq: Boolean } case class GreaterOrEqual(left: String, right: Constraint) extends Inequality { @@ -20,5 +20,3 @@ case class LesserOrEqual(left: String, right: Constraint) extends Inequality { val geq = false override def toString: String = s"$left <= ${right.serialize}" } - - diff --git a/src/main/scala/firrtl/constraint/IsAdd.scala b/src/main/scala/firrtl/constraint/IsAdd.scala index e177a8b9..9305db89 100644 --- a/src/main/scala/firrtl/constraint/IsAdd.scala +++ b/src/main/scala/firrtl/constraint/IsAdd.scala @@ -1,39 +1,38 @@ // See LICENSE for license details. - package firrtl.constraint // Is case class because writing tests is easier due to equality is not object equality -case class IsAdd private (known: Option[IsKnown], - maxs: Vector[IsMax], - mins: Vector[IsMin], - others: Vector[Constraint]) extends Constraint with MultiAry { +case class IsAdd private (known: Option[IsKnown], maxs: Vector[IsMax], mins: Vector[IsMin], others: Vector[Constraint]) + extends Constraint + with MultiAry { def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 + b2 lazy val children: Vector[Constraint] = { - if(known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others + if (known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others } def addChild(x: Constraint): IsAdd = x match { - case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others) - case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others) - case max: IsMax => new IsAdd(known, maxs :+ max, mins, others) - case min: IsMin => new IsAdd(known, maxs, mins :+ min, others) - case other => new IsAdd(known, maxs, mins, others :+ other) + case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others) + case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others) + case max: IsMax => new IsAdd(known, maxs :+ max, mins, others) + case min: IsMin => new IsAdd(known, maxs, mins :+ min, others) + case other => new IsAdd(known, maxs, mins, others :+ other) } override def serialize: String = "(" + children.map(_.serialize).mkString(" + ") + ")" - override def map(f: Constraint=>Constraint): Constraint = IsAdd(children.map(f)) + override def map(f: Constraint => Constraint): Constraint = IsAdd(children.map(f)) def reduce(): Constraint = { - if(children.size == 1) children.head else { + if (children.size == 1) children.head + else { (known, maxs, mins, others) match { case (Some(k), _, _, _) if k.value == 0 => new IsAdd(None, maxs, mins, others).reduce() case (Some(k), Vector(max), Vector(), Vector()) => max.map { o => IsAdd(k, o) }.reduce() case (Some(k), Vector(), Vector(min), Vector()) => min.map { o => IsAdd(k, o) }.reduce() - case _ => this + case _ => this } } } @@ -45,8 +44,10 @@ object IsAdd { case _ => apply(Seq(left, right)) } def apply(children: Seq[Constraint]): Constraint = { - children.foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) => - add.addChild(c) - }.reduce() + children + .foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) => + add.addChild(c) + } + .reduce() } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/constraint/IsFloor.scala b/src/main/scala/firrtl/constraint/IsFloor.scala index 5de4697e..60f049bb 100644 --- a/src/main/scala/firrtl/constraint/IsFloor.scala +++ b/src/main/scala/firrtl/constraint/IsFloor.scala @@ -10,13 +10,13 @@ case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint override def reduce(): Constraint = child match { case k: IsKnown => k.floor - case x: IsAdd => this - case x: IsMul => this - case x: IsNeg => this - case x: IsPow => this + case x: IsAdd => this + case x: IsMul => this + case x: IsNeg => this + case x: IsPow => this // floor(max(a, b)) -> max(floor(a), floor(b)) - case x: IsMax => IsMax(x.children.map {b => IsFloor(b)}) - case x: IsMin => IsMin(x.children.map {b => IsFloor(b)}) + case x: IsMax => IsMax(x.children.map { b => IsFloor(b) }) + case x: IsMin => IsMin(x.children.map { b => IsFloor(b) }) case x: IsVar => this // floor(floor(x)) -> floor(x) case x: IsFloor => x @@ -24,9 +24,7 @@ case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint } val children = Vector(child) - override def map(f: Constraint=>Constraint): Constraint = IsFloor(f(child)) + override def map(f: Constraint => Constraint): Constraint = IsFloor(f(child)) override def serialize: String = "floor(" + child.serialize + ")" } - - diff --git a/src/main/scala/firrtl/constraint/IsKnown.scala b/src/main/scala/firrtl/constraint/IsKnown.scala index 5bd25f92..07e0531c 100644 --- a/src/main/scala/firrtl/constraint/IsKnown.scala +++ b/src/main/scala/firrtl/constraint/IsKnown.scala @@ -34,11 +34,9 @@ trait IsKnown extends Constraint { /** Floor */ def floor: IsKnown - override def map(f: Constraint=>Constraint): Constraint = this + override def map(f: Constraint => Constraint): Constraint = this val children: Vector[Constraint] = Vector.empty[Constraint] def reduce(): IsKnown = this } - - diff --git a/src/main/scala/firrtl/constraint/IsMax.scala b/src/main/scala/firrtl/constraint/IsMax.scala index 3f24b7c0..0ba20c08 100644 --- a/src/main/scala/firrtl/constraint/IsMax.scala +++ b/src/main/scala/firrtl/constraint/IsMax.scala @@ -4,7 +4,7 @@ package firrtl.constraint object IsMax { def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { - case (l: IsKnown, r: IsKnown) => l max r + case (l: IsKnown, r: IsKnown) => l.max(r) case _ => apply(Seq(left, right)) } def apply(children: Seq[Constraint]): Constraint = { @@ -15,33 +15,32 @@ object IsMax { } } -case class IsMax private[constraint](known: Option[IsKnown], - mins: Vector[IsMin], - others: Vector[Constraint] - ) extends MultiAry { +case class IsMax private[constraint] (known: Option[IsKnown], mins: Vector[IsMin], others: Vector[Constraint]) + extends MultiAry { - def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 max b2 + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1.max(b2) override def serialize: String = "max(" + children.map(_.serialize).mkString(", ") + ")" - override def map(f: Constraint=>Constraint): Constraint = IsMax(children.map(f)) + override def map(f: Constraint => Constraint): Constraint = IsMax(children.map(f)) lazy val children: Vector[Constraint] = { - if(known.nonEmpty) known.get +: (mins ++ others) else mins ++ others + if (known.nonEmpty) known.get +: (mins ++ others) else mins ++ others } def reduce(): Constraint = { - if(children.size == 1) children.head else { + if (children.size == 1) children.head + else { (known, mins, others) match { case (Some(IsKnown(a)), _, _) => // Eliminate minimums who have a known minimum value which is smaller than known maximum value val filteredMins = mins.filter { case IsMin(Some(IsKnown(i)), _, _) if i <= a => false - case other => true + case other => true } // If a successful filter, rerun reduce val newMax = new IsMax(known, filteredMins, others) - if(filteredMins.size != mins.size) { + if (filteredMins.size != mins.size) { newMax.reduce() } else newMax case _ => this @@ -50,10 +49,9 @@ case class IsMax private[constraint](known: Option[IsKnown], } def addChild(x: Constraint): IsMax = x match { - case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others) - case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others) - case min: IsMin => new IsMax(known, mins :+ min, others) - case other => new IsMax(known, mins, others :+ other) + case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others) + case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others) + case min: IsMin => new IsMax(known, mins :+ min, others) + case other => new IsMax(known, mins, others :+ other) } } - diff --git a/src/main/scala/firrtl/constraint/IsMin.scala b/src/main/scala/firrtl/constraint/IsMin.scala index ee97e298..2c5db14d 100644 --- a/src/main/scala/firrtl/constraint/IsMin.scala +++ b/src/main/scala/firrtl/constraint/IsMin.scala @@ -4,43 +4,44 @@ package firrtl.constraint object IsMin { def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { - case (l: IsKnown, r: IsKnown) => l min r + case (l: IsKnown, r: IsKnown) => l.min(r) case _ => apply(Seq(left, right)) } def apply(children: Seq[Constraint]): Constraint = { - children.foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) => - add.addChild(c) - }.reduce() + children + .foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) => + add.addChild(c) + } + .reduce() } } -case class IsMin private[constraint](known: Option[IsKnown], - maxs: Vector[IsMax], - others: Vector[Constraint] - ) extends MultiAry { +case class IsMin private[constraint] (known: Option[IsKnown], maxs: Vector[IsMax], others: Vector[Constraint]) + extends MultiAry { - def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 min b2 + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1.min(b2) override def serialize: String = "min(" + children.map(_.serialize).mkString(", ") + ")" - override def map(f: Constraint=>Constraint): Constraint = IsMin(children.map(f)) + override def map(f: Constraint => Constraint): Constraint = IsMin(children.map(f)) lazy val children: Vector[Constraint] = { - if(known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others + if (known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others } def reduce(): Constraint = { - if(children.size == 1) children.head else { + if (children.size == 1) children.head + else { (known, maxs, others) match { case (Some(IsKnown(i)), _, _) => // Eliminate maximums who have a known maximum value which is larger than known minimum value val filteredMaxs = maxs.filter { case IsMax(Some(IsKnown(a)), _, _) if a >= i => false - case other => true + case other => true } // If a successful filter, rerun reduce val newMin = new IsMin(known, filteredMaxs, others) - if(filteredMaxs.size != maxs.size) { + if (filteredMaxs.size != maxs.size) { newMin.reduce() } else newMin case _ => this @@ -49,9 +50,9 @@ case class IsMin private[constraint](known: Option[IsKnown], } def addChild(x: Constraint): IsMin = x match { - case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others) - case max: IsMax => new IsMin(known, maxs :+ max, others) - case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others) - case other => new IsMin(known, maxs, others :+ other) + case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others) + case max: IsMax => new IsMin(known, maxs :+ max, others) + case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others) + case other => new IsMin(known, maxs, others :+ other) } } diff --git a/src/main/scala/firrtl/constraint/IsMul.scala b/src/main/scala/firrtl/constraint/IsMul.scala index 3f637d75..a4acd74c 100644 --- a/src/main/scala/firrtl/constraint/IsMul.scala +++ b/src/main/scala/firrtl/constraint/IsMul.scala @@ -10,9 +10,11 @@ object IsMul { case _ => apply(Seq(left, right)) } def apply(children: Seq[Constraint]): Constraint = { - children.foldLeft(new IsMul(None, Vector())) { (add, c) => - add.addChild(c) - }.reduce() + children + .foldLeft(new IsMul(None, Vector())) { (add, c) => + add.addChild(c) + } + .reduce() } } @@ -20,19 +22,20 @@ case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) ex def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 * b2 - lazy val children: Vector[Constraint] = if(known.nonEmpty) known.get +: others else others + lazy val children: Vector[Constraint] = if (known.nonEmpty) known.get +: others else others def addChild(x: Constraint): IsMul = x match { - case k: IsKnown => new IsMul(known = merge(Some(k), known), others) - case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others) - case other => new IsMul(known, others :+ other) + case k: IsKnown => new IsMul(known = merge(Some(k), known), others) + case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others) + case other => new IsMul(known, others :+ other) } override def reduce(): Constraint = { - if(children.size == 1) children.head else { + if (children.size == 1) children.head + else { (known, others) match { - case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce() - case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0) + case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce() + case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0) case (Some(Closed(x)), Vector(m: IsMax)) if x > 0 => IsMax(m.children.map { c => IsMul(Closed(x), c) }) case (Some(Closed(x)), Vector(m: IsMax)) if x < 0 => @@ -46,7 +49,7 @@ case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) ex } } - override def map(f: Constraint=>Constraint): Constraint = IsMul(children.map(f)) + override def map(f: Constraint => Constraint): Constraint = IsMul(children.map(f)) override def serialize: String = "(" + children.map(_.serialize).mkString(" * ") + ")" } diff --git a/src/main/scala/firrtl/constraint/IsNeg.scala b/src/main/scala/firrtl/constraint/IsNeg.scala index 46f739c6..574cfd47 100644 --- a/src/main/scala/firrtl/constraint/IsNeg.scala +++ b/src/main/scala/firrtl/constraint/IsNeg.scala @@ -11,10 +11,10 @@ object IsNeg { case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint { override def reduce(): Constraint = child match { case k: IsKnown => k.neg - case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) }) - case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail) - case x: IsNeg => x.child - case x: IsPow => this + case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) }) + case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail) + case x: IsNeg => x.child + case x: IsPow => this // -[max(a, b)] -> min[-a, -b] case x: IsMax => IsMin(x.children.map { b => IsNeg(b) }) case x: IsMin => IsMax(x.children.map { b => IsNeg(b) }) @@ -24,9 +24,7 @@ case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint { lazy val children = Vector(child) - override def map(f: Constraint=>Constraint): Constraint = IsNeg(f(child)) + override def map(f: Constraint => Constraint): Constraint = IsNeg(f(child)) override def serialize: String = "(-" + child.serialize + ")" } - - diff --git a/src/main/scala/firrtl/constraint/IsPow.scala b/src/main/scala/firrtl/constraint/IsPow.scala index 54a06bf8..2a1fb14a 100644 --- a/src/main/scala/firrtl/constraint/IsPow.scala +++ b/src/main/scala/firrtl/constraint/IsPow.scala @@ -12,22 +12,20 @@ case class IsPow private (child: Constraint, dummyArg: Int) extends Constraint { override def reduce(): Constraint = child match { case k: IsKnown => k.pow // 2^(a + b) -> 2^a * 2^b - case x: IsAdd => IsMul(x.children.map { b => IsPow(b)}) + case x: IsAdd => IsMul(x.children.map { b => IsPow(b) }) case x: IsMul => this case x: IsNeg => this case x: IsPow => this // 2^(max(a, b)) -> max(2^a, 2^b) since two is always positive, so a, b control magnitude - case x: IsMax => IsMax(x.children.map {b => IsPow(b)}) - case x: IsMin => IsMin(x.children.map {b => IsPow(b)}) + case x: IsMax => IsMax(x.children.map { b => IsPow(b) }) + case x: IsMin => IsMin(x.children.map { b => IsPow(b) }) case x: IsVar => this case _ => this } val children = Vector(child) - override def map(f: Constraint=>Constraint): Constraint = IsPow(f(child)) + override def map(f: Constraint => Constraint): Constraint = IsPow(f(child)) override def serialize: String = "(2^" + child.serialize + ")" } - - diff --git a/src/main/scala/firrtl/constraint/IsVar.scala b/src/main/scala/firrtl/constraint/IsVar.scala index 98396fa0..18fb53b2 100644 --- a/src/main/scala/firrtl/constraint/IsVar.scala +++ b/src/main/scala/firrtl/constraint/IsVar.scala @@ -16,7 +16,7 @@ trait IsVar extends Constraint { override def serialize: String = name - override def map(f: Constraint=>Constraint): Constraint = this + override def map(f: Constraint => Constraint): Constraint = this override def reduce() = this @@ -24,4 +24,3 @@ trait IsVar extends Constraint { } case class VarCon(name: String) extends IsVar - diff --git a/src/main/scala/firrtl/features/LetterCaseTransform.scala b/src/main/scala/firrtl/features/LetterCaseTransform.scala index a6cd270a..8610d7b1 100644 --- a/src/main/scala/firrtl/features/LetterCaseTransform.scala +++ b/src/main/scala/firrtl/features/LetterCaseTransform.scala @@ -8,14 +8,15 @@ import firrtl.transforms.ManipulateNames import scala.reflect.ClassTag /** Parent of transforms that do change the letter case of names in a FIRRTL circuit */ -abstract class LetterCaseTransform[A <: ManipulateNames[_] : ClassTag] extends ManipulateNames[A] { +abstract class LetterCaseTransform[A <: ManipulateNames[_]: ClassTag] extends ManipulateNames[A] { protected def newName: String => String - final def manipulate = (a: String, ns: Namespace) => newName(a) match { - case `a` => None - case b => Some(ns.newName(b)) - } + final def manipulate = (a: String, ns: Namespace) => + newName(a) match { + case `a` => None + case b => Some(ns.newName(b)) + } } /** Convert all FIRRTL names to lowercase */ diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala index 32bcac5f..7720028c 100644 --- a/src/main/scala/firrtl/graph/DiGraph.scala +++ b/src/main/scala/firrtl/graph/DiGraph.scala @@ -2,7 +2,7 @@ package firrtl.graph -import scala.collection.{Map, Set, mutable} +import scala.collection.{mutable, Map, Set} import scala.collection.mutable.{LinkedHashMap, LinkedHashSet} /** An exception that is raised when an assumed DAG has a cycle */ @@ -13,6 +13,7 @@ class PathNotFoundException extends Exception("Unreachable node") /** A companion to create DiGraphs from mutable data */ object DiGraph { + /** Create a DiGraph from a MutableDigraph, representing the same graph */ def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg @@ -33,7 +34,8 @@ object DiGraph { } /** Represents common behavior of all directed graphs */ -class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { +class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) { + /** Check whether the graph contains vertex v */ def contains(v: T): Boolean = edges.contains(v) @@ -74,8 +76,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) try { foundPath = path(vertex, node, blacklist = Set.empty) true - } - catch { + } catch { case _: PathNotFoundException => foundPath = Seq.empty[T] false @@ -138,7 +139,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) * @return a Map[T,T] from each visited node to its predecessor in the * traversal */ - def BFS(root: T): Map[T,T] = BFS(root, Set.empty[T]) + def BFS(root: T): Map[T, T] = BFS(root, Set.empty[T]) /** Performs breadth-first search on the directed graph, with a blacklist of nodes * @@ -147,8 +148,8 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) * @return a Map[T,T] from each visited node to its predecessor in the * traversal */ - def BFS(root: T, blacklist: Set[T]): Map[T,T] = { - val prev = new mutable.LinkedHashMap[T,T] + def BFS(root: T, blacklist: Set[T]): Map[T, T] = { + val prev = new mutable.LinkedHashMap[T, T] val queue = new mutable.Queue[T] queue.enqueue(root) while (queue.nonEmpty) { @@ -181,7 +182,9 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) * @param blacklist list of nodes to stop searching, if encountered * @return a Set[T] of nodes reachable from `root` */ - def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ case (k, v) => k }) + def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ + case (k, v) => k + }) /** Finds a path (if one exists) from one node to another * @@ -238,7 +241,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) val callStack = new mutable.Stack[StrongConnectFrame[T]] for (node <- getVertices) { - callStack.push(new StrongConnectFrame(node,getEdges(node).iterator)) + callStack.push(new StrongConnectFrame(node, getEdges(node).iterator)) while (!callStack.isEmpty) { val frame = callStack.top val v = frame.v @@ -257,7 +260,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) val w = frame.edgeIter.next if (!indices.contains(w)) { frame.childCall = Some(w) - callStack.push(new StrongConnectFrame(w,getEdges(w).iterator)) + callStack.push(new StrongConnectFrame(w, getEdges(w).iterator)) } else if (onstack.contains(w)) { lowlinks(v) = lowlinks(v).min(indices(w)) } @@ -269,8 +272,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) val w = stack.pop onstack -= w scc += w - } - while (scc.last != v); + } while (scc.last != v); sccs.append(scc.toSeq) } callStack.pop @@ -291,7 +293,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) * @param start the node to start at * @return a Map[T,Seq[Seq[T]]] where the value associated with v is the Seq of all paths from start to v */ - def pathsInDAG(start: T): LinkedHashMap[T,Seq[Seq[T]]] = { + def pathsInDAG(start: T): LinkedHashMap[T, Seq[Seq[T]]] = { // paths(v) holds the set of paths from start to v val paths = new LinkedHashMap[T, mutable.Set[Seq[T]]] val queue = new mutable.Queue[T] @@ -299,7 +301,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) def addBinding(n: T, p: Seq[T]): Unit = { paths.getOrElseUpdate(n, new LinkedHashSet[Seq[T]]) += p } - addBinding(start,Seq(start)) + addBinding(start, Seq(start)) queue += start queue ++= linearize.filter(reachable.contains(_)) while (!queue.isEmpty) { @@ -310,22 +312,25 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) } } } - paths.map({ case (k,v) => (k,v.toSeq) }) + paths.map({ case (k, v) => (k, v.toSeq) }) } /** Returns a graph with all edges reversed */ def reverse: DiGraph[T] = { val mdg = new MutableDiGraph[T] edges.foreach({ case (u, edges) => mdg.addVertex(u) }) - edges.foreach({ case (u, edges) => - edges.foreach(v => mdg.addEdge(v,u)) + edges.foreach({ + case (u, edges) => + edges.foreach(v => mdg.addEdge(v, u)) }) DiGraph(mdg) } private def filterEdges(vprime: Set[T]): LinkedHashMap[T, LinkedHashSet[T]] = { - def filterNodeSet(s: LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) }) - def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({ case (k, v) => (k, filterNodeSet(v)) }) + def filterNodeSet(s: LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) }) + def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({ + case (k, v) => (k, filterNodeSet(v)) + }) val eprime: LinkedHashMap[T, LinkedHashSet[T]] = edges.filter({ case (k, v) => vprime.contains(k) }) filterAdjacencyLists(eprime) } @@ -354,7 +359,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) */ def simplify(vprime: Set[T]): DiGraph[T] = { require(vprime.subsetOf(edges.keySet)) - val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime-v)) ) + val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime - v))) new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges) } @@ -384,6 +389,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) } class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) { + /** Add vertex v to the graph * @return v, the added vertex */ diff --git a/src/main/scala/firrtl/graph/EdgeData.scala b/src/main/scala/firrtl/graph/EdgeData.scala index 16990de0..6a63c3b9 100644 --- a/src/main/scala/firrtl/graph/EdgeData.scala +++ b/src/main/scala/firrtl/graph/EdgeData.scala @@ -6,11 +6,10 @@ import scala.collection.mutable /** * An exception that indicates that an edge cannot be found in a graph with edge data. - * + * * @note the vertex type is not captured as a type parameter, as it would be erased. */ -class EdgeNotFoundException(u: Any, v: Any) - extends IllegalArgumentException(s"Edge (${u}, ${v}) does not exist!") +class EdgeNotFoundException(u: Any, v: Any) extends IllegalArgumentException(s"Edge (${u}, ${v}) does not exist!") /** * Mixing this trait into a DiGraph indicates that each edge may be associated with an optional diff --git a/src/main/scala/firrtl/graph/EulerTour.scala b/src/main/scala/firrtl/graph/EulerTour.scala index 2d8a17e2..5e075ae2 100644 --- a/src/main/scala/firrtl/graph/EulerTour.scala +++ b/src/main/scala/firrtl/graph/EulerTour.scala @@ -6,6 +6,7 @@ import scala.collection.mutable /** Euler Tour companion object */ object EulerTour { + /** Create an Euler Tour of a `DiGraph[T]` */ def apply[T](diGraph: DiGraph[T], start: T): EulerTour[Seq[T]] = { val r = mutable.Map[Seq[T], Int]() @@ -66,8 +67,8 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { * the index of that minimum in each block, b. */ private lazy val blocks = (h ++ (1 to (m - n % m))).grouped(m).toArray - private lazy val a = blocks map (_.min) - private lazy val b = blocks map (b => b.indexOf(b.min)) + private lazy val a = blocks.map(_.min) + private lazy val b = blocks.map(b => b.indexOf(b.min)) /** Construct a Sparse Table (ST) representation for the minimum index * of a sequence of integers. Data in the returned array is indexed @@ -75,7 +76,10 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { */ private def constructSparseTable(x: Seq[Int]): Array[Array[Int]] = { val tmp = Array.ofDim[Int](x.size + 1, math.ceil(lg(x.size)).toInt) - for (i <- 0 to x.size - 1; j <- 0 to math.ceil(lg(x.size)).toInt - 1) { + for { + i <- 0 to x.size - 1 + j <- 0 to math.ceil(lg(x.size)).toInt - 1 + } { tmp(i)(j) = -1 } @@ -86,11 +90,11 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { } else { val (a, b, c) = (base, base + (1 << (size - 1)), size - 1) - val l = if (tmp(a)(c) != -1) { tmp(a)(c) } - else { tableRecursive(a, c) } + val l = if (tmp(a)(c) != -1) { tmp(a)(c) } + else { tableRecursive(a, c) } - val r = if (tmp(b)(c) != -1) { tmp(b)(c) } - else { tableRecursive(b, c) } + val r = if (tmp(b)(c) != -1) { tmp(b)(c) } + else { tableRecursive(b, c) } val min = if (x(l) < x(r)) l else r tmp(base)(size) = min @@ -99,9 +103,11 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { } } - for (i <- (0 to x.size - 1); - j <- (0 to math.ceil(lg(x.size)).toInt - 1); - if i + (1 << j) - 1 < x.size) { + for { + i <- (0 to x.size - 1) + j <- (0 to math.ceil(lg(x.size)).toInt - 1) + if i + (1 << j) - 1 < x.size + } { tableRecursive(i, j) } tmp @@ -117,16 +123,26 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { } val size = m - 1 - val out = Seq.fill(size)(Seq(-1, 1)) - .flatten.combinations(m - 1).flatMap(_.permutations).toList + val out = Seq + .fill(size)(Seq(-1, 1)) + .flatten + .combinations(m - 1) + .flatMap(_.permutations) + .toList .sortWith(sortSeqSeq) .map(_.foldLeft(Seq(0))((h, pm) => (h.head + pm) +: h).reverse) - .map{ a => + .map { a => val tmp = Array.ofDim[Int](m, m) - for (i <- 0 to size; j <- i to size) yield { + for { + i <- 0 to size + j <- i to size + } yield { val window = a.slice(i, j + 1) - tmp(i)(j) = window.indexOf(window.min) + i } - tmp }.toArray + tmp(i)(j) = window.indexOf(window.min) + i + } + tmp + } + .toArray out } private lazy val tables = constructTableLookups(m) @@ -167,7 +183,7 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { // Compute block and word indices val (block_i, block_j) = (i / m, j / m) - val (word_i, word_j) = (i % m, j % m) + val (word_i, word_j) = (i % m, j % m) /** Up to four possible minimum indices are then computed based on the * following conditions: @@ -187,12 +203,12 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(word_j) Seq(min_i) case (bi, bj) if (block_i == block_j - 1) => - val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1) - val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j) + val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(m - 1) + val min_j = block_j * m + tables(tableIdx(block_j))(0)(word_j) Seq(min_i, min_j) case _ => - val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1) - val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j) + val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(m - 1) + val min_j = block_j * m + tables(tableIdx(block_j))(0)(word_j) val (min_between_l, min_between_r) = { val range = math.floor(lg(block_j - block_i - 1)).toInt val base_0 = block_i + 1 @@ -200,7 +216,8 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) { val (idx_0, idx_1) = (st(base_0)(range), st(base_1)(range)) val (min_0, min_1) = (b(idx_0) + idx_0 * m, b(idx_1) + idx_1 * m) - (min_0, min_1) } + (min_0, min_1) + } Seq(min_i, min_between_l, min_between_r, min_j) } diff --git a/src/main/scala/firrtl/graph/RenderDiGraph.scala b/src/main/scala/firrtl/graph/RenderDiGraph.scala index b3c1373c..45be3a8f 100644 --- a/src/main/scala/firrtl/graph/RenderDiGraph.scala +++ b/src/main/scala/firrtl/graph/RenderDiGraph.scala @@ -16,7 +16,6 @@ import scala.collection.mutable */ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankDir: String = "LR") { - /** * override this to change the default way a node is displayed. Default is toString surrounded by double quotes * This example changes the double quotes to brackets @@ -38,8 +37,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD try { diGraph.linearize - } - catch { + } catch { case cyclicException: CyclicException => val node = cyclicException.node.asInstanceOf[T] path = diGraph.findLoopAtNode(node) @@ -61,31 +59,29 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD val loop = findOneLoop - if(loop.nonEmpty) { + if (loop.nonEmpty) { // Find all the children of the nodes in the loop val childrenFound = diGraph.getEdgeMap.flatMap { case (node, children) if loop.contains(node) => children - case _ => Seq.empty + case _ => Seq.empty }.toSet // Create a new DiGraph containing only loop and direct children or parents val edgeData = diGraph.getEdgeMap - val newEdgeData = edgeData.flatMap { case (node, children) => - if(loop.contains(node)) { - Some(node -> children) - } - else if(childrenFound.contains(node)) { - Some(node -> children.intersect(loop)) - } - else { - val newChildren = children.intersect(loop) - if(newChildren.nonEmpty) { - Some(node -> newChildren) - } - else { - None - } + val newEdgeData = edgeData.flatMap { + case (node, children) => + if (loop.contains(node)) { + Some(node -> children) + } else if (childrenFound.contains(node)) { + Some(node -> children.intersect(loop)) + } else { + val newChildren = children.intersect(loop) + if (newChildren.nonEmpty) { + Some(node -> newChildren) + } else { + None + } } } @@ -96,8 +92,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD } } newRenderer.toDotWithLoops(loop, getRankedNodes) - } - else { + } else { "" } } @@ -114,10 +109,11 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD val edges = diGraph.getEdgeMap - edges.foreach { case (parent, children) => - children.foreach { child => - s.append(s""" ${renderNode(parent)} -> ${renderNode(child)};""" + "\n") - } + edges.foreach { + case (parent, children) => + children.foreach { child => + s.append(s""" ${renderNode(parent)} -> ${renderNode(child)};""" + "\n") + } } s.append("}\n") s.toString @@ -137,24 +133,25 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD val edges = diGraph.getEdgeMap - edges.foreach { case (parent, children) => - allNodes += parent - allNodes ++= children + edges.foreach { + case (parent, children) => + allNodes += parent + allNodes ++= children - children.foreach { child => - val highlight = if(loopedNodes.contains(parent) && loopedNodes.contains(child)) { - "[color=red,penwidth=3.0]" - } - else { - "" + children.foreach { child => + val highlight = if (loopedNodes.contains(parent) && loopedNodes.contains(child)) { + "[color=red,penwidth=3.0]" + } else { + "" + } + s.append(s""" ${renderNode(parent)} -> ${renderNode(child)}$highlight;""" + "\n") } - s.append(s""" ${renderNode(parent)} -> ${renderNode(child)}$highlight;""" + "\n") - } } val paredRankedNodes = rankedNodes.flatMap { nodes => val newNodes = nodes.filter(allNodes.contains) - if(newNodes.nonEmpty) { Some(newNodes) } else { None } + if (newNodes.nonEmpty) { Some(newNodes) } + else { None } } paredRankedNodes.foreach { nodesAtRank => @@ -183,7 +180,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD diGraph.getEdges(node) }.filterNot(alreadyVisited.contains).distinct - if(nextNodes.nonEmpty) { + if (nextNodes.nonEmpty) { walkByRank(nextNodes, rankNumber + 1) } } @@ -191,6 +188,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD walkByRank(diGraph.findSources.toSeq) rankNodes } + /** * Convert this graph into input for the graphviz dot program. * It tries to align nodes in columns based @@ -216,7 +214,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD children }.filterNot(alreadyVisited.contains).distinct - if(nextNodes.nonEmpty) { + if (nextNodes.nonEmpty) { walkByRank(nextNodes, rankNumber + 1) } } diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 5263d9c0..2536a77e 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -39,41 +39,50 @@ case class FileInfo(escaped: String) extends Info { object FileInfo { @deprecated("Use FileInfo.fromUnEscaped instead. FileInfo.apply will be removed in FIRRTL 1.5.", "FIRRTL 1.4") - def apply(info: StringLit): FileInfo = new FileInfo(escape(info.string)) - def fromEscaped(s: String): FileInfo = new FileInfo(s) - def fromUnescaped(s: String): FileInfo = new FileInfo(escape(s)) + def apply(info: StringLit): FileInfo = new FileInfo(escape(info.string)) + def fromEscaped(s: String): FileInfo = new FileInfo(s) + def fromUnescaped(s: String): FileInfo = new FileInfo(escape(s)) + /** prepends a `\` to: `\`, `\n`, `\t` and `]` */ def escape(s: String): String = EscapeFirrtl.translate(s) + /** removes the `\` in front of `\`, `\n`, `\t` and `]` */ def unescape(s: String): String = UnescapeFirrtl.translate(s) + /** take an already escaped String and do the additional escaping needed for Verilog comment */ def escapedToVerilog(s: String) = EscapedToVerilog.translate(s) // custom `CharSequenceTranslator` for FIRRTL Info String escaping type CharMap = (CharSequence, CharSequence) - private val EscapeFirrtl = new LookupTranslator(Seq[CharMap]( - "\\" -> "\\\\", - "\n" -> "\\n", - "\t" -> "\\t", - "]" -> "\\]" - ).toMap.asJava) - private val UnescapeFirrtl = new LookupTranslator(Seq[CharMap]( - "\\\\" -> "\\", - "\\n" -> "\n", - "\\t" -> "\t", - "\\]" -> "]" - ).toMap.asJava) + private val EscapeFirrtl = new LookupTranslator( + Seq[CharMap]( + "\\" -> "\\\\", + "\n" -> "\\n", + "\t" -> "\\t", + "]" -> "\\]" + ).toMap.asJava + ) + private val UnescapeFirrtl = new LookupTranslator( + Seq[CharMap]( + "\\\\" -> "\\", + "\\n" -> "\n", + "\\t" -> "\t", + "\\]" -> "]" + ).toMap.asJava + ) // EscapeFirrtl + EscapedToVerilog essentially does the same thing as running StringEscapeUtils.unescapeJava private val EscapedToVerilog = new AggregateTranslator( - new LookupTranslator(Seq[CharMap]( - // ] is the one character that firrtl needs to be escaped that does not need to be escaped in - "\\]" -> "]", - "\"" -> "\\\"", - // \n and \t are already escaped - "\b" -> "\\b", - "\f" -> "\\f", - "\r" -> "\\r" - ).toMap.asJava), + new LookupTranslator( + Seq[CharMap]( + // ] is the one character that firrtl needs to be escaped that does not need to be escaped in + "\\]" -> "]", + "\"" -> "\\\"", + // \n and \t are already escaped + "\b" -> "\\b", + "\f" -> "\\f", + "\r" -> "\\r" + ).toMap.asJava + ), JavaUnicodeEscaper.outsideOf(32, 0x7f) ) @@ -81,9 +90,9 @@ object FileInfo { case class MultiInfo(infos: Seq[Info]) extends Info { private def collectStrings(info: Info): Seq[String] = info match { - case f : FileInfo => Seq(f.escaped) - case MultiInfo(seq) => seq flatMap collectStrings - case NoInfo => Seq.empty + case f: FileInfo => Seq(f.escaped) + case MultiInfo(seq) => seq.flatMap(collectStrings) + case NoInfo => Seq.empty } override def toString: String = { val parts = collectStrings(this) @@ -107,12 +116,12 @@ object MultiInfo { // TODO should this be made into an API? private[firrtl] def demux(info: Info): (Info, Info, Info) = info match { case MultiInfo(infos) if infos.lengthCompare(3) == 0 => (infos(0), infos(1), infos(2)) - case other => (other, NoInfo, NoInfo) // if not exactly 3, we don't know what to do + case other => (other, NoInfo, NoInfo) // if not exactly 3, we don't know what to do } - + private def flattenInfo(infos: Seq[Info]): Seq[FileInfo] = infos.flatMap { case NoInfo => Seq() - case f : FileInfo => Seq(f) + case f: FileInfo => Seq(f) case MultiInfo(infos) => flattenInfo(infos) } } @@ -127,6 +136,7 @@ trait IsDeclaration extends HasName with HasInfo case class StringLit(string: String) extends FirrtlNode { import org.apache.commons.text.StringEscapeUtils + /** Returns an escaped and quoted String */ def escape: String = { "\"" + serialize + "\"" @@ -137,26 +147,28 @@ case class StringLit(string: String) extends FirrtlNode { def verilogFormat: StringLit = { StringLit(string.replaceAll("%x", "%h")) } + /** Returns an escaped and quoted String */ def verilogEscape: String = { // normalize to turn things like ö into o import java.text.Normalizer val normalized = Normalizer.normalize(string, Normalizer.Form.NFD) - val ascii = normalized flatMap StringLit.toASCII + val ascii = normalized.flatMap(StringLit.toASCII) ascii.mkString("\"", "", "\"") } } object StringLit { import org.apache.commons.text.StringEscapeUtils + /** Maps characters to ASCII for Verilog emission */ private def toASCII(char: Char): List[Char] = char match { case nonASCII if !nonASCII.isValidByte => List('?') - case '"' => List('\\', '"') - case '\\' => List('\\', '\\') - case c if c >= ' ' && c <= '~' => List(c) - case '\n' => List('\\', 'n') - case '\t' => List('\\', 't') - case _ => List('?') + case '"' => List('\\', '"') + case '\\' => List('\\', '\\') + case c if c >= ' ' && c <= '~' => List(c) + case '\n' => List('\\', 'n') + case '\t' => List('\\', 't') + case _ => List('?') } /** Create a StringLit from a raw parsed String */ @@ -175,8 +187,8 @@ abstract class PrimOp extends FirrtlNode { def apply(args: Any*): DoPrim = { val groups = args.groupBy { case x: Expression => "exp" - case x: BigInt => "int" - case x: Int => "int" + case x: BigInt => "int" + case x: Int => "int" case other => "other" } val exprs = groups.getOrElse("exp", Nil).collect { @@ -185,11 +197,11 @@ abstract class PrimOp extends FirrtlNode { val consts = groups.getOrElse("int", Nil).map { _ match { case i: BigInt => i - case i: Int => BigInt(i) + case i: Int => BigInt(i) } } groups.get("other") match { - case None => + case None => case Some(x) => sys.error(s"Shouldn't be here: $x") } DoPrim(this, exprs, consts, UnknownType) @@ -198,12 +210,12 @@ abstract class PrimOp extends FirrtlNode { abstract class Expression extends FirrtlNode { def tpe: Type - def mapExpr(f: Expression => Expression): Expression - def mapType(f: Type => Type): Expression - def mapWidth(f: Width => Width): Expression - def foreachExpr(f: Expression => Unit): Unit - def foreachType(f: Type => Unit): Unit - def foreachWidth(f: Width => Unit): Unit + def mapExpr(f: Expression => Expression): Expression + def mapType(f: Type => Type): Expression + def mapWidth(f: Width => Width): Expression + def foreachExpr(f: Expression => Unit): Unit + def foreachType(f: Type => Unit): Unit + def foreachWidth(f: Width => Unit): Unit } /** Represents reference-like expression nodes: SubField, SubIndex, SubAccess and Reference @@ -215,75 +227,92 @@ abstract class Expression extends FirrtlNode { sealed trait RefLikeExpression extends Expression { def flow: Flow } object Reference { + /** Creates a Reference from a Wire */ def apply(wire: DefWire): Reference = Reference(wire.name, wire.tpe, WireKind, UnknownFlow) + /** Creates a Reference from a Register */ def apply(reg: DefRegister): Reference = Reference(reg.name, reg.tpe, RegKind, UnknownFlow) + /** Creates a Reference from a Node */ def apply(node: DefNode): Reference = Reference(node.name, node.value.tpe, NodeKind, SourceFlow) + /** Creates a Reference from a Port */ def apply(port: Port): Reference = Reference(port.name, port.tpe, PortKind, UnknownFlow) + /** Creates a Reference from a DefInstance */ def apply(i: DefInstance): Reference = Reference(i.name, i.tpe, InstanceKind, UnknownFlow) + /** Creates a Reference from a DefMemory */ def apply(mem: DefMemory): Reference = Reference(mem.name, passes.MemPortUtils.memType(mem), MemKind, UnknownFlow) } case class Reference(name: String, tpe: Type = UnknownType, kind: Kind = UnknownKind, flow: Flow = UnknownFlow) - extends Expression with HasName with UseSerializer with RefLikeExpression { - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachWidth(f: Width => Unit): Unit = () + extends Expression + with HasName + with UseSerializer + with RefLikeExpression { + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = () } case class SubField(expr: Expression, name: String, tpe: Type = UnknownType, flow: Flow = UnknownFlow) - extends Expression with HasName with UseSerializer with RefLikeExpression { - def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = f(expr) - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachWidth(f: Width => Unit): Unit = () + extends Expression + with HasName + with UseSerializer + with RefLikeExpression { + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = () } case class SubIndex(expr: Expression, value: Int, tpe: Type, flow: Flow = UnknownFlow) - extends Expression with UseSerializer with RefLikeExpression { - def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = f(expr) - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachWidth(f: Width => Unit): Unit = () + extends Expression + with UseSerializer + with RefLikeExpression { + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = () } case class SubAccess(expr: Expression, index: Expression, tpe: Type, flow: Flow = UnknownFlow) - extends Expression with UseSerializer with RefLikeExpression { - def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index)) - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this + extends Expression + with UseSerializer + with RefLikeExpression { + def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) } - def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachType(f: Type => Unit): Unit = f(tpe) def foreachWidth(f: Width => Unit): Unit = () } case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type = UnknownType) - extends Expression with UseSerializer { - def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this + extends Expression + with UseSerializer { + def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(tval); f(fval) } - def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachType(f: Type => Unit): Unit = f(tpe) def foreachWidth(f: Width => Unit): Unit = () } case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression with UseSerializer { - def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe) - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this + def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(value) } - def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachType(f: Type => Unit): Unit = f(tpe) def foreachWidth(f: Width => Unit): Unit = () } abstract class Literal extends Expression { @@ -292,16 +321,16 @@ abstract class Literal extends Expression { } case class UIntLiteral(value: BigInt, width: Width) extends Literal with UseSerializer { def tpe = UIntType(width) - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width)) - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = f(width) + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width)) + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = f(width) } object UIntLiteral { def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1)) - def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value)) + def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value)) /** Utility to construct UIntLiterals masked by the width * @@ -314,78 +343,82 @@ object UIntLiteral { } case class SIntLiteral(value: BigInt, width: Width) extends Literal with UseSerializer { def tpe = SIntType(width) - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = f(width) + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = f(width) } object SIntLiteral { def minWidth(value: BigInt): Width = IntWidth(value.bitLength + 1) - def apply(value: BigInt): SIntLiteral = new SIntLiteral(value, minWidth(value)) + def apply(value: BigInt): SIntLiteral = new SIntLiteral(value, minWidth(value)) } case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Literal with UseSerializer { def tpe = FixedType(width, point) - def mapExpr(f: Expression => Expression): Expression = this - def mapType(f: Type => Type): Expression = this - def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point)) - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point)) + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) - extends Expression with UseSerializer { - def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f) - def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) - def mapWidth(f: Width => Width): Expression = this - def foreachExpr(f: Expression => Unit): Unit = args.foreach(f) - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachWidth(f: Width => Unit): Unit = () + extends Expression + with UseSerializer { + def mapExpr(f: Expression => Expression): Expression = this.copy(args = args.map(f)) + def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) + def mapWidth(f: Width => Width): Expression = this + def foreachExpr(f: Expression => Unit): Unit = args.foreach(f) + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachWidth(f: Width => Unit): Unit = () } abstract class Statement extends FirrtlNode { - def mapStmt(f: Statement => Statement): Statement - def mapExpr(f: Expression => Expression): Statement - def mapType(f: Type => Type): Statement - def mapString(f: String => String): Statement - def mapInfo(f: Info => Info): Statement - def foreachStmt(f: Statement => Unit): Unit - def foreachExpr(f: Expression => Unit): Unit - def foreachType(f: Type => Unit): Unit - def foreachString(f: String => Unit): Unit - def foreachInfo(f: Info => Unit): Unit + def mapStmt(f: Statement => Statement): Statement + def mapExpr(f: Expression => Expression): Statement + def mapType(f: Type => Type): Statement + def mapString(f: String => String): Statement + def mapInfo(f: Info => Info): Statement + def foreachStmt(f: Statement => Unit): Unit + def foreachExpr(f: Expression => Unit): Unit + def foreachType(f: Type => Unit): Unit + def foreachString(f: String => Unit): Unit + def foreachInfo(f: Info => Unit): Unit } case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) - def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe)) + def mapString(f: String => String): Statement = DefWire(info, f(name), tpe) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class DefRegister( - info: Info, - name: String, - tpe: Type, - clock: Expression, - reset: Expression, - init: Expression) extends Statement with IsDeclaration with UseSerializer { + info: Info, + name: String, + tpe: Type, + clock: Expression, + reset: Expression, + init: Expression) + extends Statement + with IsDeclaration + with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = DefRegister(info, name, tpe, f(clock), f(reset), f(init)) - def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clock); f(reset); f(init) } - def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachType(f: Type => Unit): Unit = f(tpe) def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + def foreachInfo(f: Info => Unit): Unit = f(info) } object DefInstance { @@ -393,17 +426,19 @@ object DefInstance { } case class DefInstance(info: Info, name: String, module: String, tpe: Type = UnknownType) - extends Statement with IsDeclaration with UseSerializer { - def mapExpr(f: Expression => Expression): Statement = this - def mapStmt(f: Statement => Statement): Statement = this - def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(tpe) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + extends Statement + with IsDeclaration + with UseSerializer { + def mapExpr(f: Expression => Expression): Statement = this + def mapStmt(f: Statement => Statement): Statement = this + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } object ReadUnderWrite extends Enumeration { @@ -413,56 +448,64 @@ object ReadUnderWrite extends Enumeration { } case class DefMemory( - info: Info, - name: String, - dataType: Type, - depth: BigInt, - writeLatency: Int, - readLatency: Int, - readers: Seq[String], - writers: Seq[String], - readwriters: Seq[String], - // TODO: handle read-under-write - readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) - extends Statement with IsDeclaration with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(dataType) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) -} -case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = DefNode(info, f(name), value) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = f(value) - def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + info: Info, + name: String, + dataType: Type, + depth: BigInt, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String], + // TODO: handle read-under-write + readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) + extends Statement + with IsDeclaration + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) +} +case class DefNode(info: Info, name: String, value: Expression) + extends Statement + with IsDeclaration + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = DefNode(info, f(name), value) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = f(value) + def foreachType(f: Type => Unit): Unit = () + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Conditionally( - info: Info, - pred: Expression, - conseq: Statement, - alt: Statement) extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt)) - def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + info: Info, + pred: Expression, + conseq: Statement, + alt: Statement) + extends Statement + with HasInfo + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt)) + def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = { f(conseq); f(alt) } - def foreachExpr(f: Expression => Unit): Unit = f(pred) - def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) + def foreachExpr(f: Expression => Unit): Unit = f(pred) + def foreachType(f: Type => Unit): Unit = () + def foreachString(f: String => Unit): Unit = () + def foreachInfo(f: Info => Unit): Unit = f(info) } object Block { @@ -489,94 +532,101 @@ case class Block(stmts: Seq[Statement]) extends Statement with UseSerializer { } Block(res.toSeq) } - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this - def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f) - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = () + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this + def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f) + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachString(f: String => Unit): Unit = () + def foreachInfo(f: Info => Unit): Unit = () } case class PartialConnect(info: Info, loc: Expression, expr: Expression) - extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () + extends Statement + with HasInfo + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) } - def foreachType(f: Type => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) -} -case class Connect(info: Info, loc: Expression, expr: Expression) - extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () + def foreachInfo(f: Info => Unit): Unit = f(info) +} +case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) } - def foreachType(f: Type => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = f(expr) - def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = f(expr) + def foreachType(f: Type => Unit): Unit = () + def foreachString(f: String => Unit): Unit = () + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = Attach(info, exprs map f) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f) - def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) -} -case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Attach(info, exprs.map(f)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f) + def foreachType(f: Type => Unit): Unit = () + def foreachString(f: String => Unit): Unit = () + def foreachInfo(f: Info => Unit): Unit = f(info) +} +case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) + extends Statement + with HasInfo + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) } - def foreachType(f: Type => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Print( - info: Info, - string: StringLit, - args: Seq[Expression], - clk: Expression, - en: Expression) extends Statement with HasInfo with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () + info: Info, + string: StringLit, + args: Seq[Expression], + clk: Expression, + en: Expression) + extends Statement + with HasInfo + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = Print(info, string, args.map(f), f(clk), f(en)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) } - def foreachType(f: Type => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) + def foreachInfo(f: Info => Unit): Unit = f(info) } // formal @@ -587,38 +637,40 @@ object Formal extends Enumeration { } case class Verification( - op: Formal.Value, + op: Formal.Value, info: Info, - clk: Expression, + clk: Expression, pred: Expression, - en: Expression, - msg: StringLit -) extends Statement with HasInfo with UseSerializer { + en: Expression, + msg: StringLit) + extends Statement + with HasInfo + with UseSerializer { def mapStmt(f: Statement => Statement): Statement = this def mapExpr(f: Expression => Expression): Statement = copy(clk = f(clk), pred = f(pred), en = f(en)) - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = copy(info = f(info)) + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = copy(info = f(info)) def foreachStmt(f: Statement => Unit): Unit = () def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(pred); f(en); } - def foreachType(f: Type => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = f(info) + def foreachInfo(f: Info => Unit): Unit = f(info) } // end formal case object EmptyStmt extends Statement with UseSerializer { - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = this - def mapString(f: String => String): Statement = this - def mapInfo(f: Info => Info): Statement = this - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = () - def foreachString(f: String => Unit): Unit = () - def foreachInfo(f: Info => Unit): Unit = () + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this + def mapString(f: String => String): Statement = this + def mapInfo(f: Info => Info): Statement = this + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = () + def foreachString(f: String => Unit): Unit = () + def foreachInfo(f: Info => Unit): Unit = () } abstract class Width extends FirrtlNode { @@ -631,14 +683,15 @@ abstract class Width extends FirrtlNode { case _ => UnknownWidth } def max(x: Width): Width = (this, x) match { - case (a: IntWidth, b: IntWidth) => IntWidth(a.width max b.width) + case (a: IntWidth, b: IntWidth) => IntWidth(a.width.max(b.width)) case _ => UnknownWidth } def min(x: Width): Width = (this, x) match { - case (a: IntWidth, b: IntWidth) => IntWidth(a.width min b.width) + case (a: IntWidth, b: IntWidth) => IntWidth(a.width.min(b.width)) case _ => UnknownWidth } } + /** Positive Integer Bit Width of a [[GroundType]] */ object IntWidth { private val maxCached = 1024 @@ -665,7 +718,7 @@ class IntWidth(val width: BigInt) extends Width with Product with UseSerializer override def hashCode = width.toInt override def productPrefix = "IntWidth" override def toString = s"$productPrefix($width)" - def copy(width: BigInt = width) = IntWidth(width) + def copy(width: BigInt = width) = IntWidth(width) def canEqual(that: Any) = that.isInstanceOf[Width] def productArity = 1 def productElement(int: Int) = int match { @@ -693,19 +746,18 @@ case object Flip extends Orientation { /** Field of [[BundleType]] */ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode with HasName with UseSerializer - /** Bounds of [[IntervalType]] */ trait Bound extends Constraint case object UnknownBound extends Bound { def serialize: String = Serializer.serialize(this) - def map(f: Constraint=>Constraint): Constraint = this + def map(f: Constraint => Constraint): Constraint = this override def reduce(): Constraint = this val children = Vector() } case class CalcBound(arg: Constraint) extends Bound { def serialize: String = Serializer.serialize(this) - def map(f: Constraint=>Constraint): Constraint = f(arg) + def map(f: Constraint => Constraint): Constraint = f(arg) override def reduce(): Constraint = arg val children = Vector(arg) } @@ -727,58 +779,60 @@ case class Open(value: BigDecimal) extends IsKnown with Bound { def +(that: IsKnown): IsKnown = Open(value + that.value) def *(that: IsKnown): IsKnown = that match { case Closed(x) if x == 0 => Closed(x) - case _ => Open(value * that.value) + case _ => Open(value * that.value) } - def min(that: IsKnown): IsKnown = if(value < that.value) this else that - def max(that: IsKnown): IsKnown = if(value > that.value) this else that - def neg: IsKnown = Open(-value) - def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) - def pow: IsKnown = if(value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") + def min(that: IsKnown): IsKnown = if (value < that.value) this else that + def max(that: IsKnown): IsKnown = if (value > that.value) this else that + def neg: IsKnown = Open(-value) + def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) + def pow: IsKnown = + if (value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") } case class Closed(value: BigDecimal) extends IsKnown with Bound { def serialize: String = Serializer.serialize(this) def +(that: IsKnown): IsKnown = that match { - case Open(x) => Open(value + x) + case Open(x) => Open(value + x) case Closed(x) => Closed(value + x) } def *(that: IsKnown): IsKnown = that match { case IsKnown(x) if value == BigInt(0) => Closed(0) - case Open(x) => Open(value * x) - case Closed(x) => Closed(value * x) + case Open(x) => Open(value * x) + case Closed(x) => Closed(value * x) } - def min(that: IsKnown): IsKnown = if(value <= that.value) this else that - def max(that: IsKnown): IsKnown = if(value >= that.value) this else that - def neg: IsKnown = Closed(-value) + def min(that: IsKnown): IsKnown = if (value <= that.value) this else that + def max(that: IsKnown): IsKnown = if (value >= that.value) this else that + def neg: IsKnown = Closed(-value) def floor: IsKnown = Closed(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) - def pow: IsKnown = if(value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") + def pow: IsKnown = + if (value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") } /** Types of [[FirrtlNode]] */ abstract class Type extends FirrtlNode { - def mapType(f: Type => Type): Type - def mapWidth(f: Width => Width): Type - def foreachType(f: Type => Unit): Unit - def foreachWidth(f: Width => Unit): Unit + def mapType(f: Type => Type): Type + def mapWidth(f: Width => Width): Type + def foreachType(f: Type => Unit): Unit + def foreachWidth(f: Width => Unit): Unit } abstract class GroundType extends Type { val width: Width - def mapType(f: Type => Type): Type = this + def mapType(f: Type => Type): Type = this def foreachType(f: Type => Unit): Unit = () } object GroundType { def unapply(ground: GroundType): Option[Width] = Some(ground.width) } abstract class AggregateType extends Type { - def mapWidth(f: Width => Width): Type = this - def foreachWidth(f: Width => Unit): Unit = () + def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = () } case class UIntType(width: Width) extends GroundType with UseSerializer { - def mapWidth(f: Width => Width): Type = UIntType(f(width)) - def foreachWidth(f: Width => Unit): Unit = f(width) + def mapWidth(f: Width => Width): Type = UIntType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case class SIntType(width: Width) extends GroundType with UseSerializer { - def mapWidth(f: Width => Width): Type = SIntType(f(width)) - def foreachWidth(f: Width => Unit): Unit = f(width) + def mapWidth(f: Width => Width): Type = SIntType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case class FixedType(width: Width, point: Width) extends GroundType with UseSerializer { def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point)) @@ -790,21 +844,21 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground case Open(l) => s"(${dec2string(l)}, " case Closed(l) => s"[${dec2string(l)}, " case UnknownBound => s"[?, " - case _ => s"[?, " + case _ => s"[?, " } val upperString = upper match { case Open(u) => s"${dec2string(u)})" case Closed(u) => s"${dec2string(u)}]" case UnknownBound => s"?]" - case _ => s"?]" + case _ => s"?]" } val bounds = (lower, upper) match { case (k1: IsKnown, k2: IsKnown) => lowerString + upperString case _ => "" } val pointString = point match { - case IntWidth(i) => "." + i.toString - case _ => "" + case IntWidth(i) => "." + i.toString + case _ => "" } "Interval" + bounds + pointString } @@ -813,35 +867,43 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground private def precision: Option[BigDecimal] = point match { case IntWidth(width) => val bp = width.toInt - if(bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp)) + if (bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp)) case other => None } def min: Option[BigDecimal] = (lower, precision) match { - case (Open(a), Some(prec)) => a / prec match { - case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0 - case x => Some(x.setScale(0, CEILING) * prec) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l - } + case (Open(a), Some(prec)) => + a / prec match { + case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0 + case x => + Some( + x.setScale(0, CEILING) * prec + ) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l + } case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, CEILING) * prec) - case other => None + case other => None } def max: Option[BigDecimal] = (upper, precision) match { - case (Open(a), Some(prec)) => a / prec match { - case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound - case x => Some(x.setScale(0, FLOOR) * prec) - } + case (Open(a), Some(prec)) => + a / prec match { + case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound + case x => Some(x.setScale(0, FLOOR) * prec) + } case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec) } def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match { case x if trim(x).isWhole | x.doubleValue == 0.0 => x.toBigInt - case x => sys.error(s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}.") + case x => + sys.error( + s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}." + ) }) def maxAdjusted: Option[BigInt] = max.map(_ * BigDecimal(BigInt(1) << bp) match { case x if trim(x).isWhole => x.toBigInt - case x => sys.error(s"MaxAdjusted should be a whole number: $x") + case x => sys.error(s"MaxAdjusted should be a whole number: $x") }) /** If bounds are known, calculates the width, otherwise returns UnknownWidth */ @@ -854,48 +916,48 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground /** If bounds are known, returns a sequence of all possible values inside this interval */ lazy val range: Option[Seq[BigDecimal]] = (lower, upper, point) match { case (l: IsKnown, u: IsKnown, p: IntWidth) => - if(min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get)) + if (min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get)) case _ => None } - override def mapWidth(f: Width => Width): Type = this.copy(point = f(point)) - override def foreachWidth(f: Width => Unit): Unit = f(point) + override def mapWidth(f: Width => Width): Type = this.copy(point = f(point)) + override def foreachWidth(f: Width => Unit): Unit = f(point) } case class BundleType(fields: Seq[Field]) extends AggregateType with UseSerializer { def mapType(f: Type => Type): Type = - BundleType(fields map (x => x.copy(tpe = f(x.tpe)))) - def foreachType(f: Type => Unit): Unit = fields.foreach{ x => f(x.tpe) } + BundleType(fields.map(x => x.copy(tpe = f(x.tpe)))) + def foreachType(f: Type => Unit): Unit = fields.foreach { x => f(x.tpe) } } case class VectorType(tpe: Type, size: Int) extends AggregateType with UseSerializer { - def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe)) + def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe)) def foreachType(f: Type => Unit): Unit = f(tpe) } case object ClockType extends GroundType with UseSerializer { val width = IntWidth(1) - def mapWidth(f: Width => Width): Type = this - def foreachWidth(f: Width => Unit): Unit = () + def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = () } /* Abstract reset, will be inferred to UInt<1> or AsyncReset */ case object ResetType extends GroundType with UseSerializer { val width = IntWidth(1) - def mapWidth(f: Width => Width): Type = this - def foreachWidth(f: Width => Unit): Unit = () + def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = () } case object AsyncResetType extends GroundType with UseSerializer { val width = IntWidth(1) - def mapWidth(f: Width => Width): Type = this - def foreachWidth(f: Width => Unit): Unit = () + def mapWidth(f: Width => Width): Type = this + def foreachWidth(f: Width => Unit): Unit = () } case class AnalogType(width: Width) extends GroundType with UseSerializer { - def mapWidth(f: Width => Width): Type = AnalogType(f(width)) - def foreachWidth(f: Width => Unit): Unit = f(width) + def mapWidth(f: Width => Width): Type = AnalogType(f(width)) + def foreachWidth(f: Width => Unit): Unit = f(width) } case object UnknownType extends Type with UseSerializer { - def mapType(f: Type => Type): Type = this - def mapWidth(f: Width => Width): Type = this - def foreachType(f: Type => Unit): Unit = () - def foreachWidth(f: Width => Unit): Unit = () + def mapType(f: Type => Type): Type = this + def mapWidth(f: Width => Width): Type = this + def foreachType(f: Type => Unit): Unit = () + def foreachWidth(f: Width => Unit): Unit = () } /** [[Port]] Direction */ @@ -909,11 +971,14 @@ case object Output extends Direction { /** [[DefModule]] Port */ case class Port( - info: Info, - name: String, - direction: Direction, - tpe: Type) extends FirrtlNode with IsDeclaration with UseSerializer { - def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe)) + info: Info, + name: String, + direction: Direction, + tpe: Type) + extends FirrtlNode + with IsDeclaration + with UseSerializer { + def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe)) def mapString(f: String => String): Port = Port(info, f(name), direction, tpe) } @@ -921,12 +986,16 @@ case class Port( sealed abstract class Param extends FirrtlNode { def name: String } + /** Integer (of any width) Parameter */ case class IntParam(name: String, value: BigInt) extends Param with UseSerializer + /** IEEE Double Precision Parameter (for Verilog real) */ case class DoubleParam(name: String, value: Double) extends Param with UseSerializer + /** String Parameter */ case class StringParam(name: String, value: StringLit) extends Param with UseSerializer + /** Raw String Parameter * Useful for Verilog type parameters * @note Firrtl doesn't guarantee anything about this String being legal in any backend @@ -935,59 +1004,65 @@ case class RawStringParam(name: String, value: String) extends Param with UseSer /** Base class for modules */ abstract class DefModule extends FirrtlNode with IsDeclaration { - val info : Info - val name : String - val ports : Seq[Port] - def mapStmt(f: Statement => Statement): DefModule - def mapPort(f: Port => Port): DefModule - def mapString(f: String => String): DefModule - def mapInfo(f: Info => Info): DefModule - def foreachStmt(f: Statement => Unit): Unit - def foreachPort(f: Port => Unit): Unit - def foreachString(f: String => Unit): Unit - def foreachInfo(f: Info => Unit): Unit + val info: Info + val name: String + val ports: Seq[Port] + def mapStmt(f: Statement => Statement): DefModule + def mapPort(f: Port => Port): DefModule + def mapString(f: String => String): DefModule + def mapInfo(f: Info => Info): DefModule + def foreachStmt(f: Statement => Unit): Unit + def foreachPort(f: Port => Unit): Unit + def foreachString(f: String => Unit): Unit + def foreachInfo(f: Info => Unit): Unit } + /** Internal Module * * An instantiable hardware block */ case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) extends DefModule with UseSerializer { - def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body)) - def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) - def mapString(f: String => String): DefModule = this.copy(name = f(name)) - def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) - def foreachStmt(f: Statement => Unit): Unit = f(body) - def foreachPort(f: Port => Unit): Unit = ports.foreach(f) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body)) + def mapPort(f: Port => Port): DefModule = this.copy(ports = ports.map(f)) + def mapString(f: String => String): DefModule = this.copy(name = f(name)) + def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = f(body) + def foreachPort(f: Port => Unit): Unit = ports.foreach(f) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } + /** External Module * * Generally used for Verilog black boxes * @param defname Defined name of the external module (ie. the name Firrtl will emit) */ case class ExtModule( - info: Info, - name: String, - ports: Seq[Port], - defname: String, - params: Seq[Param]) extends DefModule with UseSerializer { - def mapStmt(f: Statement => Statement): DefModule = this - def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f) - def mapString(f: String => String): DefModule = this.copy(name = f(name)) - def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachPort(f: Port => Unit): Unit = ports.foreach(f) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + info: Info, + name: String, + ports: Seq[Port], + defname: String, + params: Seq[Param]) + extends DefModule + with UseSerializer { + def mapStmt(f: Statement => Statement): DefModule = this + def mapPort(f: Port => Port): DefModule = this.copy(ports = ports.map(f)) + def mapString(f: String => String): DefModule = this.copy(name = f(name)) + def mapInfo(f: Info => Info): DefModule = this.copy(f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachPort(f: Port => Unit): Unit = ports.foreach(f) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } case class Circuit(info: Info, modules: Seq[DefModule], main: String) - extends FirrtlNode with HasInfo with UseSerializer { - def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f) - def mapString(f: String => String): Circuit = this.copy(main = f(main)) - def mapInfo(f: Info => Info): Circuit = this.copy(f(info)) - def foreachModule(f: DefModule => Unit): Unit = modules foreach f - def foreachString(f: String => Unit): Unit = f(main) - def foreachInfo(f: Info => Unit): Unit = f(info) + extends FirrtlNode + with HasInfo + with UseSerializer { + def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules.map(f)) + def mapString(f: String => String): Circuit = this.copy(main = f(main)) + def mapInfo(f: Info => Info): Circuit = this.copy(f(info)) + def foreachModule(f: DefModule => Unit): Unit = modules.foreach(f) + def foreachString(f: String => Unit): Unit = f(main) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala index ea304cf3..bf9a57c1 100644 --- a/src/main/scala/firrtl/ir/Serializer.scala +++ b/src/main/scala/firrtl/ir/Serializer.scala @@ -13,19 +13,19 @@ object Serializer { val builder = new StringBuilder() val indent = 0 node match { - case n : Info => s(n)(builder, indent) - case n : StringLit => s(n)(builder, indent) - case n : Expression => s(n)(builder, indent) - case n : Statement => s(n)(builder, indent) - case n : Width => s(n)(builder, indent) - case n : Orientation => s(n)(builder, indent) - case n : Field => s(n)(builder, indent) - case n : Type => s(n)(builder, indent) - case n : Direction => s(n)(builder, indent) - case n : Port => s(n)(builder, indent) - case n : Param => s(n)(builder, indent) - case n : DefModule => s(n)(builder, indent) - case n : Circuit => s(n)(builder, indent) + case n: Info => s(n)(builder, indent) + case n: StringLit => s(n)(builder, indent) + case n: Expression => s(n)(builder, indent) + case n: Statement => s(n)(builder, indent) + case n: Width => s(n)(builder, indent) + case n: Orientation => s(n)(builder, indent) + case n: Field => s(n)(builder, indent) + case n: Type => s(n)(builder, indent) + case n: Direction => s(n)(builder, indent) + case n: Port => s(n)(builder, indent) + case n: Param => s(n)(builder, indent) + case n: DefModule => s(n)(builder, indent) + case n: Circuit => s(n)(builder, indent) } builder.toString() } @@ -39,16 +39,16 @@ object Serializer { private def flattenInfo(infos: Seq[Info]): Seq[FileInfo] = infos.flatMap { case NoInfo => Seq() - case f : FileInfo => Seq(f) + case f: FileInfo => Seq(f) case MultiInfo(infos) => flattenInfo(infos) } private def s(node: Info)(implicit b: StringBuilder, indent: Int): Unit = node match { - case f : FileInfo => b ++= " @[" ; b ++= f.escaped ; b ++= "]" + case f: FileInfo => b ++= " @["; b ++= f.escaped; b ++= "]" case NoInfo => // empty string - case m : MultiInfo => + case m: MultiInfo => val infos = m.flatten - if(infos.nonEmpty) { + if (infos.nonEmpty) { val lastId = infos.length - 1 b ++= " @[" infos.zipWithIndex.foreach { case (f, i) => b ++= f.escaped; if (i < lastId) b += ' ' } @@ -61,103 +61,113 @@ object Serializer { private def s(node: Expression)(implicit b: StringBuilder, indent: Int): Unit = node match { case Reference(name, _, _, _) => b ++= name case DoPrim(op, args, consts, _) => - b ++= op.toString ; b += '(' ; s(args, ", ", consts.isEmpty) ; s(consts, ", ") ; b += ')' + b ++= op.toString; b += '('; s(args, ", ", consts.isEmpty); s(consts, ", "); b += ')' case UIntLiteral(value, width) => - b ++= "UInt" ; s(width) ; b ++= "(\"h" ; b ++= value.toString(16) ; b ++= "\")" - case SubField(expr, name, _, _) => s(expr) ; b += '.' ; b ++= name - case SubIndex(expr, value, _, _) => s(expr) ; b += '[' ; b ++= value.toString ; b += ']' - case SubAccess(expr, index, _, _) => s(expr) ; b += '[' ; s(index) ; b += ']' + b ++= "UInt"; s(width); b ++= "(\"h"; b ++= value.toString(16); b ++= "\")" + case SubField(expr, name, _, _) => s(expr); b += '.'; b ++= name + case SubIndex(expr, value, _, _) => s(expr); b += '['; b ++= value.toString; b += ']' + case SubAccess(expr, index, _, _) => s(expr); b += '['; s(index); b += ']' case Mux(cond, tval, fval, _) => - b ++= "mux(" ; s(cond) ; b ++= ", " ; s(tval) ; b ++= ", " ; s(fval) ; b += ')' - case ValidIf(cond, value, _) => b ++= "validif(" ; s(cond) ; b ++= ", " ; s(value) ; b += ')' + b ++= "mux("; s(cond); b ++= ", "; s(tval); b ++= ", "; s(fval); b += ')' + case ValidIf(cond, value, _) => b ++= "validif("; s(cond); b ++= ", "; s(value); b += ')' case SIntLiteral(value, width) => - b ++= "SInt" ; s(width) ; b ++= "(\"h" ; b ++= value.toString(16) ; b ++= "\")" + b ++= "SInt"; s(width); b ++= "(\"h"; b ++= value.toString(16); b ++= "\")" case FixedLiteral(value, width, point) => - b ++= "Fixed" ; s(width) ; sPoint(point) - b ++= "(\"h" ; b ++= value.toString(16) ; b ++= "\")" + b ++= "Fixed"; s(width); sPoint(point) + b ++= "(\"h"; b ++= value.toString(16); b ++= "\")" // WIR - case firrtl.WVoid => b ++= "VOID" - case firrtl.WInvalid => b ++= "INVALID" + case firrtl.WVoid => b ++= "VOID" + case firrtl.WInvalid => b ++= "INVALID" case firrtl.EmptyExpression => b ++= "EMPTY" } private def s(node: Statement)(implicit b: StringBuilder, indent: Int): Unit = node match { - case DefNode(info, name, value) => b ++= "node " ; b ++= name ; b ++= " = " ; s(value) ; s(info) - case Connect(info, loc, expr) => s(loc) ; b ++= " <= " ; s(expr) ; s(info) + case DefNode(info, name, value) => b ++= "node "; b ++= name; b ++= " = "; s(value); s(info) + case Connect(info, loc, expr) => s(loc); b ++= " <= "; s(expr); s(info) case Conditionally(info, pred, conseq, alt) => - b ++= "when " ; s(pred) ; b ++= " :" ; s(info) - newLineAndIndent(1) ; s(conseq)(b, indent + 1) - if(alt != EmptyStmt) { - newLineAndIndent() ; b ++= "else :" - newLineAndIndent(1) ; s(alt)(b, indent + 1) + b ++= "when "; s(pred); b ++= " :"; s(info) + newLineAndIndent(1); s(conseq)(b, indent + 1) + if (alt != EmptyStmt) { + newLineAndIndent(); b ++= "else :" + newLineAndIndent(1); s(alt)(b, indent + 1) } - case EmptyStmt => b ++= "skip" + case EmptyStmt => b ++= "skip" case Block(Seq()) => b ++= "skip" case Block(stmts) => val it = stmts.iterator - while(it.hasNext) { + while (it.hasNext) { s(it.next) - if(it.hasNext) newLineAndIndent() + if (it.hasNext) newLineAndIndent() } case Stop(info, ret, clk, en) => - b ++= "stop(" ; s(clk) ; b ++= ", " ; s(en) ; b ++= ", " ; b ++= ret.toString ; b += ')' ; s(info) + b ++= "stop("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= ret.toString; b += ')'; s(info) case Print(info, string, args, clk, en) => - b ++= "printf(" ; s(clk) ; b ++= ", " ; s(en) ; b ++= ", " ; b ++= string.escape - if(args.nonEmpty) b ++= ", " ; s(args, ", ") ; b += ')' ; s(info) - case IsInvalid(info, expr) => s(expr) ; b ++= " is invalid" ; s(info) - case DefWire(info, name, tpe) => b ++= "wire " ; b ++= name ; b ++= " : " ; s(tpe) ; s(info) + b ++= "printf("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= string.escape + if (args.nonEmpty) b ++= ", "; s(args, ", "); b += ')'; s(info) + case IsInvalid(info, expr) => s(expr); b ++= " is invalid"; s(info) + case DefWire(info, name, tpe) => b ++= "wire "; b ++= name; b ++= " : "; s(tpe); s(info) case DefRegister(info, name, tpe, clock, reset, init) => - b ++= "reg " ; b ++= name ; b ++= " : " ; s(tpe) ; b ++= ", " ; s(clock) ; b ++= " with :" ; newLineAndIndent(1) - b ++= "reset => (" ; s(reset) ; b ++= ", " ; s(init) ; b += ')' ; s(info) - case DefInstance(info, name, module, _) => b ++= "inst " ; b ++= name ; b ++= " of " ; b ++= module ; s(info) - case DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, - readwriters, readUnderWrite) => - b ++= "mem " ; b ++= name ; b ++= " :" ; s(info) ; newLineAndIndent(1) - b ++= "data-type => " ; s(dataType) ; newLineAndIndent(1) - b ++= "depth => " ; b ++= depth.toString() ; newLineAndIndent(1) - b ++= "read-latency => " ; b ++= readLatency.toString ; newLineAndIndent(1) - b ++= "write-latency => " ; b ++= writeLatency.toString ; newLineAndIndent(1) - readers.foreach{ r => b ++= "reader => " ; b ++= r ; newLineAndIndent(1) } - writers.foreach{ w => b ++= "writer => " ; b ++= w ; newLineAndIndent(1) } - readwriters.foreach{ r => b ++= "readwriter => " ; b ++= r ; newLineAndIndent(1) } - b ++= "read-under-write => " ; b ++= readUnderWrite.toString - case PartialConnect(info, loc, expr) => s(loc) ; b ++= " <- " ; s(expr) ; s(info) - case Attach(info, exprs) => + b ++= "reg "; b ++= name; b ++= " : "; s(tpe); b ++= ", "; s(clock); b ++= " with :"; newLineAndIndent(1) + b ++= "reset => ("; s(reset); b ++= ", "; s(init); b += ')'; s(info) + case DefInstance(info, name, module, _) => b ++= "inst "; b ++= name; b ++= " of "; b ++= module; s(info) + case DefMemory( + info, + name, + dataType, + depth, + writeLatency, + readLatency, + readers, + writers, + readwriters, + readUnderWrite + ) => + b ++= "mem "; b ++= name; b ++= " :"; s(info); newLineAndIndent(1) + b ++= "data-type => "; s(dataType); newLineAndIndent(1) + b ++= "depth => "; b ++= depth.toString(); newLineAndIndent(1) + b ++= "read-latency => "; b ++= readLatency.toString; newLineAndIndent(1) + b ++= "write-latency => "; b ++= writeLatency.toString; newLineAndIndent(1) + readers.foreach { r => b ++= "reader => "; b ++= r; newLineAndIndent(1) } + writers.foreach { w => b ++= "writer => "; b ++= w; newLineAndIndent(1) } + readwriters.foreach { r => b ++= "readwriter => "; b ++= r; newLineAndIndent(1) } + b ++= "read-under-write => "; b ++= readUnderWrite.toString + case PartialConnect(info, loc, expr) => s(loc); b ++= " <- "; s(expr); s(info) + case Attach(info, exprs) => // exprs should never be empty since the attach statement takes *at least* two signals according to the spec - b ++= "attach (" ; s(exprs, ", ") ; b += ')' ; s(info) + b ++= "attach ("; s(exprs, ", "); b += ')'; s(info) case Verification(op, info, clk, pred, en, msg) => - b ++= op.toString ; b += '(' ; s(List(clk, pred, en), ", ", false) ; b ++= msg.escape - b += ')' ; s(info) + b ++= op.toString; b += '('; s(List(clk, pred, en), ", ", false); b ++= msg.escape + b += ')'; s(info) // WIR case firrtl.CDefMemory(info, name, tpe, size, seq, readUnderWrite) => - if(seq) b ++= "smem " else b ++= "cmem " - b ++= name ; b ++= " : " ; s(tpe) ; b ++= " [" ; b ++= size.toString() ; b += ']' ; s(info) + if (seq) b ++= "smem " else b ++= "cmem " + b ++= name; b ++= " : "; s(tpe); b ++= " ["; b ++= size.toString(); b += ']'; s(info) case firrtl.CDefMPort(info, name, _, mem, exps, direction) => - b ++= direction.serialize ; b ++= " mport " ; b ++= name ; b ++= " = " ; b ++= mem - b += '[' ; s(exps.head) ; b ++= "], " ; s(exps(1)) ; s(info) + b ++= direction.serialize; b ++= " mport "; b ++= name; b ++= " = "; b ++= mem + b += '['; s(exps.head); b ++= "], "; s(exps(1)); s(info) case firrtl.WDefInstanceConnector(info, name, module, tpe, portCons) => - b ++= "inst " ; b ++= name ; b ++= " of " ; b ++= module ; b ++= " with " ; s(tpe) ; b ++= " connected to (" - s(portCons.map(_._2), ", ") ; b += ')' ; s(info) + b ++= "inst "; b ++= name; b ++= " of "; b ++= module; b ++= " with "; s(tpe); b ++= " connected to (" + s(portCons.map(_._2), ", "); b += ')'; s(info) } private def s(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match { case IntWidth(width) => b += '<'; b ++= width.toString(); b += '>' - case UnknownWidth => // empty string - case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')' - case VarWidth(name) => b += '<'; b ++= name; b += '>' + case UnknownWidth => // empty string + case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')' + case VarWidth(name) => b += '<'; b ++= name; b += '>' } private def sPoint(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match { case IntWidth(width) => b ++= "<<"; b ++= width.toString(); b ++= ">>" - case UnknownWidth => // empty string - case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')' - case VarWidth(name) => b ++= "<<"; b ++= name; b ++= ">>" + case UnknownWidth => // empty string + case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')' + case VarWidth(name) => b ++= "<<"; b ++= name; b ++= ">>" } private def s(node: Orientation)(implicit b: StringBuilder, indent: Int): Unit = node match { case Default => // empty string - case Flip => b ++= "flip " + case Flip => b ++= "flip " } private def s(node: Field)(implicit b: StringBuilder, indent: Int): Unit = node match { @@ -169,19 +179,19 @@ object Serializer { case UIntType(width: Width) => b ++= "UInt"; s(width) case SIntType(width: Width) => b ++= "SInt"; s(width) case FixedType(width, point) => b ++= "Fixed"; s(width); sPoint(point) - case BundleType(fields) => b ++= "{ "; sField(fields, ", "); b += '}' - case VectorType(tpe, size) => s(tpe); b += '['; b ++= size.toString; b += ']' - case ClockType => b ++= "Clock" - case ResetType => b ++= "Reset" - case AsyncResetType => b ++= "AsyncReset" - case AnalogType(width) => b ++= "Analog"; s(width) - case UnknownType => b += '?' + case BundleType(fields) => b ++= "{ "; sField(fields, ", "); b += '}' + case VectorType(tpe, size) => s(tpe); b += '['; b ++= size.toString; b += ']' + case ClockType => b ++= "Clock" + case ResetType => b ++= "Reset" + case AsyncResetType => b ++= "AsyncReset" + case AnalogType(width) => b ++= "Analog"; s(width) + case UnknownType => b += '?' // the IntervalType has a complicated custom serialization method which does not recurse case i: IntervalType => b ++= i.serialize } private def s(node: Direction)(implicit b: StringBuilder, indent: Int): Unit = node match { - case Input => b ++= "input" + case Input => b ++= "input" case Output => b ++= "output" } @@ -191,50 +201,50 @@ object Serializer { } private def s(node: Param)(implicit b: StringBuilder, indent: Int): Unit = node match { - case IntParam(name, value) => b ++= "parameter " ; b ++= name ; b ++= " = " ; b ++= value.toString - case DoubleParam(name, value) => b ++= "parameter " ; b ++= name ; b ++= " = " ; b ++= value.toString - case StringParam(name, value) => b ++= "parameter " ; b ++= name ; b ++= " = " ; b ++= value.escape + case IntParam(name, value) => b ++= "parameter "; b ++= name; b ++= " = "; b ++= value.toString + case DoubleParam(name, value) => b ++= "parameter "; b ++= name; b ++= " = "; b ++= value.toString + case StringParam(name, value) => b ++= "parameter "; b ++= name; b ++= " = "; b ++= value.escape case RawStringParam(name, value) => - b ++= "parameter " ; b ++= name ; b ++= " = " - b += '\'' ; b ++= value.replace("'", "\\'") ; b += '\'' + b ++= "parameter "; b ++= name; b ++= " = " + b += '\''; b ++= value.replace("'", "\\'"); b += '\'' } private def s(node: DefModule)(implicit b: StringBuilder, indent: Int): Unit = node match { case Module(info, name, ports, body) => - b ++= "module " ; b ++= name ; b ++= " :" ; s(info) - ports.foreach{ p => newLineAndIndent(1) ; s(p) } + b ++= "module "; b ++= name; b ++= " :"; s(info) + ports.foreach { p => newLineAndIndent(1); s(p) } newLineNoIndent() // add a new line between port declaration and body - newLineAndIndent(1) ; s(body)(b, indent + 1) + newLineAndIndent(1); s(body)(b, indent + 1) case ExtModule(info, name, ports, defname, params) => - b ++= "extmodule " ; b ++= name ; b ++= " :" ; s(info) - ports.foreach{ p => newLineAndIndent(1) ; s(p) } - newLineAndIndent(1) ; b ++= "defname = " ; b ++= defname - params.foreach{ p => newLineAndIndent(1) ; s(p) } + b ++= "extmodule "; b ++= name; b ++= " :"; s(info) + ports.foreach { p => newLineAndIndent(1); s(p) } + newLineAndIndent(1); b ++= "defname = "; b ++= defname + params.foreach { p => newLineAndIndent(1); s(p) } } private def s(node: Circuit)(implicit b: StringBuilder, indent: Int): Unit = node match { case Circuit(info, modules, main) => - b ++= "circuit " ; b ++= main ; b ++= " :" ; s(info) - if(modules.nonEmpty) { - newLineAndIndent(1) ; s(modules.head)(b, indent + 1) - modules.drop(1).foreach{m => newLineNoIndent(); newLineAndIndent(1) ; s(m)(b, indent + 1) } + b ++= "circuit "; b ++= main; b ++= " :"; s(info) + if (modules.nonEmpty) { + newLineAndIndent(1); s(modules.head)(b, indent + 1) + modules.drop(1).foreach { m => newLineNoIndent(); newLineAndIndent(1); s(m)(b, indent + 1) } } } // serialize constraints private def s(const: Constraint)(implicit b: StringBuilder): Unit = const match { // Bounds - case UnknownBound => b += '?' - case CalcBound(arg) => b ++= "calcb(" ; s(arg) ; b += ')' + case UnknownBound => b += '?' + case CalcBound(arg) => b ++= "calcb("; s(arg); b += ')' case VarBound(name) => b ++= name - case Open(value) => b ++ "o(" ; b ++= value.toString ; b += ')' - case Closed(value) => b ++ "c(" ; b ++= value.toString ; b += ')' - case other => other.serialize + case Open(value) => b ++ "o("; b ++= value.toString; b += ')' + case Closed(value) => b ++ "c("; b ++= value.toString; b += ')' + case other => other.serialize } /** create a new line with the appropriate indent */ private def newLineAndIndent(inc: Int = 0)(implicit b: StringBuilder, indent: Int): Unit = { - b += NewLine ; doIndent(inc) + b += NewLine; doIndent(inc) } private def newLineNoIndent()(implicit b: StringBuilder): Unit = b += NewLine @@ -245,32 +255,37 @@ object Serializer { } /** serialize firrtl Expression nodes with a custom separator and the option to include the separator at the end */ - private def s(nodes: Iterable[Expression], sep: String, noFinalSep: Boolean = true) - (implicit b: StringBuilder, indent: Int): Unit = { + private def s( + nodes: Iterable[Expression], + sep: String, + noFinalSep: Boolean = true + )( + implicit b: StringBuilder, + indent: Int + ): Unit = { val it = nodes.iterator - while(it.hasNext) { + while (it.hasNext) { s(it.next()) - if(!noFinalSep || it.hasNext) b ++= sep + if (!noFinalSep || it.hasNext) b ++= sep } } /** serialize firrtl Field nodes with a custom separator and the option to include the separator at the end */ @inline - private def sField(nodes: Iterable[Field], sep: String) - (implicit b: StringBuilder, indent: Int): Unit = { + private def sField(nodes: Iterable[Field], sep: String)(implicit b: StringBuilder, indent: Int): Unit = { val it = nodes.iterator - while(it.hasNext) { + while (it.hasNext) { s(it.next()) - if(it.hasNext) b ++= sep + if (it.hasNext) b ++= sep } } /** serialize BigInts with a custom separator */ private def s(consts: Iterable[BigInt], sep: String)(implicit b: StringBuilder): Unit = { val it = consts.iterator - while(it.hasNext) { + while (it.hasNext) { b ++= it.next().toString() - if(it.hasNext) b ++= sep + if (it.hasNext) b ++= sep } } } diff --git a/src/main/scala/firrtl/ir/StructuralHash.scala b/src/main/scala/firrtl/ir/StructuralHash.scala index 1b38dec1..f1ed91f3 100644 --- a/src/main/scala/firrtl/ir/StructuralHash.scala +++ b/src/main/scala/firrtl/ir/StructuralHash.scala @@ -24,7 +24,7 @@ import scala.collection.mutable * of the same circuit and thus all modules referred to in DefInstance are the same. * * @author Kevin Laeufer <laeufer@cs.berkeley.edu> - * */ + */ object StructuralHash { def sha256(node: DefModule, moduleRename: String => String = identity): HashCode = { val m = MessageDigest.getInstance(SHA256) @@ -59,19 +59,19 @@ object StructuralHash { private val SHA256 = "SHA-256" private def hash(node: FirrtlNode, h: Hasher, rename: String => String): Unit = node match { - case n : Expression => new StructuralHash(h, rename).hash(n) - case n : Statement => new StructuralHash(h, rename).hash(n) - case n : Type => new StructuralHash(h, rename).hash(n) - case n : Width => new StructuralHash(h, rename).hash(n) - case n : Orientation => new StructuralHash(h, rename).hash(n) - case n : Field => new StructuralHash(h, rename).hash(n) - case n : Direction => new StructuralHash(h, rename).hash(n) - case n : Port => new StructuralHash(h, rename).hash(n) - case n : Param => new StructuralHash(h, rename).hash(n) - case _ : Info => throw new RuntimeException("The structural hash of Info is meaningless.") - case n : DefModule => new StructuralHash(h, rename).hash(n) - case n : Circuit => hashCircuit(n, h, rename) - case n : StringLit => h.update(n.toString) + case n: Expression => new StructuralHash(h, rename).hash(n) + case n: Statement => new StructuralHash(h, rename).hash(n) + case n: Type => new StructuralHash(h, rename).hash(n) + case n: Width => new StructuralHash(h, rename).hash(n) + case n: Orientation => new StructuralHash(h, rename).hash(n) + case n: Field => new StructuralHash(h, rename).hash(n) + case n: Direction => new StructuralHash(h, rename).hash(n) + case n: Port => new StructuralHash(h, rename).hash(n) + case n: Param => new StructuralHash(h, rename).hash(n) + case _: Info => throw new RuntimeException("The structural hash of Info is meaningless.") + case n: DefModule => new StructuralHash(h, rename).hash(n) + case n: Circuit => hashCircuit(n, h, rename) + case n: StringLit => h.update(n.toString) } private def hashModuleAndPortNames(m: DefModule, h: Hasher, rename: String => String): Unit = { @@ -85,9 +85,9 @@ object StructuralHash { } private def hashPortTypeName(tpe: Type, h: String => Unit): Unit = tpe match { - case BundleType(fields) => fields.foreach{ f => h(f.name) ; hashPortTypeName(f.tpe, h) } - case VectorType(vt, _) => hashPortTypeName(vt, h) - case _ => // ignore ground types since they do not have field names nor sub-types + case BundleType(fields) => fields.foreach { f => h(f.name); hashPortTypeName(f.tpe, h) } + case VectorType(vt, _) => hashPortTypeName(vt, h) + case _ => // ignore ground types since they do not have field names nor sub-types } private def hashCircuit(c: Circuit, h: Hasher, rename: String => String): Unit = { @@ -101,8 +101,8 @@ object StructuralHash { } } - private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map{ case (op, i) => op -> (-i -1).toByte }.toMap - assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50") + private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map { case (op, i) => op -> (-i - 1).toByte }.toMap + assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50") assert(primOpToId.values.min >= -50, "PrimOp nodes use ids -1 ... -50") private def primOp(p: PrimOp): Byte = primOpToId(p) @@ -110,7 +110,7 @@ object StructuralHash { private def verificationOp(op: Formal.Value): Byte = op match { case Formal.Assert => 0 case Formal.Assume => 1 - case Formal.Cover => 2 + case Formal.Cover => 2 } } @@ -129,14 +129,14 @@ private class MDHashCode(code: Array[Byte]) extends HashCode { /** Generic hashing interface which allows us to use different backends to trade of speed and collision resistance */ private trait Hasher { - def update(b: Byte): Unit - def update(i: Int): Unit - def update(l: Long): Unit - def update(s: String): Unit + def update(b: Byte): Unit + def update(i: Int): Unit + def update(l: Long): Unit + def update(s: String): Unit def update(b: Array[Byte]): Unit def update(d: Double): Unit = update(java.lang.Double.doubleToRawLongBits(d)) - def update(i: BigInt): Unit = update(i.toByteArray) - def update(b: Boolean): Unit = if(b) update(1.toByte) else update(0.toByte) + def update(i: BigInt): Unit = update(i.toByteArray) + def update(b: Boolean): Unit = if (b) update(1.toByte) else update(0.toByte) def update(i: BigDecimal): Unit = { // this might be broken, tried to borrow some code from BigDecimal.computeHashCode val temp = i.bigDecimal.stripTrailingZeros() @@ -149,14 +149,14 @@ private trait Hasher { private class MessageDigestHasher(m: MessageDigest) extends Hasher { override def update(b: Byte): Unit = m.update(b) override def update(i: Int): Unit = { - m.update(((i >> 0) & 0xff).toByte) - m.update(((i >> 8) & 0xff).toByte) + m.update(((i >> 0) & 0xff).toByte) + m.update(((i >> 8) & 0xff).toByte) m.update(((i >> 16) & 0xff).toByte) m.update(((i >> 24) & 0xff).toByte) } override def update(l: Long): Unit = { - m.update(((l >> 0) & 0xff).toByte) - m.update(((l >> 8) & 0xff).toByte) + m.update(((l >> 0) & 0xff).toByte) + m.update(((l >> 8) & 0xff).toByte) m.update(((l >> 16) & 0xff).toByte) m.update(((l >> 24) & 0xff).toByte) m.update(((l >> 32) & 0xff).toByte) @@ -165,42 +165,47 @@ private class MessageDigestHasher(m: MessageDigest) extends Hasher { m.update(((l >> 56) & 0xff).toByte) } // the encoding of the bytes should not matter as long as we are on the same platform - override def update(s: String): Unit = m.update(s.getBytes()) + override def update(s: String): Unit = m.update(s.getBytes()) override def update(b: Array[Byte]): Unit = m.update(b) } -class StructuralHash private(h: Hasher, renameModule: String => String) { +class StructuralHash private (h: Hasher, renameModule: String => String) { // replace identifiers with incrementing integers private val nameToInt = mutable.HashMap[String, Int]() private var nameCounter: Int = 0 - @inline private def n(name: String): Unit = hash(nameToInt.getOrElseUpdate(name, { - val ii = nameCounter - nameCounter = nameCounter + 1 - ii - })) + @inline private def n(name: String): Unit = hash( + nameToInt.getOrElseUpdate( + name, { + val ii = nameCounter + nameCounter = nameCounter + 1 + ii + } + ) + ) // internal convenience methods - @inline private def id(b: Byte): Unit = h.update(b) - @inline private def hash(i: Int): Unit = h.update(i) - @inline private def hash(b: Boolean): Unit = h.update(b) - @inline private def hash(d: Double): Unit = h.update(d) - @inline private def hash(i: BigInt): Unit = h.update(i) + @inline private def id(b: Byte): Unit = h.update(b) + @inline private def hash(i: Int): Unit = h.update(i) + @inline private def hash(b: Boolean): Unit = h.update(b) + @inline private def hash(d: Double): Unit = h.update(d) + @inline private def hash(i: BigInt): Unit = h.update(i) @inline private def hash(i: BigDecimal): Unit = h.update(i) - @inline private def hash(s: String): Unit = h.update(s) + @inline private def hash(s: String): Unit = h.update(s) private def hash(node: Expression): Unit = node match { - case Reference(name, _, _, _) => id(0) ; n(name) + case Reference(name, _, _, _) => id(0); n(name) case DoPrim(op, args, consts, _) => // no need to hash the number of arguments or constants since that is implied by the op - id(1) ; h.update(StructuralHash.primOp(op)) ; args.foreach(hash) ; consts.foreach(hash) - case UIntLiteral(value, width) => id(2) ; hash(value) ; hash(width) + id(1); h.update(StructuralHash.primOp(op)); args.foreach(hash); consts.foreach(hash) + case UIntLiteral(value, width) => id(2); hash(value); hash(width) // We hash bundles as if fields are accessed by their index. // Thus we need to also hash field accesses that way. // This has the side-effect that `x.y` might hash to the same value as `z.r`, for example if the // types are `x: {y: UInt<1>, ...}` and `z: {r: UInt<1>, ...}` respectively. // They do not hash to the same value if the type of `z` is e.g., `z: {..., r: UInt<1>, ...}` // as that would have the `r` field at a different index. - case SubField(expr, name, _, _) => id(3) ; hash(expr) + case SubField(expr, name, _, _) => + id(3); hash(expr) // find field index and hash that instead of the field name val fields = expr.tpe match { case b: BundleType => b.fields @@ -209,93 +214,115 @@ class StructuralHash private(h: Hasher, renameModule: String => String) { } val index = fields.zipWithIndex.find(_._1.name == name).map(_._2).get hash(index) - case SubIndex(expr, value, _, _) => id(4) ; hash(expr) ; hash(value) - case SubAccess(expr, index, _, _) => id(5) ; hash(expr) ; hash(index) - case Mux(cond, tval, fval, _) => id(6) ; hash(cond) ; hash(tval) ; hash(fval) - case ValidIf(cond, value, _) => id(7) ; hash(cond) ; hash(value) - case SIntLiteral(value, width) => id(8) ; hash(value) ; hash(width) - case FixedLiteral(value, width, point) => id(9) ; hash(value) ; hash(width) ; hash(point) + case SubIndex(expr, value, _, _) => id(4); hash(expr); hash(value) + case SubAccess(expr, index, _, _) => id(5); hash(expr); hash(index) + case Mux(cond, tval, fval, _) => id(6); hash(cond); hash(tval); hash(fval) + case ValidIf(cond, value, _) => id(7); hash(cond); hash(value) + case SIntLiteral(value, width) => id(8); hash(value); hash(width) + case FixedLiteral(value, width, point) => id(9); hash(value); hash(width); hash(point) // WIR - case firrtl.WVoid => id(10) - case firrtl.WInvalid => id(11) + case firrtl.WVoid => id(10) + case firrtl.WInvalid => id(11) case firrtl.EmptyExpression => id(12) // VRandom is used in the Emitter - case firrtl.VRandom(width) => id(13) ; hash(width) + case firrtl.VRandom(width) => id(13); hash(width) // ids 14 ... 19 are reserved for future Expression nodes } private def hash(node: Statement): Unit = node match { // all info fields are ignore - case DefNode(_, name, value) => id(20) ; n(name) ; hash(value) - case Connect(_, loc, expr) => id(21) ; hash(loc) ; hash(expr) + case DefNode(_, name, value) => id(20); n(name); hash(value) + case Connect(_, loc, expr) => id(21); hash(loc); hash(expr) // we place the unique id 23 between conseq and alt to distinguish between them in case conseq is empty // we place the unique id 24 after alt to distinguish between alt and the next statement in case alt is empty - case Conditionally(_, pred, conseq, alt) => id(22) ; hash(pred) ; hash(conseq) ; id(23) ; hash(alt) ; id(24) - case EmptyStmt => // empty statements are ignored - case Block(stmts) => stmts.foreach(hash) // block structure is ignored - case Stop(_, ret, clk, en) => id(25) ; hash(ret) ; hash(clk) ; hash(en) - case Print(_, string, args, clk, en) => + case Conditionally(_, pred, conseq, alt) => id(22); hash(pred); hash(conseq); id(23); hash(alt); id(24) + case EmptyStmt => // empty statements are ignored + case Block(stmts) => stmts.foreach(hash) // block structure is ignored + case Stop(_, ret, clk, en) => id(25); hash(ret); hash(clk); hash(en) + case Print(_, string, args, clk, en) => // the string is part of the side effect and thus part of the circuit behavior - id(26) ; hash(string.string) ; hash(args.length) ; args.foreach(hash) ; hash(clk) ; hash(en) - case IsInvalid(_, expr) => id(27) ; hash(expr) - case DefWire(_, name, tpe) => id(28) ; n(name) ; hash(tpe) + id(26); hash(string.string); hash(args.length); args.foreach(hash); hash(clk); hash(en) + case IsInvalid(_, expr) => id(27); hash(expr) + case DefWire(_, name, tpe) => id(28); n(name); hash(tpe) case DefRegister(_, name, tpe, clock, reset, init) => - id(29) ; n(name) ; hash(tpe) ; hash(clock) ; hash(reset) ; hash(init) + id(29); n(name); hash(tpe); hash(clock); hash(reset); hash(init) case DefInstance(_, name, module, _) => // Module is in the global namespace which is why we cannot replace it with a numeric id. // However, it might have been renamed as part of the dedup consolidation. - id(30) ; n(name) ; hash(renameModule(module)) + id(30); n(name); hash(renameModule(module)) // descriptions on statements are ignores case firrtl.DescribedStmt(_, stmt) => hash(stmt) - case DefMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers, - readwriters, readUnderWrite) => - id(30) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency) - hash(readers.length) ; readers.foreach(hash) - hash(writers.length) ; writers.foreach(hash) - hash(readwriters.length) ; readwriters.foreach(hash) + case DefMemory( + _, + name, + dataType, + depth, + writeLatency, + readLatency, + readers, + writers, + readwriters, + readUnderWrite + ) => + id(30); n(name); hash(dataType); hash(depth); hash(writeLatency); hash(readLatency) + hash(readers.length); readers.foreach(hash) + hash(writers.length); writers.foreach(hash) + hash(readwriters.length); readwriters.foreach(hash) hash(readUnderWrite) - case PartialConnect(_, loc, expr) => id(31) ; hash(loc) ; hash(expr) - case Attach(_, exprs) => id(32) ; hash(exprs.length) ; exprs.foreach(hash) + case PartialConnect(_, loc, expr) => id(31); hash(loc); hash(expr) + case Attach(_, exprs) => id(32); hash(exprs.length); exprs.foreach(hash) // WIR case firrtl.CDefMemory(_, name, tpe, size, seq, readUnderWrite) => - id(33) ; n(name) ; hash(tpe); hash(size) ; hash(seq) ; hash(readUnderWrite) + id(33); n(name); hash(tpe); hash(size); hash(seq); hash(readUnderWrite) case firrtl.CDefMPort(_, name, _, mem, exps, direction) => // the type of the MPort depends only on the memory (in well types firrtl) and can thus be ignored - id(34) ; n(name) ; n(mem) ; hash(exps.length) ; exps.foreach(hash) ; hash(direction) + id(34); n(name); n(mem); hash(exps.length); exps.foreach(hash); hash(direction) // DefAnnotatedMemory from MemIR.scala - case firrtl.passes.memlib.DefAnnotatedMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers, - readwriters, readUnderWrite, maskGran, memRef) => - id(35) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency) - hash(readers.length) ; readers.foreach(hash) - hash(writers.length) ; writers.foreach(hash) - hash(readwriters.length) ; readwriters.foreach(hash) + case firrtl.passes.memlib.DefAnnotatedMemory( + _, + name, + dataType, + depth, + writeLatency, + readLatency, + readers, + writers, + readwriters, + readUnderWrite, + maskGran, + memRef + ) => + id(35); n(name); hash(dataType); hash(depth); hash(writeLatency); hash(readLatency) + hash(readers.length); readers.foreach(hash) + hash(writers.length); writers.foreach(hash) + hash(readwriters.length); readwriters.foreach(hash) hash(readUnderWrite.toString) - hash(maskGran.size) ; maskGran.foreach(hash) - hash(memRef.size) ; memRef.foreach{ case (a, b) => hash(a) ; hash(b) } + hash(maskGran.size); maskGran.foreach(hash) + hash(memRef.size); memRef.foreach { case (a, b) => hash(a); hash(b) } case Verification(op, _, clk, pred, en, msg) => - id(36) ; hash(StructuralHash.verificationOp(op)) ; hash(clk) ; hash(pred) ; hash(en) ; hash(msg.string) + id(36); hash(StructuralHash.verificationOp(op)); hash(clk); hash(pred); hash(en); hash(msg.string) // ids 37 ... 39 are reserved for future Statement nodes } // ReadUnderWrite is never used in place of a FirrtlNode and thus we can start a new id namespace private def hash(ruw: ReadUnderWrite.Value): Unit = ruw match { - case ReadUnderWrite.New => id(0) - case ReadUnderWrite.Old => id(1) + case ReadUnderWrite.New => id(0) + case ReadUnderWrite.Old => id(1) case ReadUnderWrite.Undefined => id(2) } private def hash(node: Width): Unit = node match { - case IntWidth(width) => id(40) ; hash(width) - case UnknownWidth => id(41) - case CalcWidth(arg) => id(42) ; hash(arg) + case IntWidth(width) => id(40); hash(width) + case UnknownWidth => id(41) + case CalcWidth(arg) => id(42); hash(arg) // we are hashing the name of the `VarWidth` instead of using `n` since these Vars exist in a different namespace - case VarWidth(name) => id(43) ; hash(name) + case VarWidth(name) => id(43); hash(name) // ids 44 + 45 are reserved for future Width nodes } private def hash(node: Orientation): Unit = node match { case Default => id(46) - case Flip => id(47) + case Flip => id(47) } private def hash(node: Field): Unit = { @@ -306,81 +333,81 @@ class StructuralHash private(h: Hasher, renameModule: String => String) { // has been used in the Dedup pass for a long time. // This position-based notion of equality requires us to replace field names with field indexes when hashing // SubField accesses. - id(48) ; hash(node.flip) ; hash(node.tpe) + id(48); hash(node.flip); hash(node.tpe) } private def hash(node: Type): Unit = node match { // Types - case UIntType(width: Width) => id(50) ; hash(width) - case SIntType(width: Width) => id(51) ; hash(width) - case FixedType(width, point) => id(52) ; hash(width) ; hash(point) - case BundleType(fields) => id(53) ; hash(fields.length) ; fields.foreach(hash) - case VectorType(tpe, size) => id(54) ; hash(tpe) ; hash(size) - case ClockType => id(55) - case ResetType => id(56) - case AsyncResetType => id(57) - case AnalogType(width) => id(58) ; hash(width) - case UnknownType => id(59) - case IntervalType(lower, upper, point) => id(60) ; hash(lower) ; hash(upper) ; hash(point) + case UIntType(width: Width) => id(50); hash(width) + case SIntType(width: Width) => id(51); hash(width) + case FixedType(width, point) => id(52); hash(width); hash(point) + case BundleType(fields) => id(53); hash(fields.length); fields.foreach(hash) + case VectorType(tpe, size) => id(54); hash(tpe); hash(size) + case ClockType => id(55) + case ResetType => id(56) + case AsyncResetType => id(57) + case AnalogType(width) => id(58); hash(width) + case UnknownType => id(59) + case IntervalType(lower, upper, point) => id(60); hash(lower); hash(upper); hash(point) // ids 61 ... 65 are reserved for future Type nodes } private def hash(node: Direction): Unit = node match { - case Input => id(66) + case Input => id(66) case Output => id(67) } private def hash(node: Port): Unit = { - id(68) ; n(node.name) ; hash(node.direction) ; hash(node.tpe) + id(68); n(node.name); hash(node.direction); hash(node.tpe) } private def hash(node: Param): Unit = node match { - case IntParam(name, value) => id(70) ; n(name) ; hash(value) - case DoubleParam(name, value) => id(71) ; n(name) ; hash(value) - case StringParam(name, value) => id(72) ; n(name) ; hash(value.string) - case RawStringParam(name, value) => id(73) ; n(name) ; hash(value) + case IntParam(name, value) => id(70); n(name); hash(value) + case DoubleParam(name, value) => id(71); n(name); hash(value) + case StringParam(name, value) => id(72); n(name); hash(value.string) + case RawStringParam(name, value) => id(73); n(name); hash(value) // id 74 is reserved for future use } private def hash(node: DefModule): Unit = node match { // the module name is ignored since it does not affect module functionality case Module(_, _name, ports, body) => - id(75) ; hash(ports.length) ; ports.foreach(hash) ; hash(body) + id(75); hash(ports.length); ports.foreach(hash); hash(body) // the module name is ignored since it does not affect module functionality case ExtModule(_, name, ports, defname, params) => - id(76) ; hash(ports.length) ; ports.foreach(hash) ; hash(defname) - hash(params.length) ; params.foreach(hash) + id(76); hash(ports.length); ports.foreach(hash); hash(defname) + hash(params.length); params.foreach(hash) } // id 127 is reserved for Circuit nodes private def hash(d: firrtl.MPortDir): Unit = d match { - case firrtl.MInfer => id(-70) - case firrtl.MRead => id(-71) - case firrtl.MWrite => id(-72) + case firrtl.MInfer => id(-70) + case firrtl.MRead => id(-71) + case firrtl.MWrite => id(-72) case firrtl.MReadWrite => id(-73) } private def hash(c: firrtl.constraint.Constraint): Unit = c match { case b: Bound => hash(b) /* uses ids -80 ... -84 */ case firrtl.constraint.IsAdd(known, maxs, mins, others) => - id(-85) ; hash(known.nonEmpty) ; known.foreach(hash) - hash(maxs.length) ; maxs.foreach(hash) - hash(mins.length) ; mins.foreach(hash) - hash(others.length) ; others.foreach(hash) - case firrtl.constraint.IsFloor(child, dummyArg) => id(-86) ; hash(child) ; hash(dummyArg) - case firrtl.constraint.IsKnown(decimal) => id(-87) ; hash(decimal) - case firrtl.constraint.IsNeg(child, dummyArg) => id(-88) ; hash(child) ; hash(dummyArg) - case firrtl.constraint.IsPow(child, dummyArg) => id(-89) ; hash(child) ; hash(dummyArg) - case firrtl.constraint.IsVar(str) => id(-90) ; n(str) + id(-85); hash(known.nonEmpty); known.foreach(hash) + hash(maxs.length); maxs.foreach(hash) + hash(mins.length); mins.foreach(hash) + hash(others.length); others.foreach(hash) + case firrtl.constraint.IsFloor(child, dummyArg) => id(-86); hash(child); hash(dummyArg) + case firrtl.constraint.IsKnown(decimal) => id(-87); hash(decimal) + case firrtl.constraint.IsNeg(child, dummyArg) => id(-88); hash(child); hash(dummyArg) + case firrtl.constraint.IsPow(child, dummyArg) => id(-89); hash(child); hash(dummyArg) + case firrtl.constraint.IsVar(str) => id(-90); n(str) } private def hash(b: Bound): Unit = b match { case UnknownBound => id(-80) - case CalcBound(arg) => id(-81) ; hash(arg) + case CalcBound(arg) => id(-81); hash(arg) // we are hashing the name of the `VarBound` instead of using `n` since these Vars exist in a different namespace - case VarBound(name) => id(-82) ; hash(name) - case Open(value) => id(-83) ; hash(value) - case Closed(value) => id(-84) ; hash(value) + case VarBound(name) => id(-82); hash(name) + case Open(value) => id(-83); hash(value) + case Closed(value) => id(-84); hash(value) } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/options/DependencyManager.scala b/src/main/scala/firrtl/options/DependencyManager.scala index ee6a7404..561e32ab 100644 --- a/src/main/scala/firrtl/options/DependencyManager.scala +++ b/src/main/scala/firrtl/options/DependencyManager.scala @@ -3,7 +3,7 @@ package firrtl.options import firrtl.AnnotationSeq -import firrtl.graph.{DiGraph, CyclicException} +import firrtl.graph.{CyclicException, DiGraph} import scala.collection.Set import scala.collection.immutable.{Set => ISet} @@ -22,7 +22,6 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends override def prerequisites = currentState - override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = Seq.empty @@ -34,13 +33,13 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends */ def targets: Seq[Dependency[B]] private lazy val _targets: LinkedHashSet[Dependency[B]] = targets - .foldLeft(new LinkedHashSet[Dependency[B]]()){ case (a, b) => a += b } + .foldLeft(new LinkedHashSet[Dependency[B]]()) { case (a, b) => a += b } /** A sequence of [[firrtl.Transform]]s that have been run. Internally, this will be converted to an ordered set. */ def currentState: Seq[Dependency[B]] private lazy val _currentState: LinkedHashSet[Dependency[B]] = currentState - .foldLeft(new LinkedHashSet[Dependency[B]]()){ case (a, b) => a += b } + .foldLeft(new LinkedHashSet[Dependency[B]]()) { case (a, b) => a += b } /** Existing transform objects that have already been constructed */ def knownObjects: Set[B] @@ -64,9 +63,10 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends * requirements. This is used to solve sub-problems arising from invalidations. */ protected def copy( - targets: Seq[Dependency[B]], + targets: Seq[Dependency[B]], currentState: Seq[Dependency[B]], - knownObjects: ISet[B] = dependencyToObject.values.toSet): B + knownObjects: ISet[B] = dependencyToObject.values.toSet + ): B /** Implicit conversion from Dependency to B */ private implicit def dToO(d: Dependency[B]): B = dependencyToObject.getOrElseUpdate(d, d.getObject()) @@ -77,14 +77,16 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends /** Modified breadth-first search that supports multiple starting nodes and a custom extractor that can be used to * generate/filter the edges to explore. Additionally, this will include edges to previously discovered nodes. */ - private def bfs( start: LinkedHashSet[Dependency[B]], - blacklist: LinkedHashSet[Dependency[B]], - extractor: B => Set[Dependency[B]] ): LinkedHashMap[B, LinkedHashSet[B]] = { + private def bfs( + start: LinkedHashSet[Dependency[B]], + blacklist: LinkedHashSet[Dependency[B]], + extractor: B => Set[Dependency[B]] + ): LinkedHashMap[B, LinkedHashSet[B]] = { val (queue, edges) = { - val a: Queue[Dependency[B]] = Queue(start.toSeq:_*) - val b: LinkedHashMap[B, LinkedHashSet[B]] = LinkedHashMap[B, LinkedHashSet[B]]( - start.map((dToO(_) -> LinkedHashSet[B]())).toSeq:_*) + val a: Queue[Dependency[B]] = Queue(start.toSeq: _*) + val b: LinkedHashMap[B, LinkedHashSet[B]] = + LinkedHashMap[B, LinkedHashSet[B]](start.map((dToO(_) -> LinkedHashSet[B]())).toSeq: _*) (a, b) } @@ -117,7 +119,8 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends val edges = bfs( start = _targets &~ _currentState, blacklist = _currentState, - extractor = (p: B) => p._prerequisites &~ _currentState) + extractor = (p: B) => p._prerequisites &~ _currentState + ) DiGraph(edges) } @@ -144,11 +147,14 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends val edges = { val x = new LinkedHashMap ++ _targets .map(dependencyToObject) - .map{ a => a -> prerequisiteGraph.getVertices.filter(a._optionalPrerequisiteOf(_)) } - x - .values + .map { a => a -> prerequisiteGraph.getVertices.filter(a._optionalPrerequisiteOf(_)) } + x.values .reduce(_ ++ _) - .foldLeft(x){ case (xx, y) => if (xx.contains(y)) { xx } else { xx ++ Map(y -> Set.empty[B]) } } + .foldLeft(x) { + case (xx, y) => + if (xx.contains(y)) { xx } + else { xx ++ Map(y -> Set.empty[B]) } + } } DiGraph(edges).reverse } @@ -165,23 +171,26 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends bfs( start = v.map(oToD(_)), blacklist = _currentState, - /* Explore all invalidated transforms **EXCEPT** the current transform! */ extractor = (p: B) => { val filtered = new LinkedHashSet[Dependency[B]] filtered ++= v.filter(p.invalidates).map(oToD(_)) filtered -= oToD(p) filtered - }) + } + ) ).reverse } /** Wrap a possible [[CyclicException]] thrown by a thunk in a [[DependencyManagerException]] */ - private def cyclePossible[A](a: String, diGraph: DiGraph[_])(thunk: => A): A = try { thunk } catch { + private def cyclePossible[A](a: String, diGraph: DiGraph[_])(thunk: => A): A = try { thunk } + catch { case e: CyclicException => throw new DependencyManagerException( s"""|No transform ordering possible due to cyclic dependency in $a with cycles: - |${diGraph.findSCCs.filter(_.size > 1).mkString(" - ", "\n - ", "")}""".stripMargin, e) + |${diGraph.findSCCs.filter(_.size > 1).mkString(" - ", "\n - ", "")}""".stripMargin, + e + ) } /** An ordering of [[firrtl.options.TransformLike TransformLike]]s that causes the requested [[DependencyManager.targets @@ -198,38 +207,39 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends */ val sorted = { val edges = { - val v = cyclePossible("invalidates", invalidateGraph){ invalidateGraph.linearize }.reverse + val v = cyclePossible("invalidates", invalidateGraph) { invalidateGraph.linearize }.reverse /* A comparison function that will sort vertices based on the topological sort of the invalidation graph */ val cmp = - (l: B, r: B) => v.foldLeft((Map.empty[B, Dependency[B] => Boolean], Set.empty[Dependency[B]])){ - case ((m, s), r) => (m + (r -> ((a: Dependency[B]) => !s(a))), s + r) }._1(l)(r) + (l: B, r: B) => + v.foldLeft((Map.empty[B, Dependency[B] => Boolean], Set.empty[Dependency[B]])) { + case ((m, s), r) => (m + (r -> ((a: Dependency[B]) => !s(a))), s + r) + }._1(l)(r) new LinkedHashMap() ++ v.map(vv => vv -> (new LinkedHashSet() ++ (dependencyGraph.getEdges(vv).toSeq.sortWith(cmp)))) } cyclePossible("prerequisites", dependencyGraph) { - DiGraph(edges) - .linearize - .reverse + DiGraph(edges).linearize.reverse .dropWhile(b => _currentState.contains(b)) } } /* [todo] Seq is inefficient here, but Array has ClassTag problems. Use something else? */ - val (s, l) = sorted.foldLeft((_currentState, Seq[B]())){ case ((state, out), in) => - val prereqs = in._prerequisites ++ - dependencyGraph.getEdges(in).toSeq.map(oToD) ++ - otherPrerequisites.getEdges(in).toSeq.map(oToD) - val preprocessing: Option[B] = { - if ((prereqs -- state).nonEmpty) { Some(this.copy(prereqs.toSeq, state.toSeq)) } - else { None } - } - /* "in" is added *after* invalidation because a transform my not invalidate itself! */ - ((state ++ prereqs).map(dToO).filterNot(in.invalidates).map(oToD) + in, out ++ preprocessing :+ in) + val (s, l) = sorted.foldLeft((_currentState, Seq[B]())) { + case ((state, out), in) => + val prereqs = in._prerequisites ++ + dependencyGraph.getEdges(in).toSeq.map(oToD) ++ + otherPrerequisites.getEdges(in).toSeq.map(oToD) + val preprocessing: Option[B] = { + if ((prereqs -- state).nonEmpty) { Some(this.copy(prereqs.toSeq, state.toSeq)) } + else { None } + } + /* "in" is added *after* invalidation because a transform my not invalidate itself! */ + ((state ++ prereqs).map(dToO).filterNot(in.invalidates).map(oToD) + in, out ++ preprocessing :+ in) } val postprocessing: Option[B] = { if ((_targets -- s).nonEmpty) { Some(this.copy(_targets.toSeq, s.toSeq)) } - else { None } + else { None } } l ++ postprocessing } @@ -252,20 +262,21 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends * applied while tracking the state of the underlying A. If the state ever disagrees with a prerequisite, then this * throws an exception. */ - flattenedTransformOrder - .map{ t => - val w = wrappers.foldLeft(t){ case (tx, wrapper) => wrapper(tx) } - wrapperToClass += (w -> t) - w - }.foldLeft((annotations, _currentState)){ case ((a, state), t) => - if (!t.prerequisites.toSet.subsetOf(state)) { - throw new DependencyManagerException( - s"""|Tried to execute '$t' for which run-time prerequisites were not satisfied: - | state: ${state.mkString("\n -", "\n -", "")} - | prerequisites: ${prerequisites.mkString("\n -", "\n -", "")}""".stripMargin) - } - (t.transform(a), ((state + wrapperToClass(t)).map(dToO).filterNot(t.invalidates).map(oToD))) - }._1 + flattenedTransformOrder.map { t => + val w = wrappers.foldLeft(t) { case (tx, wrapper) => wrapper(tx) } + wrapperToClass += (w -> t) + w + }.foldLeft((annotations, _currentState)) { + case ((a, state), t) => + if (!t.prerequisites.toSet.subsetOf(state)) { + throw new DependencyManagerException( + s"""|Tried to execute '$t' for which run-time prerequisites were not satisfied: + | state: ${state.mkString("\n -", "\n -", "")} + | prerequisites: ${prerequisites.mkString("\n -", "\n -", "")}""".stripMargin + ) + } + (t.transform(a), ((state + wrapperToClass(t)).map(dToO).filterNot(t.invalidates).map(oToD))) + }._1 } /** This colormap uses Colorbrewer's 4-class OrRd color scheme */ @@ -282,13 +293,13 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends def toGraphviz(digraph: DiGraph[B], attributes: String = "", tab: String = " "): Option[String] = { val edges = - digraph - .getEdgeMap - .collect{ case (v, edges) if edges.nonEmpty => (v -> edges) } - .map{ case (v, edges) => - s"""${transformName(v)} -> ${edges.map(e => transformName(e)).mkString("{ ", " ", " }")}""" } + digraph.getEdgeMap.collect { case (v, edges) if edges.nonEmpty => (v -> edges) }.map { + case (v, edges) => + s"""${transformName(v)} -> ${edges.map(e => transformName(e)).mkString("{ ", " ", " }")}""" + } - if (edges.isEmpty) { None } else { + if (edges.isEmpty) { None } + else { Some( s"""| { $attributes |${edges.mkString(tab, "\n" + tab, "")} @@ -298,16 +309,16 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends } val connections = - Seq( (prerequisiteGraph, "edge []"), - (optionalPrerequisiteOfGraph, """edge [style=bold color="#4292c6"]"""), - (invalidateGraph, """edge [minlen=2 style=dashed constraint=false color="#fb6a4a"]"""), - (optionalPrerequisitesGraph, """edge [style=dotted color="#a1d99b"]""") ) - .flatMap{ case (a, b) => toGraphviz(a, b) } + Seq( + (prerequisiteGraph, "edge []"), + (optionalPrerequisiteOfGraph, """edge [style=bold color="#4292c6"]"""), + (invalidateGraph, """edge [minlen=2 style=dashed constraint=false color="#fb6a4a"]"""), + (optionalPrerequisitesGraph, """edge [style=dotted color="#a1d99b"]""") + ).flatMap { case (a, b) => toGraphviz(a, b) } .mkString("\n") val nodes = - (prerequisiteGraph + optionalPrerequisiteOfGraph + invalidateGraph + otherPrerequisites) - .getVertices + (prerequisiteGraph + optionalPrerequisiteOfGraph + invalidateGraph + otherPrerequisites).getVertices .map(v => s"""${transformName(v)} [label="${v.name}"]""") s"""|digraph DependencyManager { @@ -322,9 +333,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends def transformOrderToGraphviz(colormap: Seq[String] = colormap): String = { def rotate[A](a: Seq[A]): Seq[A] = a match { - case Nil => Nil + case Nil => Nil case car :: cdr => cdr :+ car - case car => car + case car => car } val sorted = ArrayBuffer.empty[String] @@ -340,7 +351,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends |$tab labeljust=l |$tab node [fillcolor="${cm.head}"]""".stripMargin - val body = pm.transformOrder.map{ + val body = pm.transformOrder.map { case a: DependencyManager[A, B] => val (str, d) = rec(a, rotate(cm), tab + " ", offset + 1) offset = d @@ -369,9 +380,10 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends * @param size the number of nodes at the current level of the tree */ def customPrintHandling( - tab: String, + tab: String, charSet: CharSet, - size: Int): Option[PartialFunction[(B, Int), Seq[String]]] = None + size: Int + ): Option[PartialFunction[(B, Int), Seq[String]]] = None /** Helper utility when recursing during pretty printing * @param tab an indentation string to use for every line of output @@ -386,9 +398,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends val defaultHandling: PartialFunction[(B, Int), Seq[String]] = { case (a: DependencyManager[_, _], `last`) => Seq(s"$tab$l ${a.name}") ++ a.prettyPrintRec(s"""$tab${" " * c.size} """, charSet) - case (a: DependencyManager[_, _], _) => Seq(s"$tab$n ${a.name}") ++ a.prettyPrintRec(s"$tab$c ", charSet) - case (a, `last`) => Seq(s"$tab$l ${a.name}") - case (a, _) => Seq(s"$tab$n ${a.name}") + case (a: DependencyManager[_, _], _) => Seq(s"$tab$n ${a.name}") ++ a.prettyPrintRec(s"$tab$c ", charSet) + case (a, `last`) => Seq(s"$tab$l ${a.name}") + case (a, _) => Seq(s"$tab$n ${a.name}") } val handling = customPrintHandling(tab, charSet, transformOrder.size) match { @@ -396,8 +408,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends case None => defaultHandling } - transformOrder - .zipWithIndex + transformOrder.zipWithIndex .flatMap(handling) } @@ -406,8 +417,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends * @param charSet a collection of characters to use when printing */ def prettyPrint( - tab: String = "", - charSet: DependencyManagerUtils.CharSet = DependencyManagerUtils.PrettyCharSet): String = { + tab: String = "", + charSet: DependencyManagerUtils.CharSet = DependencyManagerUtils.PrettyCharSet + ): String = { (Seq(s"$tab$name") ++ prettyPrintRec(tab, charSet)).mkString("\n") @@ -422,9 +434,11 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends * @param targets the [[Phase]]s you want to run */ class PhaseManager( - val targets: Seq[PhaseManager.PhaseDependency], + val targets: Seq[PhaseManager.PhaseDependency], val currentState: Seq[PhaseManager.PhaseDependency] = Seq.empty, - val knownObjects: Set[Phase] = Set.empty) extends DependencyManager[AnnotationSeq, Phase] with Phase { + val knownObjects: Set[Phase] = Set.empty) + extends DependencyManager[AnnotationSeq, Phase] + with Phase { import PhaseManager.PhaseDependency protected def copy(a: Seq[PhaseDependency], b: Seq[PhaseDependency], c: ISet[Phase]) = new PhaseManager(a, b, c) @@ -444,6 +458,7 @@ object DependencyManagerUtils { * @see [[ASCIICharSet]] */ trait CharSet { + /** Used when printing the last node */ val lastNode: String @@ -456,15 +471,15 @@ object DependencyManagerUtils { /** Uses prettier characters, but possibly not supported by all fonts */ object PrettyCharSet extends CharSet { - val lastNode = "└──" - val notLastNode = "├──" + val lastNode = "└──" + val notLastNode = "├──" val continuation = "│ " } /** Basic ASCII output */ object ASCIICharSet extends CharSet { - val lastNode = "\\--" - val notLastNode = "|--" + val lastNode = "\\--" + val notLastNode = "|--" val continuation = "| " } diff --git a/src/main/scala/firrtl/options/ExitCodes.scala b/src/main/scala/firrtl/options/ExitCodes.scala index 0e91fdec..94e525de 100644 --- a/src/main/scala/firrtl/options/ExitCodes.scala +++ b/src/main/scala/firrtl/options/ExitCodes.scala @@ -6,7 +6,7 @@ package firrtl.options sealed trait ExitCode { val number: Int } /** [[ExitCode]] indicating success */ -object ExitSuccess extends ExitCode{ val number = 0 } +object ExitSuccess extends ExitCode { val number = 0 } /** An [[ExitCode]] indicative of failure. This must be non-zero and should not conflict with a reserved exit code. */ sealed trait ExitFailure extends ExitCode diff --git a/src/main/scala/firrtl/options/OptionParser.scala b/src/main/scala/firrtl/options/OptionParser.scala index 9360a961..e7ea68bf 100644 --- a/src/main/scala/firrtl/options/OptionParser.scala +++ b/src/main/scala/firrtl/options/OptionParser.scala @@ -9,7 +9,8 @@ import scopt.OptionParser case object OptionsHelpException extends Exception("Usage help invoked") /** OptionParser mixin that causes the OptionParser to not call exit (call `sys.exit`) if the `--help` option is - * passed */ + * passed + */ trait DoNotTerminateOnExit { this: OptionParser[_] => override def terminate(exitState: Either[String, Unit]): Unit = () } @@ -33,16 +34,18 @@ trait DuplicateHandling extends OptionParser[AnnotationSeq] { /** Message for found duplicate options */ def msg(x: String, y: String) = s"""Duplicate $x "$y" (did your custom Transform or OptionsManager add this?)""" - val longDups = options.map(_.name).groupBy(identity).collect{ case (k, v) if v.size > 1 && k != "" => k } - val shortDups = options.map(_.shortOpt).flatten.groupBy(identity).collect{ case (k, v) if v.size > 1 => k } - + val longDups = options.map(_.name).groupBy(identity).collect { case (k, v) if v.size > 1 && k != "" => k } + val shortDups = options.map(_.shortOpt).flatten.groupBy(identity).collect { case (k, v) if v.size > 1 => k } - if (longDups.nonEmpty) { + if (longDups.nonEmpty) { throw new OptionsException(msg("long option", longDups.map("--" + _).mkString(",")), new IllegalArgumentException) } if (shortDups.nonEmpty) { - throw new OptionsException(msg("short option", shortDups.map("-" + _).mkString(",")), new IllegalArgumentException) + throw new OptionsException( + msg("short option", shortDups.map("-" + _).mkString(",")), + new IllegalArgumentException + ) } super.parse(args, init) diff --git a/src/main/scala/firrtl/options/Phase.scala b/src/main/scala/firrtl/options/Phase.scala index 2a68251d..6a3f4a8c 100644 --- a/src/main/scala/firrtl/options/Phase.scala +++ b/src/main/scala/firrtl/options/Phase.scala @@ -12,7 +12,7 @@ import scala.reflect import scala.reflect.ClassTag object Dependency { - def apply[A <: DependencyAPI[_] : ClassTag]: Dependency[A] = { + def apply[A <: DependencyAPI[_]: ClassTag]: Dependency[A] = { val clazz = reflect.classTag[A].runtimeClass Dependency(Left(clazz.asInstanceOf[Class[A]])) } @@ -40,26 +40,30 @@ object Dependency { case class Dependency[+A <: DependencyAPI[_]](id: Either[Class[_ <: A], A with Singleton]) { def getObject(): A = id match { - case Left(c) => safeConstruct(c) + case Left(c) => safeConstruct(c) case Right(o) => o } def getSimpleName: String = id match { - case Left(c) => c.getSimpleName + case Left(c) => c.getSimpleName case Right(o) => o.getClass.getSimpleName } def getName: String = id match { - case Left(c) => c.getName + case Left(c) => c.getName case Right(o) => o.getClass.getName } /** Wrap an [[IllegalAccessException]] due to attempted object construction in a [[DependencyManagerException]] */ - private def safeConstruct[A](a: Class[_ <: A]): A = try { a.newInstance } catch { - case e: IllegalAccessException => throw new DependencyManagerException( - s"Failed to construct '$a'! (Did you try to construct an object?)", e) - case e: InstantiationException => throw new DependencyManagerException( - s"Failed to construct '$a'! (Did you try to construct an inner class or a class with parameters?)", e) + private def safeConstruct[A](a: Class[_ <: A]): A = try { a.newInstance } + catch { + case e: IllegalAccessException => + throw new DependencyManagerException(s"Failed to construct '$a'! (Did you try to construct an object?)", e) + case e: InstantiationException => + throw new DependencyManagerException( + s"Failed to construct '$a'! (Did you try to construct an inner class or a class with parameters?)", + e + ) } } @@ -124,7 +128,7 @@ trait DependencyAPI[A <: DependencyAPI[A]] { this: TransformLike[_] => /** All transform that must run before this transform * $seqNote */ - def prerequisites: Seq[Dependency[A]] = Seq.empty + def prerequisites: Seq[Dependency[A]] = Seq.empty private[options] lazy val _prerequisites: LinkedHashSet[Dependency[A]] = new LinkedHashSet() ++ prerequisites /** All transforms that, if a prerequisite of *another* transform, will run before this transform. @@ -184,8 +188,10 @@ trait DependencyAPI[A <: DependencyAPI[A]] { this: TransformLike[_] => /** A trait indicating that no invalidations occur, i.e., all previous transforms are preserved * @tparam A some [[TransformLike]] */ -@deprecated("Use an explicit `override def invalidates` returning false. This will be removed in FIRRTL 1.5.", - "FIRRTL 1.4") +@deprecated( + "Use an explicit `override def invalidates` returning false. This will be removed in FIRRTL 1.5.", + "FIRRTL 1.4" +) trait PreservesAll[A <: DependencyAPI[A]] { this: DependencyAPI[A] => override final def invalidates(a: A): Boolean = false diff --git a/src/main/scala/firrtl/options/Registration.scala b/src/main/scala/firrtl/options/Registration.scala index c832ec7c..55772c79 100644 --- a/src/main/scala/firrtl/options/Registration.scala +++ b/src/main/scala/firrtl/options/Registration.scala @@ -14,26 +14,26 @@ import scopt.{OptionDef, OptionParser, Read} * @param shortOption an optional single-dash option * @param helpValueName a string to show as a placeholder argument in help text */ -final class ShellOption[A: Read] ( - val longOption: String, +final class ShellOption[A: Read]( + val longOption: String, val toAnnotationSeq: A => AnnotationSeq, - val helpText: String, - val shortOption: Option[String] = None, - val helpValueName: Option[String] = None -) { + val helpText: String, + val shortOption: Option[String] = None, + val helpValueName: Option[String] = None) { /** Add this specific shell (command line) option to an option parser * @param p an option parser */ final def addOption(p: OptionParser[AnnotationSeq]): Unit = { val f = Seq( - (p: OptionDef[A, AnnotationSeq]) => p.action( (x, c) => toAnnotationSeq(x).reverse ++ c ), + (p: OptionDef[A, AnnotationSeq]) => p.action((x, c) => toAnnotationSeq(x).reverse ++ c), (p: OptionDef[A, AnnotationSeq]) => p.text(helpText), - (p: OptionDef[A, AnnotationSeq]) => p.unbounded()) ++ - shortOption.map( a => (p: OptionDef[A, AnnotationSeq]) => p.abbr(a) ) ++ - helpValueName.map( a => (p: OptionDef[A, AnnotationSeq]) => p.valueName(a) ) + (p: OptionDef[A, AnnotationSeq]) => p.unbounded() + ) ++ + shortOption.map(a => (p: OptionDef[A, AnnotationSeq]) => p.abbr(a)) ++ + helpValueName.map(a => (p: OptionDef[A, AnnotationSeq]) => p.valueName(a)) - f.foldLeft(p.opt[A](longOption))( (a, b) => b(a) ) + f.foldLeft(p.opt[A](longOption))((a, b) => b(a)) } } @@ -55,13 +55,15 @@ trait HasShellOptions { /** A [[Transform]] that includes an option that should be exposed at the top level. * * @note To complete registration, include an entry in - * src/main/resources/META-INF/services/firrtl.options.RegisteredTransform */ + * src/main/resources/META-INF/services/firrtl.options.RegisteredTransform + */ trait RegisteredTransform extends HasShellOptions { this: Transform => } /** A class that includes options that should be exposed as a group at the top level. * * @note To complete registration, include an entry in - * src/main/resources/META-INF/services/firrtl.options.RegisteredLibrary */ + * src/main/resources/META-INF/services/firrtl.options.RegisteredLibrary + */ trait RegisteredLibrary extends HasShellOptions { /** The name of this library. diff --git a/src/main/scala/firrtl/options/Shell.scala b/src/main/scala/firrtl/options/Shell.scala index 88301d30..b0ead81f 100644 --- a/src/main/scala/firrtl/options/Shell.scala +++ b/src/main/scala/firrtl/options/Shell.scala @@ -4,7 +4,7 @@ package firrtl.options import firrtl.AnnotationSeq -import logger.{LogLevelAnnotation, ClassLogLevelAnnotation, LogFileAnnotation, LogClassNamesAnnotation} +import logger.{ClassLogLevelAnnotation, LogClassNamesAnnotation, LogFileAnnotation, LogLevelAnnotation} import scopt.OptionParser @@ -62,28 +62,25 @@ class Shell(val applicationName: String) { parser.note("Shell Options") ProgramArgsAnnotation.addOptions(parser) - Seq( TargetDirAnnotation, - InputAnnotationFileAnnotation, - OutputAnnotationFileAnnotation ) + Seq(TargetDirAnnotation, InputAnnotationFileAnnotation, OutputAnnotationFileAnnotation) .foreach(_.addOptions(parser)) - parser.opt[Unit]("show-registrations") - .action{ (_, c) => + parser + .opt[Unit]("show-registrations") + .action { (_, c) => val rtString = registeredTransforms.map(r => s"\n - ${r.getClass.getName}").mkString val rlString = registeredLibraries.map(l => s"\n - ${l.getClass.getName}").mkString println(s"""|The following FIRRTL transforms registered command line options:$rtString |The following libraries registered command line options:$rlString""".stripMargin) - c } + c + } .unbounded() .text("print discovered registered libraries and transforms") parser.help("help").text("prints this usage text") parser.note("Logging Options") - Seq( LogLevelAnnotation, - ClassLogLevelAnnotation, - LogFileAnnotation, - LogClassNamesAnnotation ) + Seq(LogLevelAnnotation, ClassLogLevelAnnotation, LogFileAnnotation, LogClassNamesAnnotation) .foreach(_.addOptions(parser)) } diff --git a/src/main/scala/firrtl/options/Stage.scala b/src/main/scala/firrtl/options/Stage.scala index aa4809dd..77c8133b 100644 --- a/src/main/scala/firrtl/options/Stage.scala +++ b/src/main/scala/firrtl/options/Stage.scala @@ -37,10 +37,12 @@ abstract class Stage extends Phase { .foldLeft(annotations)((a, p) => p.transform(a)) Logger.makeScope(annotationsx) { - Seq( new phases.AddDefaults, - new phases.Checks, - new Phase { def transform(a: AnnotationSeq) = run(a) }, - new phases.WriteOutputAnnotations ) + Seq( + new phases.AddDefaults, + new phases.Checks, + new Phase { def transform(a: AnnotationSeq) = run(a) }, + new phases.WriteOutputAnnotations + ) .map(phases.DeletedWrapper(_)) .foldLeft(annotationsx)((a, p) => p.transform(a)) } @@ -61,6 +63,7 @@ abstract class Stage extends Phase { * @param stage the stage to run */ class StageMain(val stage: Stage) { + /** The main function that serves as this stage's command line interface. * @param args command line arguments */ diff --git a/src/main/scala/firrtl/options/StageAnnotations.scala b/src/main/scala/firrtl/options/StageAnnotations.scala index 32f8ff59..84168975 100644 --- a/src/main/scala/firrtl/options/StageAnnotations.scala +++ b/src/main/scala/firrtl/options/StageAnnotations.scala @@ -89,7 +89,9 @@ object TargetDirAnnotation extends HasShellOptions { toAnnotationSeq = (a: String) => Seq(TargetDirAnnotation(a)), helpText = "Work directory (default: '.')", shortOption = Some("td"), - helpValueName = Some("<directory>") ) ) + helpValueName = Some("<directory>") + ) + ) } @@ -101,10 +103,11 @@ case class ProgramArgsAnnotation(arg: String) extends NoTargetAnnotation with St object ProgramArgsAnnotation { - def addOptions(p: OptionParser[AnnotationSeq]): Unit = p.arg[String]("<arg>...") + def addOptions(p: OptionParser[AnnotationSeq]): Unit = p + .arg[String]("<arg>...") .unbounded() .optional() - .action( (x, c) => ProgramArgsAnnotation(x) +: c ) + .action((x, c) => ProgramArgsAnnotation(x) +: c) .text("optional unbounded args") } @@ -123,7 +126,9 @@ object InputAnnotationFileAnnotation extends HasShellOptions { toAnnotationSeq = (a: String) => Seq(InputAnnotationFileAnnotation(a)), helpText = "An input annotation file", shortOption = Some("faf"), - helpValueName = Some("<file>") ) ) + helpValueName = Some("<file>") + ) + ) } @@ -141,7 +146,9 @@ object OutputAnnotationFileAnnotation extends HasShellOptions { toAnnotationSeq = (a: String) => Seq(OutputAnnotationFileAnnotation(a)), helpText = "An output annotation file", shortOption = Some("foaf"), - helpValueName = Some("<file>") ) ) + helpValueName = Some("<file>") + ) + ) } @@ -156,6 +163,8 @@ case object WriteDeletedAnnotation extends NoTargetAnnotation with StageOption w new ShellOption[Unit]( longOption = "write-deleted", toAnnotationSeq = (_: Unit) => Seq(WriteDeletedAnnotation), - helpText = "Include deleted annotations in the output annotation file" ) ) + helpText = "Include deleted annotations in the output annotation file" + ) + ) } diff --git a/src/main/scala/firrtl/options/StageOptions.scala b/src/main/scala/firrtl/options/StageOptions.scala index f60a991c..6b9190a7 100644 --- a/src/main/scala/firrtl/options/StageOptions.scala +++ b/src/main/scala/firrtl/options/StageOptions.scala @@ -10,26 +10,28 @@ import java.io.File * @param programArgs explicit program arguments * @param outputAnnotationFileName an output annotation filename */ -class StageOptions private [firrtl] ( - val targetDir: String = TargetDirAnnotation().directory, - val annotationFilesIn: Seq[String] = Seq.empty, +class StageOptions private[firrtl] ( + val targetDir: String = TargetDirAnnotation().directory, + val annotationFilesIn: Seq[String] = Seq.empty, val annotationFileOut: Option[String] = None, - val programArgs: Seq[String] = Seq.empty, - val writeDeleted: Boolean = false ) { + val programArgs: Seq[String] = Seq.empty, + val writeDeleted: Boolean = false) { - private [options] def copy( - targetDir: String = targetDir, - annotationFilesIn: Seq[String] = annotationFilesIn, + private[options] def copy( + targetDir: String = targetDir, + annotationFilesIn: Seq[String] = annotationFilesIn, annotationFileOut: Option[String] = annotationFileOut, - programArgs: Seq[String] = programArgs, - writeDeleted: Boolean = writeDeleted ): StageOptions = { + programArgs: Seq[String] = programArgs, + writeDeleted: Boolean = writeDeleted + ): StageOptions = { new StageOptions( targetDir = targetDir, annotationFilesIn = annotationFilesIn, annotationFileOut = annotationFileOut, programArgs = programArgs, - writeDeleted = writeDeleted ) + writeDeleted = writeDeleted + ) } @@ -62,9 +64,9 @@ class StageOptions private [firrtl] ( }.toPath.normalize.toFile file.getParentFile match { - case null => + case null => case parent if (!parent.exists) => parent.mkdirs() - case _ => + case _ => } file.toString diff --git a/src/main/scala/firrtl/options/StageUtils.scala b/src/main/scala/firrtl/options/StageUtils.scala index 3983f653..2411da6e 100644 --- a/src/main/scala/firrtl/options/StageUtils.scala +++ b/src/main/scala/firrtl/options/StageUtils.scala @@ -2,16 +2,16 @@ package firrtl.options - /** Utilities related to working with a [[Stage]] */ object StageUtils { + /** Print a warning message (in yellow) * @param message error message */ def dramaticWarning(message: String): Unit = { - println(Console.YELLOW + "-"*78) + println(Console.YELLOW + "-" * 78) println(s"Warning: $message") - println("-"*78 + Console.RESET) + println("-" * 78 + Console.RESET) } /** Print an error message (in red) @@ -19,9 +19,9 @@ object StageUtils { * @note This does not stop the Driver. */ def dramaticError(message: String): Unit = { - println(Console.RED + "-"*78) + println(Console.RED + "-" * 78) println(s"Error: $message") - println("-"*78 + Console.RESET) + println("-" * 78 + Console.RESET) } /** Generate a message suggesting that the user look at the usage text. diff --git a/src/main/scala/firrtl/options/package.scala b/src/main/scala/firrtl/options/package.scala index 8cf2875b..f87fb8a8 100644 --- a/src/main/scala/firrtl/options/package.scala +++ b/src/main/scala/firrtl/options/package.scala @@ -5,17 +5,16 @@ package firrtl package object options { implicit object StageOptionsView extends OptionsView[StageOptions] { - def view(options: AnnotationSeq): StageOptions = options - .collect { case a: StageOption => a } + def view(options: AnnotationSeq): StageOptions = options.collect { case a: StageOption => a } .foldLeft(new StageOptions())((c, x) => x match { case TargetDirAnnotation(a) => c.copy(targetDir = a) /* Insert input files at the head of the Seq for speed and because order shouldn't matter */ - case InputAnnotationFileAnnotation(a) => c.copy(annotationFilesIn = a +: c.annotationFilesIn) + case InputAnnotationFileAnnotation(a) => c.copy(annotationFilesIn = a +: c.annotationFilesIn) case OutputAnnotationFileAnnotation(a) => c.copy(annotationFileOut = Some(a)) /* Do NOT reorder program args. The order may matter. */ case ProgramArgsAnnotation(a) => c.copy(programArgs = c.programArgs :+ a) - case WriteDeletedAnnotation => c.copy(writeDeleted = true) + case WriteDeletedAnnotation => c.copy(writeDeleted = true) } ) } diff --git a/src/main/scala/firrtl/options/phases/AddDefaults.scala b/src/main/scala/firrtl/options/phases/AddDefaults.scala index ab342b1e..0ef1832a 100644 --- a/src/main/scala/firrtl/options/phases/AddDefaults.scala +++ b/src/main/scala/firrtl/options/phases/AddDefaults.scala @@ -19,7 +19,7 @@ class AddDefaults extends Phase { override def invalidates(a: Phase) = false def transform(annotations: AnnotationSeq): AnnotationSeq = { - val td = annotations.collectFirst{ case a: TargetDirAnnotation => a}.isEmpty + val td = annotations.collectFirst { case a: TargetDirAnnotation => a }.isEmpty (if (td) Seq(TargetDirAnnotation()) else Seq()) ++ annotations diff --git a/src/main/scala/firrtl/options/phases/Checks.scala b/src/main/scala/firrtl/options/phases/Checks.scala index 9e671aa5..024c13a9 100644 --- a/src/main/scala/firrtl/options/phases/Checks.scala +++ b/src/main/scala/firrtl/options/phases/Checks.scala @@ -25,24 +25,27 @@ class Checks extends Phase { val td, outA = collection.mutable.ListBuffer[Annotation]() annotations.foreach { - case a: TargetDirAnnotation => td += a + case a: TargetDirAnnotation => td += a case a: OutputAnnotationFileAnnotation => outA += a case _ => } if (td.size != 1) { - val d = td.map{ case TargetDirAnnotation(x) => x } + val d = td.map { case TargetDirAnnotation(x) => x } throw new OptionsException( s"""|Exactly one target directory must be specified, but found `${d.mkString(", ")}` specified via: | - explicit target directory: -td, --target-dir, TargetDirAnnotation - | - fallback default value""".stripMargin )} + | - fallback default value""".stripMargin + ) + } if (outA.size > 1) { - val x = outA.map{ case OutputAnnotationFileAnnotation(x) => x } + val x = outA.map { case OutputAnnotationFileAnnotation(x) => x } throw new OptionsException( s"""|At most one output annotation file can be specified, but found '${x.mkString(", ")}' specified via: - | - an option or annotation: -foaf, --output-annotation-file, OutputAnnotationFileAnnotation""" - .stripMargin )} + | - an option or annotation: -foaf, --output-annotation-file, OutputAnnotationFileAnnotation""".stripMargin + ) + } annotations } diff --git a/src/main/scala/firrtl/options/phases/GetIncludes.scala b/src/main/scala/firrtl/options/phases/GetIncludes.scala index b9320585..dd08e09b 100644 --- a/src/main/scala/firrtl/options/phases/GetIncludes.scala +++ b/src/main/scala/firrtl/options/phases/GetIncludes.scala @@ -10,7 +10,7 @@ import firrtl.FileUtils import java.io.File import scala.collection.mutable -import scala.util.{Try, Failure} +import scala.util.{Failure, Try} /** Recursively expand all [[InputAnnotationFileAnnotation]]s in an [[AnnotationSeq]] */ class GetIncludes extends Phase { @@ -37,8 +37,7 @@ class GetIncludes extends Phase { * @param annos a sequence of annotations * @return the original annotation sequence with any discovered annotations added */ - private def getIncludes(includeGuard: mutable.Set[String] = mutable.Set()) - (annos: AnnotationSeq): AnnotationSeq = { + private def getIncludes(includeGuard: mutable.Set[String] = mutable.Set())(annos: AnnotationSeq): AnnotationSeq = { annos.flatMap { case a @ InputAnnotationFileAnnotation(value) => if (includeGuard.contains(value)) { diff --git a/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala b/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala index 7ee385b1..53306c8a 100644 --- a/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala +++ b/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala @@ -16,9 +16,7 @@ import scala.collection.mutable class WriteOutputAnnotations extends Phase { override def prerequisites = - Seq( Dependency[GetIncludes], - Dependency[AddDefaults], - Dependency[Checks] ) + Seq(Dependency[GetIncludes], Dependency[AddDefaults], Dependency[Checks]) override def optionalPrerequisiteOf = Seq.empty @@ -29,8 +27,10 @@ class WriteOutputAnnotations extends Phase { val sopts = Viewer[StageOptions].view(annotations) val filesWritten = mutable.HashMap.empty[String, Annotation] val serializable: AnnotationSeq = annotations.toSeq.flatMap { - case _: Unserializable => None - case a: DeletedAnnotation => if (sopts.writeDeleted) { Some(a) } else { None } + case _: Unserializable => None + case a: DeletedAnnotation => + if (sopts.writeDeleted) { Some(a) } + else { None } case a: CustomFileEmission => val filename = a.filename(annotations) val canonical = filename.getCanonicalPath() @@ -38,7 +38,7 @@ class WriteOutputAnnotations extends Phase { filesWritten.get(canonical) match { case None => val w = new BufferedWriter(new FileWriter(filename)) - a.getBytes.foreach( w.write(_) ) + a.getBytes.foreach(w.write(_)) w.close() filesWritten(canonical) = a case Some(first) => diff --git a/src/main/scala/firrtl/passes/CInferMDir.scala b/src/main/scala/firrtl/passes/CInferMDir.scala index b4819751..1fe8d57c 100644 --- a/src/main/scala/firrtl/passes/CInferMDir.scala +++ b/src/main/scala/firrtl/passes/CInferMDir.scala @@ -18,60 +18,61 @@ object CInferMDir extends Pass { def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match { case e: Reference => - mports get e.name match { + mports.get(e.name) match { case None => - case Some(p) => mports(e.name) = (p, dir) match { - case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MInfer, MWrite) => MWrite - case (MInfer, MRead) => MRead - case (MInfer, MReadWrite) => MReadWrite - case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MWrite, MWrite) => MWrite - case (MWrite, MRead) => MReadWrite - case (MWrite, MReadWrite) => MReadWrite - case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MRead, MWrite) => MReadWrite - case (MRead, MRead) => MRead - case (MRead, MReadWrite) => MReadWrite - case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") - case (MReadWrite, MWrite) => MReadWrite - case (MReadWrite, MRead) => MReadWrite - case (MReadWrite, MReadWrite) => MReadWrite - } + case Some(p) => + mports(e.name) = (p, dir) match { + case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MInfer, MWrite) => MWrite + case (MInfer, MRead) => MRead + case (MInfer, MReadWrite) => MReadWrite + case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MWrite, MWrite) => MWrite + case (MWrite, MRead) => MReadWrite + case (MWrite, MReadWrite) => MReadWrite + case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MRead, MWrite) => MReadWrite + case (MRead, MRead) => MRead + case (MRead, MReadWrite) => MReadWrite + case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir") + case (MReadWrite, MWrite) => MReadWrite + case (MReadWrite, MRead) => MReadWrite + case (MReadWrite, MReadWrite) => MReadWrite + } } e case e: SubAccess => infer_mdir_e(mports, dir)(e.expr) infer_mdir_e(mports, MRead)(e.index) // index can't be a write port e - case e => e map infer_mdir_e(mports, dir) + case e => e.map(infer_mdir_e(mports, dir)) } def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match { case sx: CDefMPort => - mports(sx.name) = sx.direction - sx map infer_mdir_e(mports, MRead) + mports(sx.name) = sx.direction + sx.map(infer_mdir_e(mports, MRead)) case sx: Connect => - infer_mdir_e(mports, MRead)(sx.expr) - infer_mdir_e(mports, MWrite)(sx.loc) - sx + infer_mdir_e(mports, MRead)(sx.expr) + infer_mdir_e(mports, MWrite)(sx.loc) + sx case sx: PartialConnect => - infer_mdir_e(mports, MRead)(sx.expr) - infer_mdir_e(mports, MWrite)(sx.loc) - sx - case sx => sx map infer_mdir_s(mports) map infer_mdir_e(mports, MRead) + infer_mdir_e(mports, MRead)(sx.expr) + infer_mdir_e(mports, MWrite)(sx.loc) + sx + case sx => sx.map(infer_mdir_s(mports)).map(infer_mdir_e(mports, MRead)) } def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match { - case sx: CDefMPort => sx copy (direction = mports(sx.name)) - case sx => sx map set_mdir_s(mports) + case sx: CDefMPort => sx.copy(direction = mports(sx.name)) + case sx => sx.map(set_mdir_s(mports)) } def infer_mdir(m: DefModule): DefModule = { val mports = new MPortDirMap - m map infer_mdir_s(mports) map set_mdir_s(mports) + m.map(infer_mdir_s(mports)).map(set_mdir_s(mports)) } def run(c: Circuit): Circuit = - c copy (modules = c.modules map infer_mdir) + c.copy(modules = c.modules.map(infer_mdir)) } diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 9903f445..97d614c1 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -12,9 +12,7 @@ object CheckChirrtl extends Pass with CheckHighFormLike { override def prerequisites = Dependency[CheckScalaVersion] :: Nil override val optionalPrerequisiteOf = firrtl.stage.Forms.ChirrtlForm ++ - Seq( Dependency(CInferTypes), - Dependency(CInferMDir), - Dependency(RemoveCHIRRTL) ) + Seq(Dependency(CInferTypes), Dependency(CInferMDir), Dependency(RemoveCHIRRTL)) override def invalidates(a: Transform) = false diff --git a/src/main/scala/firrtl/passes/CheckFlows.scala b/src/main/scala/firrtl/passes/CheckFlows.scala index 3a9cc212..bc455a20 100644 --- a/src/main/scala/firrtl/passes/CheckFlows.scala +++ b/src/main/scala/firrtl/passes/CheckFlows.scala @@ -13,79 +13,87 @@ object CheckFlows extends Pass { override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR override def optionalPrerequisiteOf = - Seq( Dependency[passes.InferBinaryPoints], - Dependency[passes.TrimIntervals], - Dependency[passes.InferWidths], - Dependency[transforms.InferResets] ) + Seq( + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals], + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] + ) override def invalidates(a: Transform) = false type FlowMap = collection.mutable.HashMap[String, Flow] implicit def toStr(g: Flow): String = g match { - case SourceFlow => "source" - case SinkFlow => "sink" + case SourceFlow => "source" + case SinkFlow => "sink" case UnknownFlow => "unknown" - case DuplexFlow => "duplex" + case DuplexFlow => "duplex" } - class WrongFlow(info:Info, mname: String, expr: String, wrong: Flow, right: Flow) extends PassException( - s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right.") + class WrongFlow(info: Info, mname: String, expr: String, wrong: Flow, right: Flow) + extends PassException( + s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right." + ) - def run (c:Circuit): Circuit = { + def run(c: Circuit): Circuit = { val errors = new Errors() def get_flow(e: Expression, flows: FlowMap): Flow = e match { - case (e: WRef) => flows(e.name) + case (e: WRef) => flows(e.name) case (e: WSubIndex) => get_flow(e.expr, flows) case (e: WSubAccess) => get_flow(e.expr, flows) - case (e: WSubField) => e.expr.tpe match {case t: BundleType => - val f = (t.fields find (_.name == e.name)).get - times(get_flow(e.expr, flows), f.flip) - } + case (e: WSubField) => + e.expr.tpe match { + case t: BundleType => + val f = (t.fields.find(_.name == e.name)).get + times(get_flow(e.expr, flows), f.flip) + } case _ => SourceFlow } def flip_q(t: Type): Boolean = { def flip_rec(t: Type, f: Orientation): Boolean = t match { - case tx:BundleType => tx.fields exists ( - field => flip_rec(field.tpe, times(f, field.flip)) - ) + case tx: BundleType => tx.fields.exists(field => flip_rec(field.tpe, times(f, field.flip))) case tx: VectorType => flip_rec(tx.tpe, f) case tx => f == Flip } flip_rec(t, Default) } - def check_flow(info:Info, mname: String, flows: FlowMap, desired: Flow)(e:Expression): Unit = { - val flow = get_flow(e,flows) + def check_flow(info: Info, mname: String, flows: FlowMap, desired: Flow)(e: Expression): Unit = { + val flow = get_flow(e, flows) (flow, desired) match { case (SourceFlow, SinkFlow) => errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) - case (SinkFlow, SourceFlow) => kind(e) match { - case PortKind | InstanceKind if !flip_q(e.tpe) => // OK! - case _ => - errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) - } + case (SinkFlow, SourceFlow) => + kind(e) match { + case PortKind | InstanceKind if !flip_q(e.tpe) => // OK! + case _ => + errors.append(new WrongFlow(info, mname, e.serialize, desired, flow)) + } case _ => } - } + } - def check_flows_e (info:Info, mname: String, flows: FlowMap)(e:Expression): Unit = { + def check_flows_e(info: Info, mname: String, flows: FlowMap)(e: Expression): Unit = { e match { - case e: Mux => e foreach check_flow(info, mname, flows, SourceFlow) - case e: DoPrim => e.args foreach check_flow(info, mname, flows, SourceFlow) + case e: Mux => e.foreach(check_flow(info, mname, flows, SourceFlow)) + case e: DoPrim => e.args.foreach(check_flow(info, mname, flows, SourceFlow)) case _ => } - e foreach check_flows_e(info, mname, flows) + e.foreach(check_flows_e(info, mname, flows)) } def check_flows_s(minfo: Info, mname: String, flows: FlowMap)(s: Statement): Unit = { - val info = get_info(s) match { case NoInfo => minfo case x => x } + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } s match { - case (s: DefWire) => flows(s.name) = DuplexFlow + case (s: DefWire) => flows(s.name) = DuplexFlow case (s: DefRegister) => flows(s.name) = DuplexFlow - case (s: DefMemory) => flows(s.name) = SourceFlow + case (s: DefMemory) => flows(s.name) = SourceFlow case (s: WDefInstance) => flows(s.name) = SourceFlow case (s: DefNode) => check_flow(info, mname, flows, SourceFlow)(s.value) @@ -94,7 +102,7 @@ object CheckFlows extends Pass { check_flow(info, mname, flows, SinkFlow)(s.loc) check_flow(info, mname, flows, SourceFlow)(s.expr) case (s: Print) => - s.args foreach check_flow(info, mname, flows, SourceFlow) + s.args.foreach(check_flow(info, mname, flows, SourceFlow)) check_flow(info, mname, flows, SourceFlow)(s.en) check_flow(info, mname, flows, SourceFlow)(s.clk) case (s: PartialConnect) => @@ -111,14 +119,14 @@ object CheckFlows extends Pass { check_flow(info, mname, flows, SourceFlow)(s.en) case _ => } - s foreach check_flows_e(info, mname, flows) - s foreach check_flows_s(minfo, mname, flows) + s.foreach(check_flows_e(info, mname, flows)) + s.foreach(check_flows_s(minfo, mname, flows)) } for (m <- c.modules) { val flows = new FlowMap - flows ++= (m.ports map (p => p.name -> to_flow(p.direction))) - m foreach check_flows_s(m.info, m.name, flows) + flows ++= (m.ports.map(p => p.name -> to_flow(p.direction))) + m.foreach(check_flows_s(m.info, m.name, flows)) } errors.trigger() c diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala index 2f706d35..559c9060 100644 --- a/src/main/scala/firrtl/passes/CheckHighForm.scala +++ b/src/main/scala/firrtl/passes/CheckHighForm.scala @@ -27,66 +27,71 @@ trait CheckHighFormLike { this: Pass => scopes.find(_.contains(port.mem)).getOrElse(scopes.head) += port.name } def legalDecl(name: String): Boolean = !moduleNS.contains(name) - def legalRef(name: String): Boolean = scopes.exists(_.contains(name)) + def legalRef(name: String): Boolean = scopes.exists(_.contains(name)) def childScope(): ScopeView = new ScopeView(moduleNS, new NameSet +: scopes) } // Custom Exceptions - class NotUniqueException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Reference $name does not have a unique name.") - class InvalidLOCException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort.") - class NegUIntException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] UIntLiteral cannot be negative.") - class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Reference $name is not declared.") - class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Poison $name cannot be a bundle type with flips.") - class MemWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name cannot be a bundle type with flips.") - class IllegalMemLatencyException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency.") - class RegWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Register $name cannot be a bundle type with flips.") - class InvalidAccessException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Invalid access to non-reference.") - class ModuleNameNotUniqueException(info: Info, mname: String) extends PassException( - s"$info: Repeat definition of module $mname") - class DefnameConflictException(info: Info, mname: String, defname: String) extends PassException( - s"$info: defname $defname of extmodule $mname conflicts with an existing module") - class DefnameDifferentPortsException(info: Info, mname: String, defname: String) extends PassException( - s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname""") - class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException( - s"$info: Module $name is not defined.") - class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) extends PassException( - s"$info: [module $mname] Primop $op requires $n expression arguments.") - class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) extends PassException( - s"$info: [module $mname] Primop $op requires $n integer arguments.") - class NegWidthException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Width cannot be negative.") - class NegVecSizeException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Vector type size cannot be negative.") - class NegMemSizeException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Memory size cannot be negative or zero.") - class BadPrintfException(info: Info, mname: String, x: Char) extends PassException( - s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"") - class BadPrintfTrailingException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Bad printf format: trailing " + "\"%\"") - class BadPrintfIncorrectNumException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Bad printf format: incorrect number of arguments") - class InstanceLoop(info: Info, mname: String, loop: String) extends PassException( - s"$info: [module $mname] Has instance loop $loop") - class NoTopModuleException(info: Info, name: String) extends PassException( - s"$info: A single module must be named $name.") - class NegArgException(info: Info, mname: String, op: String, value: BigInt) extends PassException( - s"$info: [module $mname] Primop $op argument $value < 0.") - class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt) extends PassException( - s"$info: [module $mname] Primop $op lsb $lsb > $msb.") - class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException( - s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}") - class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException( - s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}") - + class NotUniqueException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Reference $name does not have a unique name.") + class InvalidLOCException(info: Info, mname: String) + extends PassException( + s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort." + ) + class NegUIntException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] UIntLiteral cannot be negative.") + class UndeclaredReferenceException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Reference $name is not declared.") + class PoisonWithFlipException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Poison $name cannot be a bundle type with flips.") + class MemWithFlipException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Memory $name cannot be a bundle type with flips.") + class IllegalMemLatencyException(info: Info, mname: String, name: String) + extends PassException( + s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency." + ) + class RegWithFlipException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Register $name cannot be a bundle type with flips.") + class InvalidAccessException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Invalid access to non-reference.") + class ModuleNameNotUniqueException(info: Info, mname: String) + extends PassException(s"$info: Repeat definition of module $mname") + class DefnameConflictException(info: Info, mname: String, defname: String) + extends PassException(s"$info: defname $defname of extmodule $mname conflicts with an existing module") + class DefnameDifferentPortsException(info: Info, mname: String, defname: String) + extends PassException( + s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname""" + ) + class ModuleNotDefinedException(info: Info, mname: String, name: String) + extends PassException(s"$info: Module $name is not defined.") + class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) + extends PassException(s"$info: [module $mname] Primop $op requires $n expression arguments.") + class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) + extends PassException(s"$info: [module $mname] Primop $op requires $n integer arguments.") + class NegWidthException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Width cannot be negative.") + class NegVecSizeException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Vector type size cannot be negative.") + class NegMemSizeException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Memory size cannot be negative or zero.") + class BadPrintfException(info: Info, mname: String, x: Char) + extends PassException(s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"") + class BadPrintfTrailingException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Bad printf format: trailing " + "\"%\"") + class BadPrintfIncorrectNumException(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Bad printf format: incorrect number of arguments") + class InstanceLoop(info: Info, mname: String, loop: String) + extends PassException(s"$info: [module $mname] Has instance loop $loop") + class NoTopModuleException(info: Info, name: String) + extends PassException(s"$info: A single module must be named $name.") + class NegArgException(info: Info, mname: String, op: String, value: BigInt) + extends PassException(s"$info: [module $mname] Primop $op argument $value < 0.") + class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt) + extends PassException(s"$info: [module $mname] Primop $op lsb $lsb > $msb.") + class ResetInputException(info: Info, mname: String, expr: Expression) + extends PassException(s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}") + class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) + extends PassException(s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}") // Is Chirrtl allowed for this check? If not, return an error def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] @@ -94,12 +99,12 @@ trait CheckHighFormLike { this: Pass => def run(c: Circuit): Circuit = { val errors = new Errors() val moduleGraph = new ModuleGraph - val moduleNames = (c.modules map (_.name)).toSet + val moduleNames = (c.modules.map(_.name)).toSet val intModuleNames = c.modules.view.collect({ case m: Module => m.name }).toSet - c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach { - m => errors.append(new ModuleNameNotUniqueException(m.info, m.name)) + c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach { m => + errors.append(new ModuleNameNotUniqueException(m.info, m.name)) } /** Strip all widths from types */ @@ -110,16 +115,18 @@ trait CheckHighFormLike { this: Pass => val extmoduleCollidingPorts = c.modules.collect { case a: ExtModule => a - }.groupBy(a => (a.defname, a.params.nonEmpty)).map { - /* There are no parameters, so all ports must match exactly. */ - case (k@ (_, false), a) => - k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo))).toSet - /* If there are parameters, then only port names must match because parameters could parameterize widths. - * This means that this check cannot produce false positives, but can have false negatives. - */ - case (k@ (_, true), a) => - k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo).mapType(stripWidth))).toSet - }.filter(_._2.size > 1) + }.groupBy(a => (a.defname, a.params.nonEmpty)) + .map { + /* There are no parameters, so all ports must match exactly. */ + case (k @ (_, false), a) => + k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo))).toSet + /* If there are parameters, then only port names must match because parameters could parameterize widths. + * This means that this check cannot produce false positives, but can have false negatives. + */ + case (k @ (_, true), a) => + k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo).mapType(stripWidth))).toSet + } + .filter(_._2.size > 1) c.modules.collect { case a: ExtModule => @@ -129,7 +136,8 @@ trait CheckHighFormLike { this: Pass => case _ => } a match { - case ExtModule(info, name, _, defname, params) if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) => + case ExtModule(info, name, _, defname, params) + if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) => errors.append(new DefnameDifferentPortsException(info, name, defname)) case _ => } @@ -147,14 +155,14 @@ trait CheckHighFormLike { this: Pass => } def nonNegativeConsts(): Unit = { - e.consts.filter(_ < 0).foreach { - negC => errors.append(new NegArgException(info, mname, e.op.toString, negC)) + e.consts.filter(_ < 0).foreach { negC => + errors.append(new NegArgException(info, mname, e.op.toString, negC)) } } e.op match { - case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | - Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze => + case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | + Clip | Wrap | Squeeze => correctNum(Option(2), 0) case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not => correctNum(Option(1), 0) @@ -175,7 +183,7 @@ trait CheckHighFormLike { this: Pass => case AsInterval => correctNum(Option(1), 3) case Andr | Orr | Xorr | Neg => - correctNum(None,0) + correctNum(None, 0) } } @@ -208,12 +216,12 @@ trait CheckHighFormLike { this: Pass => } def checkHighFormT(info: Info, mname: => String)(t: Type): Unit = { - t foreach checkHighFormT(info, mname) + t.foreach(checkHighFormT(info, mname)) t match { case tx: VectorType if tx.size < 0 => errors.append(new NegVecSizeException(info, mname)) case _: IntervalType => - case _ => t foreach checkHighFormW(info, mname) + case _ => t.foreach(checkHighFormW(info, mname)) } } @@ -235,12 +243,12 @@ trait CheckHighFormLike { this: Pass => errors.append(new NegUIntException(info, mname)) case ex: DoPrim => checkHighFormPrimop(info, mname, ex) case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => - case ex: SubAccess => validSubexp(info, mname)(ex.expr) + case ex: SubAccess => validSubexp(info, mname)(ex.expr) case ex: WSubAccess => validSubexp(info, mname)(ex.expr) - case ex => ex foreach validSubexp(info, mname) + case ex => ex.foreach(validSubexp(info, mname)) } - e foreach checkHighFormW(info, mname + "/" + e.serialize) - e foreach checkHighFormE(info, mname, names) + e.foreach(checkHighFormW(info, mname + "/" + e.serialize)) + e.foreach(checkHighFormE(info, mname, names)) } def checkName(info: Info, mname: String, names: ScopeView)(name: String): Unit = { @@ -253,14 +261,17 @@ trait CheckHighFormLike { this: Pass => if (!moduleNames(child)) errors.append(new ModuleNotDefinedException(info, parent, child)) // Check to see if a recursive module instantiation has occured - val childToParent = moduleGraph add (parent, child) + val childToParent = moduleGraph.add(parent, child) if (childToParent.nonEmpty) - errors.append(new InstanceLoop(info, parent, childToParent mkString "->")) + errors.append(new InstanceLoop(info, parent, childToParent.mkString("->"))) } def checkHighFormS(minfo: Info, mname: String, names: ScopeView)(s: Statement): Unit = { - val info = get_info(s) match {case NoInfo => minfo case x => x} - s foreach checkName(info, mname, names) + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } + s.foreach(checkName(info, mname, names)) s match { case DefRegister(info, name, tpe, _, reset, init) => if (hasFlip(tpe)) @@ -272,24 +283,24 @@ trait CheckHighFormLike { this: Pass => errors.append(new MemWithFlipException(info, mname, sx.name)) if (sx.depth <= 0) errors.append(new NegMemSizeException(info, mname)) - case sx: DefInstance => checkInstance(info, mname, sx.module) - case sx: WDefInstance => checkInstance(info, mname, sx.module) - case sx: Connect => checkValidLoc(info, mname, sx.loc) - case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) - case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) - case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } + case sx: DefInstance => checkInstance(info, mname, sx.module) + case sx: WDefInstance => checkInstance(info, mname, sx.module) + case sx: Connect => checkValidLoc(info, mname, sx.loc) + case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) + case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) + case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } case mport: CDefMPort => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } names.expandMPortVisibility(mport) case sx => // Do Nothing } - s foreach checkHighFormT(info, mname) - s foreach checkHighFormE(info, mname, names) + s.foreach(checkHighFormT(info, mname)) + s.foreach(checkHighFormE(info, mname, names)) s match { - case Conditionally(_,_, conseq, alt) => + case Conditionally(_, _, conseq, alt) => checkHighFormS(minfo, mname, names.childScope())(conseq) checkHighFormS(minfo, mname, names.childScope())(alt) - case _ => s foreach checkHighFormS(minfo, mname, names) + case _ => s.foreach(checkHighFormS(minfo, mname, names)) } } @@ -313,10 +324,10 @@ trait CheckHighFormLike { this: Pass => def checkHighFormM(m: DefModule): Unit = { val names = ScopeView() - m foreach checkHighFormP(m.name, names) - m foreach checkHighFormS(m.info, m.name, names) + m.foreach(checkHighFormP(m.name, names)) + m.foreach(checkHighFormS(m.info, m.name, names)) m match { - case _: Module => + case _: Module => case ext: ExtModule => for ((port, expr) <- findBadResetTypePorts(ext, Output)) { errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr)) @@ -324,7 +335,7 @@ trait CheckHighFormLike { this: Pass => } } - c.modules foreach checkHighFormM + c.modules.foreach(checkHighFormM) c.modules.filter(_.name == c.main) match { case Seq(topMod) => for ((port, expr) <- findBadResetTypePorts(topMod, Input)) { @@ -342,21 +353,23 @@ object CheckHighForm extends Pass with CheckHighFormLike { override def prerequisites = firrtl.stage.Forms.WorkingIR override def optionalPrerequisiteOf = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferWidths], - Dependency[transforms.InferResets] ) + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths], + Dependency[transforms.InferResets] + ) override def invalidates(a: Transform) = false - class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") + class IllegalChirrtlMemException(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = { val memName = s match { case cm: CDefMemory => cm.name - case cp: CDefMPort => cp.mem + case cp: CDefMPort => cp.mem } Some(new IllegalChirrtlMemException(info, mname, memName)) } diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 4a5577f9..96057831 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -22,10 +22,11 @@ object CheckInitialization extends Pass { private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) - class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( - s"$info : [module $mname] Reference $name is not fully initialized.\n" + - trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n") - ) + class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) + extends PassException( + s"$info : [module $mname] Reference $name is not fully initialized.\n" + + trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n") + ) private def getTrace(expr: WrappedExpression, voidExprs: Map[WrappedExpression, VoidExpr]): Seq[Statement] = { @tailrec @@ -81,7 +82,7 @@ object CheckInitialization extends Pass { case node: DefNode => // Ignore nodes case decl: IsDeclaration => val trace = getTrace(expr, voidExprs.toMap) - errors append new RefNotInitializedException(decl.info, m.name, decl.name, trace) + errors.append(new RefNotInitializedException(decl.info, m.name, decl.name, trace)) } } } diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala index c94928a1..956c1134 100644 --- a/src/main/scala/firrtl/passes/CheckTypes.scala +++ b/src/main/scala/firrtl/passes/CheckTypes.scala @@ -16,92 +16,105 @@ object CheckTypes extends Pass { override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR override def optionalPrerequisiteOf = - Seq( Dependency(passes.ResolveFlows), - Dependency(passes.CheckFlows), - Dependency[passes.InferWidths], - Dependency(passes.CheckWidths) ) + Seq( + Dependency(passes.ResolveFlows), + Dependency(passes.CheckFlows), + Dependency[passes.InferWidths], + Dependency(passes.CheckWidths) + ) override def invalidates(a: Transform) = false // Custom Exceptions - class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname ] Subfield $name is not in bundle.") - class SubfieldOnNonBundle(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Subfield $name is accessed on a non-bundle.") - class IndexTooLarge(info: Info, mname: String, value: Int) extends PassException( - s"$info: [module $mname] Index with value $value is too large.") - class IndexOnNonVector(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Index illegal on non-vector type.") - class AccessIndexNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Access index must be a UInt type.") - class IndexNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Index is not of UIntType.") - class EnableNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Enable is not of UIntType.") + class SubfieldNotInBundle(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname ] Subfield $name is not in bundle.") + class SubfieldOnNonBundle(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Subfield $name is accessed on a non-bundle.") + class IndexTooLarge(info: Info, mname: String, value: Int) + extends PassException(s"$info: [module $mname] Index with value $value is too large.") + class IndexOnNonVector(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Index illegal on non-vector type.") + class AccessIndexNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Access index must be a UInt type.") + class IndexNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Index is not of UIntType.") + class EnableNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Enable is not of UIntType.") class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression) extends PassException({ - val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}" - val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}" - s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe" - }) - class InvalidRegInit(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Type of init must match type of DefRegister.") - class PrintfArgNotGround(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.") - class ReqClk(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Requires a clock typed signal.") - class RegReqClk(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Register $name requires a clock typed signal.") - class EnNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Enable must be a UIntType typed signal.") - class PredNotUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Predicate not a UIntType.") - class OpNotGround(info: Info, mname: String, op: String) extends PassException( - s"$info: [module $mname] Primop $op cannot operate on non-ground types.") - class OpNotUInt(info: Info, mname: String, op: String, e: String) extends PassException( - s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.") - class OpNotAllUInt(info: Info, mname: String, op: String) extends PassException( - s"$info: [module $mname] Primop $op requires all arguments to be UInt type.") - class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException( - s"$info: [module $mname] Primop $op requires all operands to have the same type.") - class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.") - class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.") - class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") - class NodePassiveType(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Node must be a passive type.") - class MuxSameType(info: Info, mname: String, t1: String, t2: String) extends PassException( - s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.") - class MuxPassiveTypes(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Must mux between passive types.") - class MuxCondUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] A mux condition must be of type UInt.") - class MuxClock(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Firrtl does not support muxing clocks.") - class ValidIfPassiveTypes(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Must validif a passive type.") - class ValidIfCondUInt(info: Info, mname: String) extends PassException( - s"$info: [module $mname] A validif condition must be of type UInt.") - class IllegalAnalogDeclaration(info: Info, mname: String, decName: String) extends PassException( - s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName.") - class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException( - s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.") - class IllegalResetType(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.") - class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException( - s"$info: [module $mname] Uninferred type: $exp." - ) + val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}" + val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}" + s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe" + }) + class InvalidRegInit(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Type of init must match type of DefRegister.") + class PrintfArgNotGround(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.") + class ReqClk(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Requires a clock typed signal.") + class RegReqClk(info: Info, mname: String, name: String) + extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.") + class EnNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.") + class PredNotUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Predicate not a UIntType.") + class OpNotGround(info: Info, mname: String, op: String) + extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.") + class OpNotUInt(info: Info, mname: String, op: String, e: String) + extends PassException(s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.") + class OpNotAllUInt(info: Info, mname: String, op: String) + extends PassException(s"$info: [module $mname] Primop $op requires all arguments to be UInt type.") + class OpNotAllSameType(info: Info, mname: String, op: String) + extends PassException(s"$info: [module $mname] Primop $op requires all operands to have the same type.") + class OpNoMixFix(info: Info, mname: String, op: String) + extends PassException( + s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type." + ) + class OpNotCorrectType(info: Info, mname: String, op: String, tpes: Seq[String]) + extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.") + class OpNotAnalog(info: Info, mname: String, exp: String) + extends PassException(s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") + class NodePassiveType(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Node must be a passive type.") + class MuxSameType(info: Info, mname: String, t1: String, t2: String) + extends PassException(s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.") + class MuxPassiveTypes(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Must mux between passive types.") + class MuxCondUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.") + class MuxClock(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.") + class ValidIfPassiveTypes(info: Info, mname: String) + extends PassException(s"$info: [module $mname] Must validif a passive type.") + class ValidIfCondUInt(info: Info, mname: String) + extends PassException(s"$info: [module $mname] A validif condition must be of type UInt.") + class IllegalAnalogDeclaration(info: Info, mname: String, decName: String) + extends PassException( + s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName." + ) + class IllegalAttachExp(info: Info, mname: String, expName: String) + extends PassException( + s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName." + ) + class IllegalResetType(info: Info, mname: String, exp: String) + extends PassException( + s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp." + ) + class IllegalUnknownType(info: Info, mname: String, exp: String) + extends PassException( + s"$info: [module $mname] Uninferred type: $exp." + ) def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match { case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false - case _ => true + case _ => true } def legalResetType(tpe: Type): Boolean = tpe match { case UIntType(IntWidth(w)) if w == 1 => true - case AsyncResetType => true - case ResetType => true - case UIntType(UnknownWidth) => + case AsyncResetType => true + case ResetType => true + case UIntType(UnknownWidth) => // cannot catch here, though width may ultimately be wrong true case _ => false @@ -118,13 +131,13 @@ object CheckTypes extends Pass { fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point) case (_: AnalogType, _: AnalogType) => true case (AsyncResetType, AsyncResetType) => flip1 == flip2 - case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 - case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2 + case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2 + case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2 case (t1: BundleType, t2: BundleType) => - val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( - (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) ))) - t2.fields forall (f2 => - t1_fields get f2.name match { + val t1_fields = + (t1.fields.foldLeft(Map[String, (Type, Orientation)]()))((map, f1) => map + (f1.name -> ((f1.tpe, f1.flip)))) + t2.fields.forall(f2 => + t1_fields.get(f2.name) match { case None => true case Some((f1_tpe, f1_flip)) => bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip)) @@ -155,79 +168,155 @@ object CheckTypes extends Pass { def ut: UIntType = UIntType(UnknownWidth) def st: SIntType = SIntType(UnknownWidth) - def run (c:Circuit) : Circuit = { + def run(c: Circuit): Circuit = { val errors = new Errors() def passive(t: Type): Boolean = t match { - case _: UIntType |_: SIntType => true + case _: UIntType | _: SIntType => true case tx: VectorType => passive(tx.tpe) - case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe)) + case tx: BundleType => tx.fields.forall(x => x.flip == Default && passive(x.tpe)) case tx => true } def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = { - def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = { + def checkAllTypes( + exprs: Seq[Expression], + okUInt: Boolean, + okSInt: Boolean, + okClock: Boolean, + okFix: Boolean, + okAsync: Boolean, + okInterval: Boolean + ): Unit = { exprs.foldLeft((false, false, false, false, false, false)) { - case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match { - case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval) - case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval) - case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval) - case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) - case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) - case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) - case UnknownType => - errors.append(new IllegalUnknownType(info, mname, e.serialize)) - (isUInt, isSInt, isClock, isFix, isAsync, isInterval) - case other => throwInternalError(s"Illegal Type: ${other.serialize}") - } + case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => + expr.tpe match { + case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval) + case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval) + case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval) + case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval) + case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval) + case i: IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true) + case UnknownType => + errors.append(new IllegalUnknownType(info, mname, e.serialize)) + (isUInt, isSInt, isClock, isFix, isAsync, isInterval) + case other => throwInternalError(s"Illegal Type: ${other.serialize}") + } } match { // (UInt, SInt, Clock, Fixed, Async, Interval) - case (isAll, false, false, false, false, false) if isAll == okUInt => - case (false, isAll, false, false, false, false) if isAll == okSInt => - case (false, false, isAll, false, false, false) if isAll == okClock => - case (false, false, false, isAll, false, false) if isAll == okFix => - case (false, false, false, false, isAll, false) if isAll == okAsync => + case (isAll, false, false, false, false, false) if isAll == okUInt => + case (false, isAll, false, false, false, false) if isAll == okSInt => + case (false, false, isAll, false, false, false) if isAll == okClock => + case (false, false, false, isAll, false, false) if isAll == okFix => + case (false, false, false, false, isAll, false) if isAll == okAsync => case (false, false, false, false, false, isAll) if isAll == okInterval => - case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) + case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize))) } } e.op match { case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval => - // All types are ok + // All types are ok case Dshl | Dshr => - checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) - checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false) + checkAllTypes( + Seq(e.args.head), + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) + checkAllTypes( + Seq(e.args(1)), + okUInt = true, + okSInt = false, + okClock = false, + okFix = false, + okAsync = false, + okInterval = false + ) case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) case Pad | Bits | Head | Tail => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = false + ) case Shl | Shr | Cat => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) case IncP | DecP | SetP => - checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true) + checkAllTypes( + e.args, + okUInt = false, + okSInt = false, + okClock = false, + okFix = true, + okAsync = false, + okInterval = true + ) case Wrap | Clip | Squeeze => - checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true) + checkAllTypes( + e.args, + okUInt = false, + okSInt = false, + okClock = false, + okFix = false, + okAsync = false, + okInterval = true + ) case _ => - checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false) + checkAllTypes( + e.args, + okUInt = true, + okSInt = true, + okClock = false, + okFix = false, + okAsync = false, + okInterval = false + ) } } - def check_types_e(info:Info, mname: String)(e: Expression): Unit = { + def check_types_e(info: Info, mname: String)(e: Expression): Unit = { e match { - case (e: WSubField) => e.expr.tpe match { - case (t: BundleType) => t.fields find (_.name == e.name) match { - case Some(_) => - case None => errors.append(new SubfieldNotInBundle(info, mname, e.name)) + case (e: WSubField) => + e.expr.tpe match { + case (t: BundleType) => + t.fields.find(_.name == e.name) match { + case Some(_) => + case None => errors.append(new SubfieldNotInBundle(info, mname, e.name)) + } + case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name)) + } + case (e: WSubIndex) => + e.expr.tpe match { + case (t: VectorType) if e.value < t.size => + case (t: VectorType) => + errors.append(new IndexTooLarge(info, mname, e.value)) + case _ => + errors.append(new IndexOnNonVector(info, mname)) } - case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name)) - } - case (e: WSubIndex) => e.expr.tpe match { - case (t: VectorType) if e.value < t.size => - case (t: VectorType) => - errors.append(new IndexTooLarge(info, mname, e.value)) - case _ => - errors.append(new IndexOnNonVector(info, mname)) - } case (e: WSubAccess) => e.expr.tpe match { case _: VectorType => @@ -256,11 +345,14 @@ object CheckTypes extends Pass { } case _ => } - e foreach check_types_e(info, mname) + e.foreach(check_types_e(info, mname)) } def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = { - val info = get_info(s) match { case NoInfo => minfo case x => x } + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } s match { case sx: Connect if !validConnect(sx) => val conMsg = sx.copy(info = NoInfo).serialize @@ -270,7 +362,7 @@ object CheckTypes extends Pass { errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr)) case sx: DefRegister => sx.tpe match { - case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) + case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname)) case t if !validConnect(sx.tpe, sx.init.tpe) => val conMsg = sx.copy(info = NoInfo).serialize @@ -285,11 +377,12 @@ object CheckTypes extends Pass { } case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) => errors.append(new PredNotUInt(info, mname)) - case sx: DefNode => sx.value.tpe match { - case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) - case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname)) - case t => - } + case sx: DefNode => + sx.value.tpe match { + case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) + case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname)) + case t => + } case sx: Attach => for (e <- sx.exprs) { e.tpe match { @@ -298,14 +391,14 @@ object CheckTypes extends Pass { } kind(e) match { case (InstanceKind | PortKind | WireKind) => - case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize)) + case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize)) } } case sx: Stop => if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) case sx: Print => - if (sx.args exists (x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) + if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) errors.append(new PrintfArgNotGround(info, mname)) if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) @@ -313,17 +406,18 @@ object CheckTypes extends Pass { if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname)) if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) - case sx: DefMemory => sx.dataType match { - case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) - case t => - } + case sx: DefMemory => + sx.dataType match { + case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) + case t => + } case _ => } - s foreach check_types_e(info, mname) - s foreach check_types_s(info, mname) + s.foreach(check_types_e(info, mname)) + s.foreach(check_types_s(info, mname)) } - c.modules foreach (m => m foreach check_types_s(m.info, m.name)) + c.modules.foreach(m => m.foreach(check_types_s(m.info, m.name))) errors.trigger() c } diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala index a7729ef8..f7fefa87 100644 --- a/src/main/scala/firrtl/passes/CheckWidths.scala +++ b/src/main/scala/firrtl/passes/CheckWidths.scala @@ -22,43 +22,49 @@ object CheckWidths extends Pass { /** The maximum allowed width for any circuit element */ val MaxWidth = 1000000 val DshlMaxWidth = getUIntWidth(MaxWidth) - class UninferredWidth (info: Info, target: String) extends PassException( - s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) - |$target""".stripMargin) - class UninferredBound (info: Info, target: String, bound: String) extends PassException( - s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) - |$target""".stripMargin) - class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException( - s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?) - |$target""".stripMargin) - class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException( - s"$info : [target $mname] Width too small for constant $b.") - class WidthTooBig(info: Info, mname: String, b: BigInt) extends PassException( - s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits") - class DshlTooBig(info: Info, mname: String) extends PassException( - s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits.") - class MultiBitAsClock(info: Info, mname: String) extends PassException( - s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.") - class MultiBitAsAsyncReset(info: Info, mname: String) extends PassException( - s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.") - class NegWidthException(info:Info, mname: String) extends PassException( - s"$info: [target $mname] Width cannot be negative or zero.") - class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException( - s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp.") - class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( - s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.") - class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException( - s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.") - class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException( - s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.") + class UninferredWidth(info: Info, target: String) + extends PassException(s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?) + |$target""".stripMargin) + class UninferredBound(info: Info, target: String, bound: String) + extends PassException(s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?) + |$target""".stripMargin) + class InvalidRange(info: Info, target: String, i: IntervalType) + extends PassException(s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?) + |$target""".stripMargin) + class WidthTooSmall(info: Info, mname: String, b: BigInt) + extends PassException(s"$info : [target $mname] Width too small for constant $b.") + class WidthTooBig(info: Info, mname: String, b: BigInt) + extends PassException(s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits") + class DshlTooBig(info: Info, mname: String) + extends PassException( + s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits." + ) + class MultiBitAsClock(info: Info, mname: String) + extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.") + class MultiBitAsAsyncReset(info: Info, mname: String) + extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.") + class NegWidthException(info: Info, mname: String) + extends PassException(s"$info: [target $mname] Width cannot be negative or zero.") + class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) + extends PassException( + s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp." + ) + class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) + extends PassException(s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.") + class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) + extends PassException(s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.") + class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) + extends PassException( + s"$info: [target $mname] Attach source $source and expression $eName must have identical widths." + ) class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim) - extends PassException({ - val toSqz = squeeze.args.head.serialize - val toSqzTpe = squeeze.args.head.tpe.serialize - val sqzTo = squeeze.args(1).serialize - val sqzToTpe = squeeze.args(1).tpe.serialize - s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe" - }) + extends PassException({ + val toSqz = squeeze.args.head.serialize + val toSqzTpe = squeeze.args.head.tpe.serialize + val sqzTo = squeeze.args(1).serialize + val sqzToTpe = squeeze.args(1).tpe.serialize + s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe" + }) def run(c: Circuit): Circuit = { val errors = new Errors() @@ -77,35 +83,35 @@ object CheckWidths extends Pass { def hasWidth(tpe: Type): Boolean = tpe match { case GroundType(IntWidth(w)) => true - case GroundType(_) => false - case _ => throwInternalError(s"hasWidth - $tpe") + case GroundType(_) => false + case _ => throwInternalError(s"hasWidth - $tpe") } def check_width_t(info: Info, target: Target)(t: Type): Unit = { t match { case tt: BundleType => tt.fields.foreach(check_width_f(info, target)) //Supports when l = u (if closed) - case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i - case i:IntervalType if i.range == Some(Nil) => + case i @ IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i + case i: IntervalType if i.range == Some(Nil) => errors.append(new InvalidRange(info, target.prettyPrint(" "), i)) i - case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u => + case i @ IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u => errors.append(new InvalidRange(info, target.prettyPrint(" "), i)) i - case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i - case i@IntervalType(_: IsKnown, _, _) => + case i @ IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i + case i @ IntervalType(_: IsKnown, _, _) => errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper")) i - case i@IntervalType(_, _: IsKnown, _) => + case i @ IntervalType(_, _: IsKnown, _) => errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower")) i - case i@IntervalType(_, _, _) => + case i @ IntervalType(_, _, _) => errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower")) errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper")) i - case tt => tt foreach check_width_t(info, target) + case tt => tt.foreach(check_width_t(info, target)) } - t foreach check_width_w(info, target, t) + t.foreach(check_width_w(info, target, t)) } def check_width_f(info: Info, target: Target)(f: Field): Unit = @@ -120,7 +126,8 @@ object CheckWidths extends Pass { errors.append(new WidthTooSmall(info, target.serialize, v)) case e @ DoPrim(op, Seq(a, b), _, tpe) => (op, a.tpe, b.tpe) match { - case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) => + case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) + if (ua < lb) || (ub < la) => errors.append(new DisjointSqueeze(info, target.serialize, e)) case (Dshl, at, bt) if (hasWidth(at) && bitWidth(bt) >= DshlMaxWidth) => errors.append(new DshlTooBig(info, target.serialize)) @@ -159,7 +166,6 @@ object CheckWidths extends Pass { } } - def check_width_e_dfs(info: Info, target: Target, expr: Expression): Unit = { val stack = collection.mutable.ArrayStack(expr) def push(e: Expression): Unit = stack.push(e) @@ -171,25 +177,31 @@ object CheckWidths extends Pass { } def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = { - val info = get_info(s) match { case NoInfo => minfo case x => x } - val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target } - s foreach check_width_e(info, target, 4) - s foreach check_width_s(info, target) - s foreach check_width_t(info, subRef) + val info = get_info(s) match { + case NoInfo => minfo + case x => x + } + val subRef = s match { + case sx: HasName => target.ref(sx.name) + case _ => target + } + s.foreach(check_width_e(info, target, 4)) + s.foreach(check_width_s(info, target)) + s.foreach(check_width_t(info, subRef)) s match { case Attach(infox, exprs) => - exprs.tail.foreach ( e => + exprs.tail.foreach(e => if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe)) errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize)) ) case sx: DefRegister => sx.reset.tpe match { case UIntType(IntWidth(w)) if w == 1 => - case AsyncResetType => - case ResetType => - case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) + case AsyncResetType => + case ResetType => + case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name)) } - if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) { + if (!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) { val conMsg = sx.copy(info = NoInfo).serialize errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init)) } @@ -197,14 +209,15 @@ object CheckWidths extends Pass { } } - def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe) + def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = + check_width_t(p.info, target.ref(p.name))(p.tpe) def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = { - m foreach check_width_p(m.info, circuit.module(m.name)) - m foreach check_width_s(m.info, circuit.module(m.name)) + m.foreach(check_width_p(m.info, circuit.module(m.name))) + m.foreach(check_width_s(m.info, circuit.module(m.name))) } - c.modules foreach check_width_m(CircuitTarget(c.main)) + c.modules.foreach(check_width_m(CircuitTarget(c.main))) errors.trigger() c } diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 544f90a6..55a9c53a 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -10,15 +10,16 @@ import firrtl.options.Dependency object CommonSubexpressionElimination extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats] ) + Seq( + Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats] + ) override def optionalPrerequisiteOf = - Seq( Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = false @@ -27,24 +28,26 @@ object CommonSubexpressionElimination extends Pass { val nodes = collection.mutable.HashMap[String, Expression]() def eliminateNodeRef(e: Expression): Expression = e match { - case WRef(name, tpe, kind, flow) => nodes get name match { - case Some(expression) => expressions get expression match { - case Some(cseName) if cseName != name => - WRef(cseName, tpe, kind, flow) + case WRef(name, tpe, kind, flow) => + nodes.get(name) match { + case Some(expression) => + expressions.get(expression) match { + case Some(cseName) if cseName != name => + WRef(cseName, tpe, kind, flow) + case _ => e + } case _ => e } - case _ => e - } - case _ => e map eliminateNodeRef + case _ => e.map(eliminateNodeRef) } def eliminateNodeRefs(s: Statement): Statement = { - s map eliminateNodeRef match { + s.map(eliminateNodeRef) match { case x: DefNode => nodes(x.name) = x.value expressions.getOrElseUpdate(x.value, x.name) x - case other => other map eliminateNodeRefs + case other => other.map(eliminateNodeRefs) } } @@ -54,7 +57,7 @@ object CommonSubexpressionElimination extends Pass { def run(c: Circuit): Circuit = { val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => Module(m.info, m.name, m.ports, cse(m.body)) + case m: Module => Module(m.info, m.name, m.ports, cse(m.body)) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala index 4a426209..baf7d4d5 100644 --- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -7,7 +7,7 @@ import firrtl.PrimOps._ import firrtl.ir._ import firrtl._ import firrtl.Mappers._ -import firrtl.Utils.{sub_type, module_type, field_type, max, throwInternalError} +import firrtl.Utils.{field_type, max, module_type, sub_type, throwInternalError} import firrtl.options.Dependency /** Replaces FixedType with SIntType, and correctly aligns all binary points @@ -15,71 +15,74 @@ import firrtl.options.Dependency object ConvertFixedToSInt extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck], - Dependency[RemoveIntervals] ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck], + Dependency[RemoveIntervals] + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform) = false def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) - if((point - p) > 0) { + if ((point - p) > 0) { DoPrim(Shl, Seq(e), Seq(point - p), UnknownType) } else if (point - p < 0) { DoPrim(Shr, Seq(e), Seq(p - point), UnknownType) } else e case FixedType(w, p) => throwInternalError(s"alignArg: shouldn't be here - $e") - case _ => e + case _ => e } def calcPoint(es: Seq[Expression]): BigInt = es.map(_.tpe match { case FixedType(IntWidth(w), IntWidth(p)) => p - case _ => BigInt(0) + case _ => BigInt(0) }).reduce(max(_, _)) def toSIntType(t: Type): Type = t match { case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w)) - case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t") - case _ => t map toSIntType + case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t") + case _ => t.map(toSIntType) } def run(c: Circuit): Circuit = { - val moduleTypes = mutable.HashMap[String,Type]() - def onModule(m:DefModule) : DefModule = { - val types = mutable.HashMap[String,Type]() - def updateExpType(e:Expression): Expression = e match { - case DoPrim(Mul, args, consts, tpe) => e map updateExpType - case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType - case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType - case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType - case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType + val moduleTypes = mutable.HashMap[String, Type]() + def onModule(m: DefModule): DefModule = { + val types = mutable.HashMap[String, Type]() + def updateExpType(e: Expression): Expression = e match { + case DoPrim(Mul, args, consts, tpe) => e.map(updateExpType) + case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe).map(updateExpType) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe).map(updateExpType) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe).map(updateExpType) + case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p).map(updateExpType) case DoPrim(op, args, consts, tpe) => val point = calcPoint(args) val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType) - newExp map updateExpType match { + newExp.map(updateExpType) match { case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) - case e => e + case e => e } case Mux(cond, tval, fval, tpe) => val point = calcPoint(Seq(tval, fval)) val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType) - newExp map updateExpType + newExp.map(updateExpType) case e: UIntLiteral => e case e: SIntLiteral => e - case _ => e map updateExpType match { - case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) - case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) - case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g) - case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g) - case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g) - } + case _ => + e.map(updateExpType) match { + case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) + case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) + case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g) + case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g) + case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g) + } } def updateStmtType(s: Statement): Statement = s match { case DefRegister(info, name, tpe, clock, reset, init) => val newType = toSIntType(tpe) types(name) = newType - DefRegister(info, name, newType, clock, reset, init) map updateExpType + DefRegister(info, name, newType, clock, reset, init).map(updateExpType) case DefWire(info, name, tpe) => val newType = toSIntType(tpe) types(name) = newType @@ -101,37 +104,34 @@ object ConvertFixedToSInt extends Pass { case Connect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) - Connect(info, loc, newExp) map updateExpType + Connect(info, loc, newExp).map(updateExpType) case PartialConnect(info, loc, exp) => val point = calcPoint(Seq(loc)) val newExp = alignArg(exp, point) - PartialConnect(info, loc, newExp) map updateExpType + PartialConnect(info, loc, newExp).map(updateExpType) // check Connect case, need to shl - case s => (s map updateStmtType) map updateExpType + case s => (s.map(updateStmtType)).map(updateExpType) } m.ports.foreach(p => types(p.name) = p.tpe) m match { - case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body)) - case m:ExtModule => m + case Module(info, name, ports, body) => Module(info, name, ports, updateStmtType(body)) + case m: ExtModule => m } } - val newModules = for(m <- c.modules) yield { - val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe))) + val newModules = for (m <- c.modules) yield { + val newPorts = m.ports.map(p => Port(p.info, p.name, p.direction, toSIntType(p.tpe))) m match { - case Module(info, name, ports, body) => Module(info,name,newPorts,body) - case ext: ExtModule => ext.copy(ports = newPorts) + case Module(info, name, ports, body) => Module(info, name, newPorts, body) + case ext: ExtModule => ext.copy(ports = newPorts) } } newModules.foreach(m => moduleTypes(m.name) = module_type(m)) /* @todo This should be moved outside */ - (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main )) + (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main)) } } - - - // vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala index d28e6399..4f849c5a 100644 --- a/src/main/scala/firrtl/passes/ExpandConnects.scala +++ b/src/main/scala/firrtl/passes/ExpandConnects.scala @@ -9,8 +9,7 @@ import firrtl.Mappers._ object ExpandConnects extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped + Seq(Dependency(PullMuxes), Dependency(ReplaceAccesses)) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform) = a match { case ResolveFlows => true @@ -19,62 +18,65 @@ object ExpandConnects extends Pass { def run(c: Circuit): Circuit = { def expand_connects(m: Module): Module = { - val flows = collection.mutable.LinkedHashMap[String,Flow]() + val flows = collection.mutable.LinkedHashMap[String, Flow]() def expand_s(s: Statement): Statement = { - def set_flow(e: Expression): Expression = e map set_flow match { + def set_flow(e: Expression): Expression = e.map(set_flow) match { case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name)) case ex: WSubField => val f = get_field(ex.expr.tpe, ex.name) val flowx = times(flow(ex.expr), f.flip) WSubField(ex.expr, ex.name, ex.tpe, flowx) - case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) + case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr)) case ex => ex } s match { - case sx: DefWire => flows(sx.name) = DuplexFlow; sx - case sx: DefRegister => flows(sx.name) = DuplexFlow; sx + case sx: DefWire => flows(sx.name) = DuplexFlow; sx + case sx: DefRegister => flows(sx.name) = DuplexFlow; sx case sx: WDefInstance => flows(sx.name) = SourceFlow; sx - case sx: DefMemory => flows(sx.name) = SourceFlow; sx + case sx: DefMemory => flows(sx.name) = SourceFlow; sx case sx: DefNode => flows(sx.name) = SourceFlow; sx case sx: IsInvalid => - val invalids = create_exps(sx.expr).flatMap { case expx => - flow(set_flow(expx)) match { + val invalids = create_exps(sx.expr).flatMap { + case expx => + flow(set_flow(expx)) match { case DuplexFlow => Some(IsInvalid(sx.info, expx)) - case SinkFlow => Some(IsInvalid(sx.info, expx)) - case _ => None - } + case SinkFlow => Some(IsInvalid(sx.info, expx)) + case _ => None + } } invalids.size match { - case 0 => EmptyStmt - case 1 => invalids.head - case _ => Block(invalids) + case 0 => EmptyStmt + case 1 => invalids.head + case _ => Block(invalids) } case sx: Connect => val locs = create_exps(sx.loc) val exps = create_exps(sx.expr) - Block(locs.zip(exps).map { case (locx, expx) => - to_flip(flow(locx)) match { + Block(locs.zip(exps).map { + case (locx, expx) => + to_flip(flow(locx)) match { case Default => Connect(sx.info, locx, expx) - case Flip => Connect(sx.info, expx, locx) - } + case Flip => Connect(sx.info, expx, locx) + } }) case sx: PartialConnect => val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default) val locs = create_exps(sx.loc) val exps = create_exps(sx.expr) - val stmts = ls map { case (x, y) => - locs(x).tpe match { - case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) - case _ => - to_flip(flow(locs(x))) match { - case Default => Connect(sx.info, locs(x), exps(y)) - case Flip => Connect(sx.info, exps(y), locs(x)) - } - } + val stmts = ls.map { + case (x, y) => + locs(x).tpe match { + case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) + case _ => + to_flip(flow(locs(x))) match { + case Default => Connect(sx.info, locs(x), exps(y)) + case Flip => Connect(sx.info, exps(y), locs(x)) + } + } } Block(stmts) - case sx => sx map expand_s + case sx => sx.map(expand_s) } } @@ -83,8 +85,8 @@ object ExpandConnects extends Pass { } val modulesx = c.modules.map { - case (m: ExtModule) => m - case (m: Module) => expand_connects(m) + case (m: ExtModule) => m + case (m: Module) => expand_connects(m) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index ab7f02db..14d5d3ef 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -28,21 +28,23 @@ import collection.mutable object ExpandWhens extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses) + ) ++ firrtl.stage.Forms.Resolved override def invalidates(a: Transform): Boolean = a match { case CheckInitialization | ResolveKinds | InferTypes => true - case _ => false + case _ => false } /** Returns circuit with when and last connection semantics resolved */ def run(c: Circuit): Circuit = { - val modulesx = c.modules map { + val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => onModule(m) + case m: Module => onModule(m) } Circuit(c.info, modulesx, c.main) } @@ -74,13 +76,12 @@ object ExpandWhens extends Pass { // Does an expression contain WVoid inserted in this pass? def containsVoid(e: Expression): Boolean = e match { - case WVoid => true + case WVoid => true case ValidIf(_, value, _) => memoizedVoid(value) - case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) - case _ => false + case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv) + case _ => false } - // Memoizes the node that holds a particular expression, if any val nodes = new NodeLookup @@ -95,18 +96,15 @@ object ExpandWhens extends Pass { * @param p predicate so far, used to update simulation constructs * @param s statement to expand */ - def expandWhens(netlist: Netlist, - defaults: Defaults, - p: Expression) - (s: Statement): Statement = s match { + def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression)(s: Statement): Statement = s match { // For each non-register declaration, update netlist with value WVoid for each sink reference // Return self, unchanged case stmt @ (_: DefNode | EmptyStmt) => stmt case w: DefWire => - netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow) map (ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid)) w case w: DefMemory => - netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow) map (ref => we(ref) -> WVoid)) + netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid)) w case w: WDefInstance => netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid)) @@ -151,82 +149,88 @@ object ExpandWhens extends Pass { // Process combined maps because we only want to create 1 mux for each node // present in the conseq and/or alt - val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) => - // Defaults in netlist get priority over those in defaults - val default = netlist get lvalue match { - case Some(v) => Some(v) - case None => getDefault(lvalue, defaults) - } - // info0 and info1 correspond to Mux infos, use info0 only if ValidIf - val (res, info0, info1) = default match { - case Some(defaultValue) => - val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) - val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) - (trueValue, falseValue) match { - case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) - case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) - case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) - case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) - } - case None => - // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt - (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) - } + val memos = (conseqNetlist ++ altNetlist).map { + case (lvalue, _) => + // Defaults in netlist get priority over those in defaults + val default = netlist.get(lvalue) match { + case Some(v) => Some(v) + case None => getDefault(lvalue, defaults) + } + // info0 and info1 correspond to Mux infos, use info0 only if ValidIf + val (res, info0, info1) = default match { + case Some(defaultValue) => + val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue)) + val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue)) + (trueValue, falseValue) match { + case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo) + case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo) + case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo) + case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo) + } + case None => + // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt + (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo) + } - res match { - // Don't create a node to hold mux trees with void values - // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches) - case e if containsVoid(e) => - netlist(lvalue) = e - memoizedVoid += e // remember that this was void - EmptyStmt - case _: ValidIf | _: Mux | _: DoPrim => nodes get res match { - case Some(name) => - netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + res match { + // Don't create a node to hold mux trees with void values + // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches) + case e if containsVoid(e) => + netlist(lvalue) = e + memoizedVoid += e // remember that this was void + EmptyStmt + case _: ValidIf | _: Mux | _: DoPrim => + nodes.get(res) match { + case Some(name) => + netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + EmptyStmt + case None => + val name = namespace.newTemp + nodes(res) = name + netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) + // Use MultiInfo constructor to preserve NoInfos + val info = new MultiInfo(List(sx.info, info0, info1)) + DefNode(info, name, res) + } + case _ => + netlist(lvalue) = res EmptyStmt - case None => - val name = namespace.newTemp - nodes(res) = name - netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow) - // Use MultiInfo constructor to preserve NoInfos - val info = new MultiInfo(List(sx.info, info0, info1)) - DefNode(info, name, res) } - case _ => - netlist(lvalue) = res - EmptyStmt - } } Block(Seq(conseqStmt, altStmt) ++ memos) - case block: Block => block map expandWhens(netlist, defaults, p) + case block: Block => block.map(expandWhens(netlist, defaults, p)) case _ => throwInternalError() } val netlist = new Netlist // Add ports to netlist - netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) => - getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid) + netlist ++= (m.ports.flatMap { + case Port(_, name, dir, tpe) => + getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid) }) // Do traversal and construct mutable datastructures val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body) val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet - val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ - combineAttaches(attaches.toSeq) ++ simlist) + val newBody = Block( + Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++ + combineAttaches(attaches.toSeq) ++ simlist + ) Module(m.info, m.name, m.ports, newBody) } - /** Returns all references to all sink leaf subcomponents of a reference */ private def getSinkRefs(n: String, t: Type, g: Flow): Seq[Expression] = { val exps = create_exps(WRef(n, t, ExpKind, g)) - exps.flatMap { case exp => - exp.tpe match { - case AnalogType(w) => None - case _ => flow(exp) match { - case (DuplexFlow | SinkFlow) => Some(exp) - case _ => None + exps.flatMap { + case exp => + exp.tpe match { + case AnalogType(w) => None + case _ => + flow(exp) match { + case (DuplexFlow | SinkFlow) => Some(exp) + case _ => None + } } - } } } @@ -238,7 +242,7 @@ object ExpandWhens extends Pass { def handleInvalid(k: WrappedExpression, info: Info): Statement = if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1) netlist.map { - case (k, WInvalid) => handleInvalid(k, NoInfo) + case (k, WInvalid) => handleInvalid(k, NoInfo) case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info) case (k, v) => val (info, expr) = unwrap(v) @@ -261,7 +265,7 @@ object ExpandWhens extends Pass { case Seq() => // None of these expressions is present in the attachMap AttachAcc(exprs, attachMap.size) case accs => // At least one expression present in the attachMap - val sorted = accs sortBy (_.idx) + val sorted = accs.sortBy(_.idx) AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx) } attachMap ++= acc.exprs.map(_ -> acc) @@ -274,10 +278,11 @@ object ExpandWhens extends Pass { private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = { defaults match { case Nil => None - case head :: tail => head get lvalue match { - case Some(p) => Some(p) - case None => getDefault(lvalue, tail) - } + case head :: tail => + head.get(lvalue) match { + case Some(p) => Some(p) + case None => getDefault(lvalue, tail) + } } } @@ -290,10 +295,12 @@ object ExpandWhens extends Pass { class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true @@ -301,6 +308,6 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration { } override def execute(a: CircuitState): CircuitState = - Seq(ExpandWhens, CheckInitialization).foldLeft(a){ case (acc, tx) => tx.transform(acc) } + Seq(ExpandWhens, CheckInitialization).foldLeft(a) { case (acc, tx) => tx.transform(acc) } } diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala index a16205a7..f393d8a5 100644 --- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala +++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala @@ -13,9 +13,7 @@ import firrtl.options.Dependency class InferBinaryPoints extends Pass { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ResolveFlows) ) + Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows)) override def optionalPrerequisiteOf = Seq.empty @@ -23,12 +21,12 @@ class InferBinaryPoints extends Pass { private val constraintSolver = new ConstraintSolver() - private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match { - case (UIntType(w1), UIntType(w2)) => - case (SIntType(w1), SIntType(w2)) => - case (ClockType, ClockType) => - case (ResetType, _) => - case (_, ResetType) => + private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1, t2) match { + case (UIntType(w1), UIntType(w2)) => + case (SIntType(w1), SIntType(w2)) => + case (ClockType, ClockType) => + case (ResetType, _) => + case (_, ResetType) => case (AsyncResetType, AsyncResetType) => case (FixedType(w1, p1), FixedType(w2, p2)) => constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) @@ -36,78 +34,86 @@ class InferBinaryPoints extends Pass { constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) case (AnalogType(w1), AnalogType(w2)) => case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields) foreach { case (f1, f2) => - (f1.flip, f2.flip) match { - case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) - case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) - case _ => sys.error("Shouldn't be here") - } + (t1.fields.zip(t2.fields)).foreach { + case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } } case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other") } - private def addDecConstraints(t: Type): Type = t map addDecConstraints - private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match { + private def addDecConstraints(t: Type): Type = t.map(addDecConstraints) + private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s.map(addDecConstraints) match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps) foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + (locs.zip(exps)).foreach { + case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) val locs = create_exps(pc.loc) val exps = create_exps(pc.expr) - ls foreach { case (x, y) => - val loc = locs(x) - val exp = exps(y) - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + ls.foreach { + case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } pc case r: DefRegister => - addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) + addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) r - case x => x map addStmtConstraints(mt) + case x => x.map(addStmtConstraints(mt)) } private def fixWidth(w: Width): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => w - case _ => sys.error("Shouldn't be here") + case None => w + case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type): Type = t map fixType map fixWidth match { + private def fixType(t: Type): Type = t.map(fixType).map(fixWidth) match { case IntervalType(l, u, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => p - case _ => sys.error("Shouldn't be here") + case None => p + case _ => sys.error("Shouldn't be here") } IntervalType(l, u, px) case FixedType(w, p) => val px = constraintSolver.get(p) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => p - case _ => sys.error("Shouldn't be here") + case None => p + case _ => sys.error("Shouldn't be here") } FixedType(w, px) case x => x } - private def fixStmt(s: Statement): Statement = s map fixStmt map fixType - private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) - def run (c: Circuit): Circuit = { + private def fixStmt(s: Statement): Statement = s.map(fixStmt).map(fixType) + private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe)) + def run(c: Circuit): Circuit = { val ct = CircuitTarget(c.main) - c.modules foreach (m => m map addStmtConstraints(ct.module(m.name))) - c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)}) + c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name)))) + c.modules.foreach(_.ports.foreach { p => addDecConstraints(p.tpe) }) constraintSolver.solve() - InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + InferTypes.run( + c.copy(modules = + c.modules.map( + _.map(fixPort) + .map(fixStmt) + ) + ) + ) } } diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala index 6cc9f2b9..4d14e7ff 100644 --- a/src/main/scala/firrtl/passes/InferTypes.scala +++ b/src/main/scala/firrtl/passes/InferTypes.scala @@ -23,16 +23,16 @@ object InferTypes extends Pass { def remove_unknowns_b(b: Bound): Bound = b match { case UnknownBound => VarBound(namespace.newName("b")) - case k => k + case k => k } def remove_unknowns_w(w: Width): Width = w match { case UnknownWidth => VarWidth(namespace.newName("w")) - case wx => wx + case wx => wx } def remove_unknowns(t: Type): Type = { - t map remove_unknowns map remove_unknowns_w match { + t.map(remove_unknowns).map(remove_unknowns_w) match { case IntervalType(l, u, p) => IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p) case x => x @@ -41,18 +41,18 @@ object InferTypes extends Pass { // we first need to remove the unknown widths and bounds from all ports, // as their type will determine the module types - val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) + val portsKnown = c.modules.map(_.map { p: Port => p.copy(tpe = remove_unknowns(p.tpe)) }) val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap def infer_types_e(types: TypeLookup)(e: Expression): Expression = - e map infer_types_e(types) match { - case e: WRef => e copy (tpe = types(e.name)) - case e: WSubField => e copy (tpe = field_type(e.expr.tpe, e.name)) - case e: WSubIndex => e copy (tpe = sub_type(e.expr.tpe)) - case e: WSubAccess => e copy (tpe = sub_type(e.expr.tpe)) - case e: DoPrim => PrimOps.set_primop_type(e) - case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval)) - case e: ValidIf => e copy (tpe = e.value.tpe) + e.map(infer_types_e(types)) match { + case e: WRef => e.copy(tpe = types(e.name)) + case e: WSubField => e.copy(tpe = field_type(e.expr.tpe, e.name)) + case e: WSubIndex => e.copy(tpe = sub_type(e.expr.tpe)) + case e: WSubAccess => e.copy(tpe = sub_type(e.expr.tpe)) + case e: DoPrim => PrimOps.set_primop_type(e) + case e: Mux => e.copy(tpe = mux_type_and_widths(e.tval, e.fval)) + case e: ValidIf => e.copy(tpe = e.value.tpe) case e @ (_: UIntLiteral | _: SIntLiteral) => e } @@ -60,37 +60,37 @@ object InferTypes extends Pass { case sx: WDefInstance => val t = mtypes(sx.module) types(sx.name) = t - sx copy (tpe = t) + sx.copy(tpe = t) case sx: DefWire => val t = remove_unknowns(sx.tpe) types(sx.name) = t - sx copy (tpe = t) + sx.copy(tpe = t) case sx: DefNode => - val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode] + val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode] val t = remove_unknowns(sxx.value.tpe) types(sx.name) = t sxx case sx: DefRegister => val t = remove_unknowns(sx.tpe) types(sx.name) = t - sx copy (tpe = t) map infer_types_e(types) + sx.copy(tpe = t).map(infer_types_e(types)) case sx: DefMemory => // we need to remove the unknowns from the data type so that all ports get the same VarWidth val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType)) types(sx.name) = MemPortUtils.memType(knownDataType) knownDataType - case sx => sx map infer_types_s(types) map infer_types_e(types) + case sx => sx.map(infer_types_s(types)).map(infer_types_e(types)) } def infer_types_p(types: TypeLookup)(p: Port): Port = { val t = remove_unknowns(p.tpe) types(p.name) = t - p copy (tpe = t) + p.copy(tpe = t) } def infer_types(m: DefModule): DefModule = { val types = new TypeLookup - m map infer_types_p(types) map infer_types_s(types) + m.map(infer_types_p(types)).map(infer_types_s(types)) } c.copy(modules = portsKnown.map(infer_types)) @@ -108,45 +108,45 @@ object CInferTypes extends Pass { private type TypeLookup = collection.mutable.HashMap[String, Type] def run(c: Circuit): Circuit = { - val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap - - def infer_types_e(types: TypeLookup)(e: Expression) : Expression = - e map infer_types_e(types) match { - case (e: Reference) => e copy (tpe = types.getOrElse(e.name, UnknownType)) - case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name)) - case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe)) - case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe)) - case (e: DoPrim) => PrimOps.set_primop_type(e) - case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval)) - case (e: ValidIf) => e copy (tpe = e.value.tpe) - case e @ (_: UIntLiteral | _: SIntLiteral) => e + val mtypes = (c.modules.map(m => m.name -> module_type(m))).toMap + + def infer_types_e(types: TypeLookup)(e: Expression): Expression = + e.map(infer_types_e(types)) match { + case (e: Reference) => e.copy(tpe = types.getOrElse(e.name, UnknownType)) + case (e: SubField) => e.copy(tpe = field_type(e.expr.tpe, e.name)) + case (e: SubIndex) => e.copy(tpe = sub_type(e.expr.tpe)) + case (e: SubAccess) => e.copy(tpe = sub_type(e.expr.tpe)) + case (e: DoPrim) => PrimOps.set_primop_type(e) + case (e: Mux) => e.copy(tpe = mux_type(e.tval, e.fval)) + case (e: ValidIf) => e.copy(tpe = e.value.tpe) + case e @ (_: UIntLiteral | _: SIntLiteral) => e } def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match { case sx: DefRegister => types(sx.name) = sx.tpe - sx map infer_types_e(types) + sx.map(infer_types_e(types)) case sx: DefWire => types(sx.name) = sx.tpe sx case sx: DefNode => - val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode] + val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode] types(sxx.name) = sxx.value.tpe sxx case sx: DefMemory => types(sx.name) = MemPortUtils.memType(sx) sx case sx: CDefMPort => - val t = types getOrElse(sx.mem, UnknownType) + val t = types.getOrElse(sx.mem, UnknownType) types(sx.name) = t - sx copy (tpe = t) + sx.copy(tpe = t) case sx: CDefMemory => types(sx.name) = sx.tpe sx case sx: DefInstance => types(sx.name) = mtypes(sx.module) sx - case sx => sx map infer_types_s(types) map infer_types_e(types) + case sx => sx.map(infer_types_s(types)).map(infer_types_e(types)) } def infer_types_p(types: TypeLookup)(p: Port): Port = { @@ -156,9 +156,9 @@ object CInferTypes extends Pass { def infer_types(m: DefModule): DefModule = { val types = new TypeLookup - m map infer_types_p(types) map infer_types_s(types) + m.map(infer_types_p(types)).map(infer_types_s(types)) } - c copy (modules = c.modules map infer_types) + c.copy(modules = c.modules.map(infer_types)) } } diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 3720523b..eae9690f 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -14,7 +14,7 @@ import firrtl.options.Dependency object InferWidths { def apply(): InferWidths = new InferWidths() - def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver) + def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver) def execute(state: CircuitState): CircuitState = new InferWidths().execute(state) } @@ -22,12 +22,14 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = { val newLoc :: newExp :: Nil = Seq(loc, exp).map { target => renameMap.get(target) match { - case None => Some(target) - case Some(Seq()) => None + case None => Some(target) + case Some(Seq()) => None case Some(Seq(one)) => Some(one) case Some(many) => - throw new Exception(s"Target below is an AggregateType, which " + - "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + throw new Exception( + s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint() + ) } } @@ -60,28 +62,31 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg * * Uses firrtl.constraint package to infer widths */ -class InferWidths extends Transform - with ResolvedAnnotationPaths - with DependencyAPIMigration { +class InferWidths extends Transform with ResolvedAnnotationPaths with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferBinaryPoints], - Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals] + ) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform) = false val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation]) - private def addTypeConstraints - (r1: ReferenceTarget, r2: ReferenceTarget) - (t1: Type, t2: Type) - (implicit constraintSolver: ConstraintSolver) - : Unit = (t1,t2) match { + private def addTypeConstraints( + r1: ReferenceTarget, + r2: ReferenceTarget + )(t1: Type, + t2: Type + )( + implicit constraintSolver: ConstraintSolver + ): Unit = (t1, t2) match { case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) - case (ClockType, ClockType) => + case (ClockType, ClockType) => case (FixedType(w1, p1), FixedType(w2, p2)) => constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint("")) constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) @@ -93,101 +98,119 @@ class InferWidths extends Transform constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint("")) constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint("")) case (t1: BundleType, t2: BundleType) => - (t1.fields zip t2.fields) foreach { case (f1, f2) => - (f1.flip, f2.flip) match { - case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) - case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) - case _ => sys.error("Shouldn't be here") - } + (t1.fields.zip(t2.fields)).foreach { + case (f1, f2) => + (f1.flip, f2.flip) match { + case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe) + case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe) + case _ => sys.error("Shouldn't be here") + } } case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe) case (AsyncResetType, AsyncResetType) => Nil - case (ResetType, _) => Nil - case (_, ResetType) => Nil + case (ResetType, _) => Nil + case (_, ResetType) => Nil } - private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver) - : Expression = e map addExpConstraints match { - case m@Mux(p, tVal, fVal, t) => - constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") - m - case other => other - } + private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver): Expression = + e.map(addExpConstraints) match { + case m @ Mux(p, tVal, fVal, t) => + constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W") + m + case other => other + } - private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver) - : Statement = s map addExpConstraints match { + private def addStmtConstraints( + mt: ModuleTarget + )(s: Statement + )( + implicit constraintSolver: ConstraintSolver + ): Statement = s.map(addExpConstraints) match { case c: Connect => val n = get_size(c.loc.tpe) val locs = create_exps(c.loc) val exps = create_exps(c.expr) - (locs zip exps).foreach { case (loc, exp) => - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } - } + (locs.zip(exps)).foreach { + case (loc, exp) => + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } + } c case pc: PartialConnect => val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default) val locs = create_exps(pc.loc) val exps = create_exps(pc.expr) - ls foreach { case (x, y) => - val loc = locs(x) - val exp = exps(y) - to_flip(flow(loc)) match { - case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) - case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) - } + ls.foreach { + case (x, y) => + val loc = locs(x) + val exp = exps(y) + to_flip(flow(loc)) match { + case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe) + case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe) + } } pc case r: DefRegister => - if (r.reset.tpe != AsyncResetType ) { + if (r.reset.tpe != AsyncResetType) { addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1))) } addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe) r - case a@Attach(_, exprs) => - val widths = exprs map (e => (e, getWidth(e.tpe))) + case a @ Attach(_, exprs) => + val widths = exprs.map(e => (e, getWidth(e.tpe))) val maxWidth = IsMax(widths.map(x => width2constraint(x._2))) - widths.foreach { case (e, w) => - constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint("")) + widths.foreach { + case (e, w) => + constraintSolver.addGeq( + w, + CalcWidth(maxWidth), + Target.asTarget(mt)(e).prettyPrint(""), + mt.ref(a.serialize).prettyPrint("") + ) } a case c: Conditionally => addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1))) - c map addStmtConstraints(mt) - case x => x map addStmtConstraints(mt) + c.map(addStmtConstraints(mt)) + case x => x.map(addStmtConstraints(mt)) } private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match { case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt) - case None => w - case _ => sys.error("Shouldn't be here") + case None => w + case _ => sys.error("Shouldn't be here") } - private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match { + private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t.map(fixType).map(fixWidth) match { case IntervalType(l, u, p) => val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match { case (Some(x: Bound), Some(y: Bound)) => (x, y) case (None, None) => (l, u) - case x => sys.error(s"Shouldn't be here: $x") - + case x => sys.error(s"Shouldn't be here: $x") } IntervalType(lx, ux, fixWidth(p)) case FixedType(w, p) => FixedType(w, fixWidth(p)) - case x => x + case x => x } - private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType + private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = + s.map(fixStmt).map(fixType) private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = { Port(p.info, p.name, p.direction, fixType(p.tpe)) } - def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = { + def run(c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = { val ct = CircuitTarget(c.main) - c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name))) + c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name)))) constraintSolver.solve() - val ret = InferTypes.run(c.copy(modules = c.modules map (_ - map fixPort - map fixStmt))) + val ret = InferTypes.run( + c.copy(modules = + c.modules.map( + _.map(fixPort) + .map(fixStmt) + ) + ) + ) constraintSolver.clear() ret } @@ -200,15 +223,16 @@ class InferWidths extends Transform def getDeclTypes(modName: String)(stmt: Statement): Unit = { val pairOpt = stmt match { - case w: DefWire => Some(w.name -> w.tpe) - case r: DefRegister => Some(r.name -> r.tpe) - case n: DefNode => Some(n.name -> n.value.tpe) + case w: DefWire => Some(w.name -> w.tpe) + case r: DefRegister => Some(r.name -> r.tpe) + case n: DefNode => Some(n.name -> n.value.tpe) case i: WDefInstance => Some(i.name -> i.tpe) - case m: DefMemory => Some(m.name -> MemPortUtils.memType(m)) + case m: DefMemory => Some(m.name -> MemPortUtils.memType(m)) case other => None } - pairOpt.foreach { case (ref, tpe) => - typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe) + pairOpt.foreach { + case (ref, tpe) => + typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe) } stmt.foreachStmt(getDeclTypes(modName)) } @@ -223,14 +247,20 @@ class InferWidths extends Transform } state.annotations.foreach { - case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => - val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target => - val baseType = typeMap.getOrElse(target.copy(component = Seq.empty), - throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint())) + case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal => + val locType :: expType :: Nil = Seq(anno.loc, anno.exp).map { target => + val baseType = typeMap.getOrElse( + target.copy(component = Seq.empty), + throw new Exception( + s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint() + ) + ) val leafType = target.componentType(baseType) if (leafType.isInstanceOf[AggregateType]) { - throw new Exception(s"Target below is an AggregateType, which " + - "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()) + throw new Exception( + s"Target below is an AggregateType, which " + + "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint() + ) } leafType diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index ad963b19..316878fb 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -32,89 +32,100 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe override def invalidates(a: Transform): Boolean = a == ResolveKinds - private [firrtl] val inlineDelim: String = "_" + private[firrtl] val inlineDelim: String = "_" val options = Seq( new ShellOption[Seq[String]]( longOption = "inline", - toAnnotationSeq = (a: Seq[String]) => a.map { value => - value.split('.') match { - case Array(circuit) => - InlineAnnotation(CircuitName(circuit)) - case Array(circuit, module) => - InlineAnnotation(ModuleName(module, CircuitName(circuit))) - case Array(circuit, module, inst) => - InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit)))) - } - } :+ RunFirrtlTransformAnnotation(new InlineInstances), + toAnnotationSeq = (a: Seq[String]) => + a.map { value => + value.split('.') match { + case Array(circuit) => + InlineAnnotation(CircuitName(circuit)) + case Array(circuit, module) => + InlineAnnotation(ModuleName(module, CircuitName(circuit))) + case Array(circuit, module, inst) => + InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit)))) + } + } :+ RunFirrtlTransformAnnotation(new InlineInstances), helpText = "Inline selected modules", shortOption = Some("fil"), - helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]") ) ) - - private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = - anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) { - case ((modNames, instNames), ann) => ann match { - case InlineAnnotation(CircuitName(c)) => - (circuit.modules.collect { - case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) - }.toSet, instNames) - case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) - case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) - case _ => (modNames, instNames) - } - } - - def execute(state: CircuitState): CircuitState = { - // TODO Add error check for more than one annotation for inlining - val (modNames, instNames) = collectAnns(state.circuit, state.annotations) - if (modNames.nonEmpty || instNames.nonEmpty) { - run(state.circuit, modNames, instNames, state.annotations) - } else { - state - } - } - - // Checks the following properties: - // 1) All annotated modules exist - // 2) All annotated modules are InModules (can be inlined) - // 3) All annotated instances exist, and their modules can be inline - def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = { - val errors = mutable.ArrayBuffer[PassException]() - val moduleMap = InstanceKeyGraph(c).moduleMap - def checkExists(name: String): Unit = - if (!moduleMap.contains(name)) - errors += new PassException(s"Annotated module does not exist: $name") - def checkExternal(name: String): Unit = moduleMap(name) match { - case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name") - case _ => - } - def checkInstance(cn: ComponentName): Unit = { - var containsCN = false - def onStmt(name: String)(s: Statement): Statement = { - s match { - case WDefInstance(_, inst_name, module_name, tpe) => - if (name == inst_name) { - containsCN = true - checkExternal(module_name) - } - case _ => + helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]") + ) + ) + + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) { + case ((modNames, instNames), ann) => + ann match { + case InlineAnnotation(CircuitName(c)) => + ( + circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, + instNames + ) + case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => (modNames, instNames) + } + } + + def execute(state: CircuitState): CircuitState = { + // TODO Add error check for more than one annotation for inlining + val (modNames, instNames) = collectAnns(state.circuit, state.annotations) + if (modNames.nonEmpty || instNames.nonEmpty) { + run(state.circuit, modNames, instNames, state.annotations) + } else { + state + } + } + + // Checks the following properties: + // 1) All annotated modules exist + // 2) All annotated modules are InModules (can be inlined) + // 3) All annotated instances exist, and their modules can be inline + def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = { + val errors = mutable.ArrayBuffer[PassException]() + val moduleMap = InstanceKeyGraph(c).moduleMap + def checkExists(name: String): Unit = + if (!moduleMap.contains(name)) + errors += new PassException(s"Annotated module does not exist: $name") + def checkExternal(name: String): Unit = moduleMap(name) match { + case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name") + case _ => + } + def checkInstance(cn: ComponentName): Unit = { + var containsCN = false + def onStmt(name: String)(s: Statement): Statement = { + s match { + case WDefInstance(_, inst_name, module_name, tpe) => + if (name == inst_name) { + containsCN = true + checkExternal(module_name) } - s map onStmt(name) - } - onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body) - if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}") + case _ => + } + s.map(onStmt(name)) } + onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body) + if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}") + } - moduleNames.foreach{mn => checkExists(mn.name)} - if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) - moduleNames.foreach{mn => checkExternal(mn.name)} - if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) - instanceNames.foreach{cn => checkInstance(cn)} - if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) - } - + moduleNames.foreach { mn => checkExists(mn.name) } + if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) + moduleNames.foreach { mn => checkExternal(mn.name) } + if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) + instanceNames.foreach { cn => checkInstance(cn) } + if (errors.nonEmpty) throw new PassExceptions(errors.toSeq) + } - def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = { + def run( + c: Circuit, + modsToInline: Set[ModuleName], + instsToInline: Set[ComponentName], + annos: AnnotationSeq + ): CircuitState = { def getInstancesOf(c: Circuit, modules: Set[String]): Set[(OfModule, Instance)] = c.modules.foldLeft(Set[(OfModule, Instance)]()) { (set, d) => d match { @@ -125,7 +136,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) => instances += (OfModule(m.name) -> Instance(instName)) s - case sx => sx map findInstances + case sx => sx.map(findInstances) } findInstances(m.body) instances.toSet ++ set @@ -135,7 +146,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe // Check annotations and circuit match up check(c, modsToInline, instsToInline) val flatModules = modsToInline.map(m => m.name) - val flatInstances: Set[(OfModule, Instance)] = instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules) + val flatInstances: Set[(OfModule, Instance)] = + instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules) val iGraph = InstanceKeyGraph(c) val namespaceMap = collection.mutable.Map[String, Namespace]() // Map of Module name to Map of instance name to Module name @@ -144,11 +156,13 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe /** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */ def appendNamePrefix( currentModule: IsModule, - nextModule: IsModule, - prefix: String, - ns: Namespace, - renames: mutable.HashMap[String, String], - renameMap: RenameMap)(s: Statement): Statement = { + nextModule: IsModule, + prefix: String, + ns: Namespace, + renames: mutable.HashMap[String, String], + renameMap: RenameMap + )(s: Statement + ): Statement = { def onName(ofModuleOpt: Option[String])(name: String) = { if (prefix.nonEmpty && !ns.tryName(prefix + name)) { throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists") @@ -164,25 +178,29 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe } s match { - case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) - case other => s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) + case s: WDefInstance => + s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) + case other => + s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap)) } } /** Modify all references */ def appendRefPrefix( currentModule: IsModule, - renames: mutable.HashMap[String, String])(s: Statement): Statement = { - def onExpr(e: Expression): Expression = e match { - case wr@ WRef(name, _, _, _) => - renames.get(name) match { - case Some(prefixedName) => wr.copy(name = prefixedName) - case None => wr - } - case ex => ex.map(onExpr) - } - s.map(onExpr).map(appendRefPrefix(currentModule, renames)) + renames: mutable.HashMap[String, String] + )(s: Statement + ): Statement = { + def onExpr(e: Expression): Expression = e match { + case wr @ WRef(name, _, _, _) => + renames.get(name) match { + case Some(prefixedName) => wr.copy(name = prefixedName) + case None => wr + } + case ex => ex.map(onExpr) } + s.map(onExpr).map(appendRefPrefix(currentModule, renames)) + } val cache = mutable.HashMap.empty[ModuleTarget, Statement] @@ -194,16 +212,19 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val (renamesMap, renamesSeq) = { val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)] // compute instance graph - instMaps.foreach { case (grandParentOfMod, parents) => - parents.foreach { case (parentInst, parentOfMod) => - val from = grandParentOfMod -> parentInst - mutableDiGraph.addVertex(from) - instMaps(parentOfMod).foreach { case (childInst, _) => - val to = parentOfMod -> childInst - mutableDiGraph.addVertex(to) - mutableDiGraph.addEdge(from, to) + instMaps.foreach { + case (grandParentOfMod, parents) => + parents.foreach { + case (parentInst, parentOfMod) => + val from = grandParentOfMod -> parentInst + mutableDiGraph.addVertex(from) + instMaps(parentOfMod).foreach { + case (childInst, _) => + val to = parentOfMod -> childInst + mutableDiGraph.addVertex(to) + mutableDiGraph.addEdge(from, to) + } } - } } val diGraph = DiGraph(mutableDiGraph) @@ -226,10 +247,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe } def fixupRefs( - instMap: collection.Map[Instance, OfModule], - currentModule: IsModule)(e: Expression): Expression = { + instMap: collection.Map[Instance, OfModule], + currentModule: IsModule + )(e: Expression + ): Expression = { e match { - case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) => + case wsf @ WSubField(wr @ WRef(ref, _, InstanceKind, _), field, tpe, gen) => val inst = currentModule.instOf(ref, instMap(Instance(ref)).value) val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance)) val port = inst.ref(field) @@ -242,12 +265,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe } case None => wsf } - case wr@ WRef(name, _, InstanceKind, _) => + case wr @ WRef(name, _, InstanceKind, _) => val inst = currentModule.instOf(name, instMap(Instance(name)).value) val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance)) val comp = currentModule.ref(name) renamesOpt.flatMap(_.get(comp)).getOrElse(Seq(comp)) match { - case Seq(car: ReferenceTarget) => wr.copy(name=car.ref) + case Seq(car: ReferenceTarget) => wr.copy(name = car.ref) } case ex => ex.map(fixupRefs(instMap, currentModule)) } @@ -258,7 +281,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName))) val instMap = instMaps(OfModule(currentModuleName)) s match { - case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) => + case wDef @ WDefInstance(_, instName, modName, _) + if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) => val renames = renamesMap(OfModule(currentModuleName) -> Instance(instName)) val toInline = iGraph.moduleMap(modName) match { case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}") @@ -269,7 +293,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe val bodyx = { val module = currentModule.copy(module = modName) - cache.getOrElseUpdate(module, Block(ports :+ toInline.body) map onStmt(module)) + cache.getOrElseUpdate(module, Block(ports :+ toInline.body).map(onStmt(module))) } val names = "" +: Uniquify @@ -294,14 +318,14 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe renamedBody case sx => sx - .map(fixupRefs(instMap, currentModule)) - .map(onStmt(currentModule)) + .map(fixupRefs(instMap, currentModule)) + .map(onStmt(currentModule)) } } val flatCircuit = c.copy(modules = c.modules.flatMap { case m if flatModules.contains(m.name) => None - case m => + case m => Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main))))) }) diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala index 8b7b733a..5d59e075 100644 --- a/src/main/scala/firrtl/passes/Legalize.scala +++ b/src/main/scala/firrtl/passes/Legalize.scala @@ -1,11 +1,11 @@ package firrtl.passes import firrtl.PrimOps._ -import firrtl.Utils.{BoolType, error, zero} +import firrtl.Utils.{error, zero, BoolType} import firrtl.ir._ import firrtl.options.Dependency import firrtl.transforms.ConstantPropagation -import firrtl.{Transform, bitWidth} +import firrtl.{bitWidth, Transform} import firrtl.Mappers._ // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts @@ -62,30 +62,31 @@ object Legalize extends Pass { } else { val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) val expr = t match { - case UIntType(_) => bits - case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) + case UIntType(_) => bits + case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) } Connect(c.info, c.loc, expr) } } - def run (c: Circuit): Circuit = { - def legalizeE(expr: Expression): Expression = expr map legalizeE match { - case prim: DoPrim => prim.op match { - case Shr => legalizeShiftRight(prim) - case Pad => legalizePad(prim) - case Bits | Head | Tail => legalizeBitExtract(prim) - case _ => prim - } + def run(c: Circuit): Circuit = { + def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match { + case prim: DoPrim => + prim.op match { + case Shr => legalizeShiftRight(prim) + case Pad => legalizePad(prim) + case Bits | Head | Tail => legalizeBitExtract(prim) + case _ => prim + } case e => e // respect pre-order traversal } - def legalizeS (s: Statement): Statement = { + def legalizeS(s: Statement): Statement = { val legalizedStmt = s match { case c: Connect => legalizeConnect(c) case _ => s } - legalizedStmt map legalizeS map legalizeE + legalizedStmt.map(legalizeS).map(legalizeE) } - c copy (modules = c.modules map (_ map legalizeS)) + c.copy(modules = c.modules.map(_.map(legalizeS))) } } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index ace4f3e8..ad608cec 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -3,8 +3,26 @@ package firrtl.passes import firrtl.analyses.{InstanceKeyGraph, SymbolTable} -import firrtl.annotations.{CircuitTarget, MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, ReferenceTarget} -import firrtl.{CircuitForm, CircuitState, DependencyAPIMigration, InstanceKind, Kind, MemKind, PortKind, RenameMap, Transform, UnknownForm, Utils} +import firrtl.annotations.{ + CircuitTarget, + MemoryInitAnnotation, + MemoryRandomInitAnnotation, + ModuleTarget, + ReferenceTarget +} +import firrtl.{ + CircuitForm, + CircuitState, + DependencyAPIMigration, + InstanceKind, + Kind, + MemKind, + PortKind, + RenameMap, + Transform, + UnknownForm, + Utils +} import firrtl.ir._ import firrtl.options.Dependency import firrtl.stage.TransformManager.TransformDependency @@ -20,18 +38,19 @@ import scala.collection.mutable object LowerTypes extends Transform with DependencyAPIMigration { override def prerequisites: Seq[TransformDependency] = Seq( Dependency(RemoveAccesses), // we require all SubAccess nodes to have been removed - Dependency(CheckTypes), // we require all types to be correct - Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType) - Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded + Dependency(CheckTypes), // we require all types to be correct + Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType) + Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded ) - override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty + override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty override def invalidates(a: Transform): Boolean = a match { case ResolveFlows => true // we generate UnknownFlow for now (could be fixed) - case _ => false + case _ => false } /** Delimiter used in lowering names */ val delim = "_" + /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name * @param e [[firrtl.ir.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]] * @return Lowered name of e @@ -39,8 +58,8 @@ object LowerTypes extends Transform with DependencyAPIMigration { */ def loweredName(e: Expression): String = e match { case e: Reference => e.name - case e: SubField => s"${loweredName(e.expr)}$delim${e.name}" - case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}" + case e: SubField => s"${loweredName(e.expr)}$delim${e.name}" + case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}" } def loweredName(s: Seq[String]): String = s.mkString(delim) @@ -48,7 +67,7 @@ object LowerTypes extends Transform with DependencyAPIMigration { // When memories are lowered to ground type, we have to fix the init annotation or error on it. val (memInitAnnos, otherAnnos) = state.annotations.partition { case _: MemoryRandomInitAnnotation => false - case _: MemoryInitAnnotation => true + case _: MemoryInitAnnotation => true case _ => false } val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule) @@ -61,14 +80,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3) // chain module renames in topological order - val moduleRenames = resultAndRenames.map{ case(m,r, _) => m.name -> r }.toMap + val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator - val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a,b) => a.andThen(b)) + val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b)) state.copy(circuit = result, renames = Some(renames), annotations = newAnnos) } - private def onModule(c: CircuitTarget, m: DefModule, memoryInit: Seq[MemoryInitAnnotation]): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { + private def onModule( + c: CircuitTarget, + m: DefModule, + memoryInit: Seq[MemoryInitAnnotation] + ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = { val renameMap = RenameMap() val ref = c.module(m.name) @@ -86,26 +109,36 @@ object LowerTypes extends Transform with DependencyAPIMigration { } // We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names. - private def lowerPorts(ref: ModuleTarget, m: DefModule, renameMap: RenameMap): - (DefModule, Seq[(String, Seq[Reference])]) = { + private def lowerPorts( + ref: ModuleTarget, + m: DefModule, + renameMap: RenameMap + ): (DefModule, Seq[(String, Seq[Reference])]) = { val namespace = mutable.HashSet[String]() ++ m.ports.map(_.name) val loweredPortsAndRefs = m.ports.flatMap { p => - val fieldsAndRefs = DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set()) - fieldsAndRefs.map { case (f, ref) => - (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind))) + val fieldsAndRefs = + DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set()) + fieldsAndRefs.map { + case (f, ref) => + (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind))) } } val newM = m match { - case e : ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1)) - case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1)) + case e: ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1)) + case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1)) } (newM, loweredPortsAndRefs.map(_._2)) } - private def onStatement(s: Statement)(implicit symbols: LoweringTable, memInit: Seq[MemoryInitAnnotation]): Statement = s match { + private def onStatement( + s: Statement + )( + implicit symbols: LoweringTable, + memInit: Seq[MemoryInitAnnotation] + ): Statement = s match { // declarations - case d : DefWire => - Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name=name, tpe=tpe) }) + case d: DefWire => + Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name = name, tpe = tpe) }) case d @ DefRegister(info, _, _, clock, reset, _) => // clock and reset are always of ground type val loweredClock = onExpression(clock) @@ -113,41 +146,41 @@ object LowerTypes extends Transform with DependencyAPIMigration { // It is important to first lower the declaration, because the reset can refer to the register itself! val loweredRegs = symbols.lower(d.name, d.tpe, firrtl.RegKind) val inits = Utils.create_exps(d.init).map(onExpression) - Block( - loweredRegs.zip(inits).map { case ((name, tpe, _), init) => + Block(loweredRegs.zip(inits).map { + case ((name, tpe, _), init) => DefRegister(info, name, tpe, loweredClock, loweredReset, init) }) - case d : DefNode => + case d: DefNode => val values = Utils.create_exps(d.value).map(onExpression) - Block( - symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map{ case((name, tpe, _), value) => + Block(symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map { + case ((name, tpe, _), value) => assert(tpe == value.tpe) DefNode(d.info, name, value) }) - case d : DefMemory => + case d: DefMemory => // TODO: as an optimization, we could just skip ground type memories here. // This would require that we don't error in getReferences() but instead return the old reference. val mems = symbols.lower(d) - if(mems.length > 1 && memInit.exists(_.target.ref == d.name)) { + if (mems.length > 1 && memInit.exists(_.target.ref == d.name)) { val mod = memInit.find(_.target.ref == d.name).get.target.encapsulatingModule val msg = s"[module $mod] Cannot initialize memory ${d.name} of non ground type ${d.dataType.serialize}" throw new RuntimeException(msg) } Block(mems) - case d : DefInstance => symbols.lower(d) + case d: DefInstance => symbols.lower(d) // connections case Connect(info, loc, expr) => - if(!expr.tpe.isInstanceOf[GroundType]) { + if (!expr.tpe.isInstanceOf[GroundType]) { throw new RuntimeException(s"LowerTypes expects Connects to have been expanded! ${expr.tpe.serialize}") } val rhs = onExpression(expr) // We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated. val lhs = symbols.getReferences(loc.asInstanceOf[RefLikeExpression]) Block(lhs.map(loc => Connect(info, loc, rhs))) - case p : PartialConnect => + case p: PartialConnect => throw new RuntimeException(s"LowerTypes expects PartialConnects to be resolved! $p") case IsInvalid(info, expr) => - if(!expr.tpe.isInstanceOf[GroundType]) { + if (!expr.tpe.isInstanceOf[GroundType]) { throw new RuntimeException(s"LowerTypes expects IsInvalids to have been expanded! ${expr.tpe.serialize}") } // We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated. @@ -172,15 +205,18 @@ object LowerTypes extends Transform with DependencyAPIMigration { // Holds the first level of the module-level namespace. // (i.e. everything that can be addressed directly by a Reference node) private class LoweringSymbolTable extends SymbolTable { - def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name) + def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name) def declareInstance(name: String, module: String): Unit = symbols.append(name) private val symbols = mutable.ArrayBuffer[String]() def getSymbolNames: Iterable[String] = symbols } // Lowers types and keeps track of references to lowered types. -private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: ModuleTarget, - portNameToExprs: Seq[(String, Seq[Reference])]) { +private class LoweringTable( + table: LoweringSymbolTable, + renameMap: RenameMap, + m: ModuleTarget, + portNameToExprs: Seq[(String, Seq[Reference])]) { private val portNames: Set[String] = portNameToExprs.map(_._2.head.name).toSet private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames // Serialized old access string to new ground type reference. @@ -196,10 +232,11 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: nameToExprs ++= refs.map { case (name, r) => name -> List(r) } newInst } + /** used to lower nodes, registers and wires */ def lower(name: String, tpe: Type, kind: Kind, flip: Orientation = Default): Seq[(String, Type, Orientation)] = { val fieldsAndRefs = DestructTypes.destruct(m, Field(name, flip, tpe), namespace, renameMap, portNames) - nameToExprs ++= fieldsAndRefs.map{ case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) } + nameToExprs ++= fieldsAndRefs.map { case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) } fieldsAndRefs.map { case (f, _) => (f.name, f.tpe, f.flip) } } def lower(p: Port): Seq[Port] = { @@ -211,10 +248,10 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: // We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left. private def serialize(expr: RefLikeExpression): String = expr match { - case Reference(name, _, _, _) => name - case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name + case Reference(name, _, _, _) => name + case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name case SubIndex(expr, index, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "[" + index.toString + "]" - case a : SubAccess => + case a: SubAccess => throw new RuntimeException(s"LowerTypes expects all SubAccesses to have been expanded! ${a.serialize}") } } @@ -230,13 +267,18 @@ private object DestructTypes { * - generates a list of all old reference name that now refer to the particular ground type field * - updates namespace with all possibly conflicting names */ - def destruct(m: ModuleTarget, ref: Field, namespace: Namespace, renameMap: RenameMap, reserved: Set[String]): - Seq[(Field, String)] = { + def destruct( + m: ModuleTarget, + ref: Field, + namespace: Namespace, + renameMap: RenameMap, + reserved: Set[String] + ): Seq[(Field, String)] = { // field renames (uniquify) are computed bottom up val (rename, _) = uniquify(ref, namespace, reserved) // early exit for ground types that do not need renaming - if(ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) { + if (ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) { return List((ref, ref.name)) } @@ -253,8 +295,13 @@ private object DestructTypes { * Note that the list of fields is only of the child fields, and needs a SubField node * instead of a flat Reference when turning them into access expressions. */ - def destructInstance(m: ModuleTarget, instance: DefInstance, namespace: Namespace, renameMap: RenameMap, - reserved: Set[String]): (DefInstance, Seq[(String, SubField)]) = { + def destructInstance( + m: ModuleTarget, + instance: DefInstance, + namespace: Namespace, + renameMap: RenameMap, + reserved: Set[String] + ): (DefInstance, Seq[(String, SubField)]) = { val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved) val newName = rename.map(_.name).getOrElse(instance.name) @@ -266,14 +313,14 @@ private object DestructTypes { } // rename all references to the instance if necessary - if(newName != instance.name) { + if (newName != instance.name) { renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module)) } // The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered. val newInstance = instance.copy(name = newName, tpe = BundleType(children.map(_._1))) val instanceRef = Reference(newName, newInstance.tpe, InstanceKind) - val refs = children.map{ case(c,r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) } + val refs = children.map { case (c, r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) } (newInstance, refs) } @@ -285,8 +332,13 @@ private object DestructTypes { * e.g. ("mem_a.r.clk", "mem.r.clk") and ("mem_b.r.clk", "mem.r.clk") * Thus it is appropriate to groupBy old reference string instead of just inserting into a hash table. */ - def destructMemory(m: ModuleTarget, mem: DefMemory, namespace: Namespace, renameMap: RenameMap, - reserved: Set[String]): (Seq[DefMemory], Seq[(String, SubField)]) = { + def destructMemory( + m: ModuleTarget, + mem: DefMemory, + namespace: Namespace, + renameMap: RenameMap, + reserved: Set[String] + ): (Seq[DefMemory], Seq[(String, SubField)]) = { // Uniquify the lowered memory names: When memories get split up into ground types, the access order is changes. // E.g. `mem.r.data.x` becomes `mem_x.r.data`. // This is why we need to create the new bundle structure before we can resolve any name clashes. @@ -301,48 +353,50 @@ private object DestructTypes { // the "old dummy field" is used as a template for the new memory port types val oldDummyField = Field("dummy", Default, MemPortUtils.memType(mem.copy(dataType = BoolType))) - val newMemAndSubFields = res.map { case (field, refs) => - val newMem = mem.copy(name = field.name, dataType = field.tpe) - val newMemRef = m.ref(field.name) - val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType - if(memWasRenamed) { renameMap.record(oldMemRef, newMemRef) } - - val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind) - val refSuffixes = refs.map(_.component).filterNot(_.isEmpty) - - val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port => - val oldPortRef = oldMemRef.field(port.name) - val newPortRef = newMemRef.field(port.name) - - val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe - val newPortAccess = SubField(newMemReference, port.name, newPortType) - - port.tpe.asInstanceOf[BundleType].fields.map { portField => - val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata" - val isMaskField = portField.name == "mask" || portField.name == "wmask" - val isDataOrMaskField = isDataField || isMaskField - val oldFieldRefs = if(memWasRenamed && isDataOrMaskField) { - // there might have been multiple different fields which now alias to the same lowered field. - val oldPortFieldBaseRef = oldPortRef.field(portField.name) - refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s)) - } else { - List(oldPortRef.field(portField.name)) + val newMemAndSubFields = res.map { + case (field, refs) => + val newMem = mem.copy(name = field.name, dataType = field.tpe) + val newMemRef = m.ref(field.name) + val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType + if (memWasRenamed) { renameMap.record(oldMemRef, newMemRef) } + + val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind) + val refSuffixes = refs.map(_.component).filterNot(_.isEmpty) + + val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port => + val oldPortRef = oldMemRef.field(port.name) + val newPortRef = newMemRef.field(port.name) + + val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe + val newPortAccess = SubField(newMemReference, port.name, newPortType) + + port.tpe.asInstanceOf[BundleType].fields.map { portField => + val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata" + val isMaskField = portField.name == "mask" || portField.name == "wmask" + val isDataOrMaskField = isDataField || isMaskField + val oldFieldRefs = if (memWasRenamed && isDataOrMaskField) { + // there might have been multiple different fields which now alias to the same lowered field. + val oldPortFieldBaseRef = oldPortRef.field(portField.name) + refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s)) + } else { + List(oldPortRef.field(portField.name)) + } + + val newPortType = if (isDataField) { newMem.dataType } + else { portField.tpe } + val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType) + + // record renames only for the data field which is the only port field of non-ground type + val newPortFieldRef = newPortRef.field(portField.name) + if (memWasRenamed && isDataOrMaskField) { + oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) } + } + + val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs) + (oldFieldStringRef, newPortFieldAccess) } - - val newPortType = if(isDataField) { newMem.dataType } else { portField.tpe } - val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType) - - // record renames only for the data field which is the only port field of non-ground type - val newPortFieldRef = newPortRef.field(portField.name) - if(memWasRenamed && isDataOrMaskField) { - oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) } - } - - val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs) - (oldFieldStringRef, newPortFieldAccess) } - } - (newMem, subFields) + (newMem, subFields) } (newMemAndSubFields.map(_._1), newMemAndSubFields.flatMap(_._2)) @@ -356,22 +410,30 @@ private object DestructTypes { Field(mem.name, Default, BundleType(fields)) } - private def recordRenames(fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], renameMap: RenameMap, parent: ParentRef): - Unit = { + private def recordRenames( + fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], + renameMap: RenameMap, + parent: ParentRef + ): Unit = { // TODO: if we group by ReferenceTarget, we could reduce the number of calls to `record`. Is it worth it? - fieldToRefs.foreach { case(field, refs) => - val fieldRef = parent.ref(field.name) - refs.foreach{ r => renameMap.record(r, fieldRef) } + fieldToRefs.foreach { + case (field, refs) => + val fieldRef = parent.ref(field.name) + refs.foreach { r => renameMap.record(r, fieldRef) } } } private def extractGroundTypeRefString(refs: Seq[ReferenceTarget]): String = { - if (refs.isEmpty) { "" } else { + if (refs.isEmpty) { "" } + else { // Since we depend on ExpandConnects any reference we encounter will be of ground type // and thus the one with the longest access path. - refs.reduceLeft((x, y) => if (x.component.length > y.component.length) x else y) + refs + .reduceLeft((x, y) => if (x.component.length > y.component.length) x else y) // convert references to strings relative to the module - .serialize.dropWhile(_ != '>').tail + .serialize + .dropWhile(_ != '>') + .tail } } @@ -385,14 +447,19 @@ private object DestructTypes { * @return a sequence of ground type fields with new names and, for each field, * a sequence of old references that should to be renamed to point to the particular field */ - private def destruct(prefix: String, oldParent: ParentRef, oldField: Field, - isVecField: Boolean, rename: Option[RenameNode]): Seq[(Field, Seq[ReferenceTarget])] = { + private def destruct( + prefix: String, + oldParent: ParentRef, + oldField: Field, + isVecField: Boolean, + rename: Option[RenameNode] + ): Seq[(Field, Seq[ReferenceTarget])] = { val newName = rename.map(_.name).getOrElse(oldField.name) val oldRef = oldParent.ref(oldField.name, isVecField) oldField.tpe match { - case _ : GroundType => List((oldField.copy(name = prefix + newName), List(oldRef))) - case _ : BundleType | _ : VectorType => + case _: GroundType => List((oldField.copy(name = prefix + newName), List(oldRef))) + case _: BundleType | _: VectorType => val newPrefix = prefix + newName + LowerTypes.delim val isVecField = oldField.tpe.isInstanceOf[VectorType] val fields = getFields(oldField.tpe) @@ -401,7 +468,7 @@ private object DestructTypes { destruct(newPrefix, RefParentRef(oldRef), f, isVecField, rename.flatMap(_.children.get(f.name))) } // the bundle/vec reference refers to all children - children.map{ case(c, r) => (c, r :+ oldRef) } + children.map { case (c, r) => (c, r :+ oldRef) } } } @@ -409,7 +476,8 @@ private object DestructTypes { /** Implements the core functionality of the old Uniquify pass: rename bundle fields and top-level references * where necessary in order to avoid name clashes when lowering aggregate type with the `_` delimiter. - * We don't actually do the rename here but just calculate a rename tree. */ + * We don't actually do the rename here but just calculate a rename tree. + */ private def uniquify(ref: Field, namespace: Namespace, reserved: Set[String]): (Option[RenameNode], Seq[String]) = { // ensure that there are no name clashes with the list of reserved (port) names val newRefName = findValidPrefix(ref.name, reserved.contains) @@ -426,23 +494,23 @@ private object DestructTypes { // We added f.name in previous map, delete if we change it val renamed = prefix != ref.name if (renamed) { - if(!reserved.contains(ref.name)) namespace -= ref.name + if (!reserved.contains(ref.name)) namespace -= ref.name namespace += prefix } val suffixes = renamedFieldNames.map(f => prefix + LowerTypes.delim + f) val anyChildRenamed = renamedFields.exists(_._1.isDefined) - val rename = if(renamed || anyChildRenamed){ - val children = renamedFields.map(_._1).zip(fields).collect{ case (Some(r), f) => f.name -> r }.toMap + val rename = if (renamed || anyChildRenamed) { + val children = renamedFields.map(_._1).zip(fields).collect { case (Some(r), f) => f.name -> r }.toMap Some(RenameNode(prefix, children)) } else { None } (rename, suffixes :+ prefix) - case v : VectorType=> + case v: VectorType => // if Vecs are to be lowered, we can just treat them like a bundle uniquify(ref.copy(tpe = vecToBundle(v)), namespace, reserved) - case _ : GroundType => - if(newRefName == ref.name) { + case _: GroundType => + if (newRefName == ref.name) { (None, List(ref.name)) } else { (Some(RenameNode(newRefName, Map())), List(newRefName)) @@ -452,22 +520,23 @@ private object DestructTypes { } /** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check - * because elts could be Seq("") In this case, we're just really checking if prefix itself collides */ + * because elts could be Seq("") In this case, we're just really checking if prefix itself collides + */ @tailrec private def findValidPrefix(prefix: String, inNamespace: String => Boolean, elts: Seq[String] = List("")): String = { elts.find(elt => inNamespace(prefix + elt)) match { case Some(_) => findValidPrefix(prefix + "_", inNamespace, elts) - case None => prefix + case None => prefix } } private def getFields(tpe: Type): Seq[Field] = tpe match { case BundleType(fields) => fields - case v : VectorType => vecToBundle(v).fields + case v: VectorType => vecToBundle(v).fields } private def vecToBundle(v: VectorType): BundleType = { - BundleType(( 0 until v.size).map(i => Field(i.toString, Default, v.tpe))) + BundleType((0 until v.size).map(i => Field(i.toString, Default, v.tpe))) } /** Used to abstract over module and reference parents. @@ -480,6 +549,7 @@ private object DestructTypes { } private case class RefParentRef(r: ReferenceTarget) extends ParentRef { override def ref(name: String, asVecField: Boolean): ReferenceTarget = - if(asVecField) { r.index(name.toInt) } else { r.field(name) } + if (asVecField) { r.index(name.toInt) } + else { r.field(name) } } } diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index ca5c2544..79560605 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -15,23 +15,21 @@ object PadWidths extends Pass { override def prerequisites = ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf)).toSeq override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation]) override def optionalPrerequisiteOf = - Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case _: firrtl.transforms.ConstantPropagation | Legalize => true case _ => false } - private def width(t: Type): Int = bitWidth(t).toInt + private def width(t: Type): Int = bitWidth(t).toInt private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width private def fixup(i: Int)(e: Expression) = { @@ -54,31 +52,31 @@ object PadWidths extends Pass { } // Recursive, updates expression so children exp's have correct widths - private def onExp(e: Expression): Expression = e map onExp match { + private def onExp(e: Expression): Expression = e.map(onExp) match { case Mux(cond, tval, fval, tpe) => Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe) - case ex: ValidIf => ex copy (value = fixup(width(ex.tpe))(ex.value)) - case ex: DoPrim => ex.op match { - case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | - Add | Sub | Mul | Div | Rem | Shr => - // sensitive ops - ex map fixup((ex.args map width foldLeft 0)(math.max)) - case Dshl => - // special case as args aren't all same width - ex copy (op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) - case _ => ex - } + case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value)) + case ex: DoPrim => + ex.op match { + case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Mul | Div | Rem | Shr => + // sensitive ops + ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max))) + case Dshl => + // special case as args aren't all same width + ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1))) + case _ => ex + } case ex => ex } // Recursive. Fixes assignments and register initialization widths - private def onStmt(s: Statement): Statement = s map onExp match { + private def onStmt(s: Statement): Statement = s.map(onExp) match { case sx: Connect => - sx copy (expr = fixup(width(sx.loc))(sx.expr)) + sx.copy(expr = fixup(width(sx.loc))(sx.expr)) case sx: DefRegister => - sx copy (init = fixup(width(sx.tpe))(sx.init)) - case sx => sx map onStmt + sx.copy(init = fixup(width(sx.tpe))(sx.init)) + case sx => sx.map(onStmt) } - def run(c: Circuit): Circuit = c copy (modules = c.modules map (_ map onStmt)) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(_.map(onStmt))) } diff --git a/src/main/scala/firrtl/passes/Pass.scala b/src/main/scala/firrtl/passes/Pass.scala index 036bd06a..b5eac4ed 100644 --- a/src/main/scala/firrtl/passes/Pass.scala +++ b/src/main/scala/firrtl/passes/Pass.scala @@ -8,7 +8,7 @@ import firrtl.{CircuitState, FirrtlUserException, Transform} * Has an [[UnknownForm]], because larger [[Transform]] should specify form */ trait Pass extends Transform with DependencyAPIMigration { - def run(c: Circuit): Circuit + def run(c: Circuit): Circuit def execute(state: CircuitState): CircuitState = state.copy(circuit = run(state.circuit)) } diff --git a/src/main/scala/firrtl/passes/PullMuxes.scala b/src/main/scala/firrtl/passes/PullMuxes.scala index b805b5fc..27543d63 100644 --- a/src/main/scala/firrtl/passes/PullMuxes.scala +++ b/src/main/scala/firrtl/passes/PullMuxes.scala @@ -11,38 +11,50 @@ object PullMuxes extends Pass { override def invalidates(a: Transform) = false def run(c: Circuit): Circuit = { - def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { - case ex: WSubField => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubField(exx.tval, ex.name, ex.tpe, ex.flow), - WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubIndex => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), - WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubAccess => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), - WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex => ex - } - def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e - val modulesx = c.modules.map { - case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) - case (m:ExtModule) => m - } - Circuit(c.info, modulesx, c.main) - } + def pull_muxes_e(e: Expression): Expression = e.map(pull_muxes_e) match { + case ex: WSubField => + ex.expr match { + case exx: Mux => + Mux( + exx.cond, + WSubField(exx.tval, ex.name, ex.tpe, ex.flow), + WSubField(exx.fval, ex.name, ex.tpe, ex.flow), + ex.tpe + ) + case exx: ValidIf => ValidIf(exx.cond, WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubIndex => + ex.expr match { + case exx: Mux => + Mux( + exx.cond, + WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), + WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), + ex.tpe + ) + case exx: ValidIf => ValidIf(exx.cond, WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubAccess => + ex.expr match { + case exx: Mux => + Mux( + exx.cond, + WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), + WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), + ex.tpe + ) + case exx: ValidIf => ValidIf(exx.cond, WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex => ex + } + def pull_muxes(s: Statement): Statement = s.map(pull_muxes).map(pull_muxes_e) + val modulesx = c.modules.map { + case (m: Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) + case (m: ExtModule) => m + } + Circuit(c.info, modulesx, c.main) + } } diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index 18db5939..015346ff 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -2,7 +2,7 @@ package firrtl.passes -import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField} +import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex} import firrtl.PrimOps.{And, Eq} import firrtl.ir._ import firrtl.Mappers._ @@ -17,10 +17,12 @@ import scala.collection.mutable object RemoveAccesses extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ZeroLengthVecs), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ZeroLengthVecs), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case Uniquify | ResolveKinds | ResolveFlows => true @@ -28,8 +30,8 @@ object RemoveAccesses extends Pass { } private def AND(e1: Expression, e2: Expression) = - if(e1 == one) e2 - else if(e2 == one) e1 + if (e1 == one) e2 + else if (e2 == one) e1 else DoPrim(And, Seq(e1, e2), Nil, BoolType) private def EQV(e1: Expression, e2: Expression): Expression = @@ -45,30 +47,35 @@ object RemoveAccesses extends Pass { * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) */ private def getLocations(e: Expression): Seq[Location] = e match { - case e: WRef => create_exps(e).map(Location(_,one)) + case e: WRef => create_exps(e).map(Location(_, one)) case e: WSubIndex => val ls = getLocations(e.expr) val start = get_point(e) val end = start + get_size(e.tpe) val stride = get_size(e.expr.tpe) - for ((l, i) <- ls.zipWithIndex - if ((i % stride) >= start) & ((i % stride) < end)) yield l + for ( + (l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end) + ) yield l case e: WSubField => val ls = getLocations(e.expr) val start = get_point(e) val end = start + get_size(e.tpe) val stride = get_size(e.expr.tpe) - for ((l, i) <- ls.zipWithIndex - if ((i % stride) >= start) & ((i % stride) < end)) yield l + for ( + (l, i) <- ls.zipWithIndex + if ((i % stride) >= start) & ((i % stride) < end) + ) yield l case e: WSubAccess => val ls = getLocations(e.expr) val stride = get_size(e.tpe) val wrap = e.expr.tpe.asInstanceOf[VectorType].size - ls.zipWithIndex map {case (l, i) => - val c = (i / stride) % wrap - val basex = l.base - val guardx = AND(l.guard,EQV(UIntLiteral(c),e.index)) - Location(basex,guardx) + ls.zipWithIndex.map { + case (l, i) => + val c = (i / stride) % wrap + val basex = l.base + val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index)) + Location(basex, guardx) } } @@ -78,10 +85,10 @@ object RemoveAccesses extends Pass { var ret: Boolean = false def rec_has_access(e: Expression): Expression = { e match { - case _ : WSubAccess => ret = true + case _: WSubAccess => ret = true case _ => } - e map rec_has_access + e.map(rec_has_access) } rec_has_access(e) ret @@ -90,7 +97,7 @@ object RemoveAccesses extends Pass { // This improves the performance of this pass private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]() private def create_exps(e: Expression) = - createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e)) + createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e)) def run(c: Circuit): Circuit = { def remove_m(m: Module): Module = { @@ -105,21 +112,21 @@ object RemoveAccesses extends Pass { */ val stmts = mutable.ArrayBuffer[Statement]() def removeSource(e: Expression): Expression = e match { - case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(e) => + case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) => val rs = getLocations(e) - rs find (x => x.guard != one) match { + rs.find(x => x.guard != one) match { case None => throwInternalError(s"removeSource: shouldn't be here - $e") case Some(_) => val (wire, temp) = create_temp(e) val temps = create_exps(temp) def getTemp(i: Int) = temps(i % temps.size) stmts += wire - rs.zipWithIndex foreach { + rs.zipWithIndex.foreach { case (x, i) if i < temps.size => - stmts += IsInvalid(get_info(s),getTemp(i)) - stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) + stmts += IsInvalid(get_info(s), getTemp(i)) + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) case (x, i) => - stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) + stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt) } temp } @@ -129,14 +136,16 @@ object RemoveAccesses extends Pass { /** Replaces a subaccess in a given sink expression */ def removeSink(info: Info, loc: Expression): Expression = loc match { - case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(loc) => + case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) => val ls = getLocations(loc) - if (ls.size == 1 & weq(ls.head.guard,one)) loc + if (ls.size == 1 & weq(ls.head.guard, one)) loc else { val (wire, temp) = create_temp(loc) stmts += wire - ls foreach (x => stmts += - Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt)) + ls.foreach(x => + stmts += + Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt) + ) temp } case _ => loc @@ -150,7 +159,7 @@ object RemoveAccesses extends Pass { case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow)) //case w: WSubIndex => removeSource(w) //case w: WSubField => removeSource(w) - case x => x map fixSource + case x => x.map(fixSource) } /** Recursively walks a sink expression and fixes all subaccesses @@ -159,13 +168,13 @@ object RemoveAccesses extends Pass { */ def fixSink(e: Expression): Expression = e match { case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow) - case x => x map fixSink + case x => x.map(fixSink) } val sx = s match { case Connect(info, loc, exp) => Connect(info, removeSink(info, fixSink(loc)), fixSource(exp)) - case sxx => sxx map fixSource map onStmt + case sxx => sxx.map(fixSource).map(onStmt) } stmts += sx if (stmts.size != 1) Block(stmts.toSeq) else stmts(0) @@ -173,9 +182,9 @@ object RemoveAccesses extends Pass { Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body))) } - c copy (modules = c.modules map { + c.copy(modules = c.modules.map { case m: ExtModule => m - case m: Module => remove_m(m) + case m: Module => remove_m(m) }) } } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 61fd6258..624138ab 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -17,8 +17,7 @@ case class DataRef(exp: Expression, source: String, sink: String, mask: String, object RemoveCHIRRTL extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.ChirrtlForm ++ - Seq( Dependency(passes.CInferTypes), - Dependency(passes.CInferMDir) ) + Seq(Dependency(passes.CInferTypes), Dependency(passes.CInferMDir)) override def invalidates(a: Transform) = false @@ -31,10 +30,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match { case _: GroundType => Seq(ex) - case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex) - case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex) + case t: BundleType => + (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq( + ex + ) + case t: VectorType => + ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => + exps ++ create_all_exps(SubIndex(ex, i, t.tpe)) + ) ++ Seq(ex) case UnknownType => Seq(ex) } @@ -42,17 +45,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { case ex: Mux => val e1s = create_exps(ex.tval) val e2s = create_exps(ex.fval) - (e1s zip e2s) map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) } + (e1s.zip(e2s)).map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) } case ex: ValidIf => - create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe)) - case ex => ex.tpe match { - case _: GroundType => Seq(ex) - case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(SubField(ex, f.name, f.tpe))) - case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(SubIndex(ex, i, t.tpe))) - case UnknownType => Seq(ex) - } + create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe)) + case ex => + ex.tpe match { + case _: GroundType => Seq(ex) + case t: BundleType => + (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_exps(SubField(ex, f.name, f.tpe))) + case t: VectorType => + ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => exps ++ create_exps(SubIndex(ex, i, t.tpe))) + case UnknownType => Seq(ex) + } } private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]()) @@ -61,40 +65,48 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { s match { case sx: CDefMemory if sx.seq => smems += sx.name case sx: CDefMPort => - val p = mports getOrElse (sx.mem, EMPs) + val p = mports.getOrElse(sx.mem, EMPs) sx.direction match { - case MRead => p.readers += MPort(sx.name, sx.exps(1)) - case MWrite => p.writers += MPort(sx.name, sx.exps(1)) + case MRead => p.readers += MPort(sx.name, sx.exps(1)) + case MWrite => p.writers += MPort(sx.name, sx.exps(1)) case MReadWrite => p.readwriters += MPort(sx.name, sx.exps(1)) - case MInfer => // direction may not be inferred if it's not being used + case MInfer => // direction may not be inferred if it's not being used } mports(sx.mem) = p case _ => } - s map collect_smems_and_mports(mports, smems) + s.map(collect_smems_and_mports(mports, smems)) } - def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap, - refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match { + def collect_refs( + mports: MPortMap, + smems: SeqMemSet, + types: MPortTypeMap, + refs: DataRefMap, + raddrs: AddrMap, + renames: RenameMap + )(s: Statement + ): Statement = s match { case sx: CDefMemory => types(sx.name) = sx.tpe - val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1))) + val taddr = UIntType(IntWidth(1.max(getUIntWidth(sx.size - 1)))) val tdata = sx.tpe - def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap (r => Seq( - IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)), - IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType)) - )) - def set_enable(vec: scala.collection.Seq[MPort], en: String) = vec.toSeq.map (r => - Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero) + def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap(r => + Seq( + IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)), + IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType)) + ) ) + def set_enable(vec: scala.collection.Seq[MPort], en: String) = + vec.toSeq.map(r => Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)) def set_write(vec: scala.collection.Seq[MPort], data: String, mask: String) = vec.toSeq.flatMap { r => val tmask = createMask(sx.tpe) val portRef = SubField(Reference(sx.name, ut), r.name, ut) Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask))) } - val rds = (mports getOrElse (sx.name, EMPs)).readers - val wrs = (mports getOrElse (sx.name, EMPs)).writers - val rws = (mports getOrElse (sx.name, EMPs)).readwriters + val rds = (mports.getOrElse(sx.name, EMPs)).readers + val wrs = (mports.getOrElse(sx.name, EMPs)).writers + val rws = (mports.getOrElse(sx.name, EMPs)).readwriters val stmts = set_poison(rds) ++ set_enable(rds, "en") ++ set_poison(wrs) ++ @@ -104,8 +116,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { set_enable(rws, "wmode") ++ set_enable(rws, "en") ++ set_write(rws, "wdata", "wmask") - val mem = DefMemory(sx.info, sx.name, sx.tpe, sx.size, 1, if (sx.seq) 1 else 0, - rds.map(_.name).toSeq, wrs.map(_.name).toSeq, rws.map(_.name).toSeq, sx.readUnderWrite) + val mem = DefMemory( + sx.info, + sx.name, + sx.tpe, + sx.size, + 1, + if (sx.seq) 1 else 0, + rds.map(_.name).toSeq, + wrs.map(_.name).toSeq, + rws.map(_.name).toSeq, + sx.readUnderWrite + ) Block(mem +: stmts) case sx: CDefMPort => types.get(sx.mem) match { @@ -130,8 +152,8 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val es = create_all_exps(WRef(sx.name, sx.tpe)) val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe)) val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe)) - ((es zip rs) zip ws) map { - case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) + ((es.zip(rs)).zip(ws)).map { + case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) } case MWrite => refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false) @@ -142,7 +164,7 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") val es = create_all_exps(WRef(sx.name, sx.tpe)) val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) - (es zip ws) map { + (es.zip(ws)).map { case (e, w) => renames.rename(e.serialize, w.serialize) } case MRead => @@ -157,63 +179,69 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") val es = create_all_exps(WRef(sx.name, sx.tpe)) val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) - (es zip rs) map { + (es.zip(rs)).map { case (e, r) => renames.rename(e.serialize, r.serialize) } case MInfer => // do nothing if it's not being used } - Block(List() ++ - (addrs.map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++ - (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++ - (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++ - masks.map(lhs => Connect(sx.info, lhs, zero)) + Block( + List() ++ + (addrs.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++ + (clks.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++ + (ens.map(x => Connect(sx.info, SubField(portRef, x, ut), one))) ++ + masks.map(lhs => Connect(sx.info, lhs, zero)) ) - case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames) + case sx => sx.map(collect_refs(mports, smems, types, refs, raddrs, renames)) } def get_mask(refs: DataRefMap)(e: Expression): Expression = - e map get_mask(refs) match { - case ex: Reference => refs get ex.name match { - case None => ex - case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe)) - } + e.map(get_mask(refs)) match { + case ex: Reference => + refs.get(ex.name) match { + case None => ex + case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe)) + } case ex => ex } def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = { var has_write_mport = false var has_readwrite_mport: Option[Expression] = None - var has_read_mport: Option[Expression] = None + var has_read_mport: Option[Expression] = None def remove_chirrtl_e(g: Flow)(e: Expression): Expression = e match { - case Reference(name, tpe, _, _) => refs get name match { - case Some(p) => g match { - case SinkFlow => - has_write_mport = true - if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType)) - SubField(p.exp, p.sink, tpe) - case SourceFlow => - SubField(p.exp, p.source, tpe) - } - case None => g match { - case SinkFlow => raddrs get name match { - case Some(en) => has_read_mport = Some(en) ; e - case None => e - } - case SourceFlow => e + case Reference(name, tpe, _, _) => + refs.get(name) match { + case Some(p) => + g match { + case SinkFlow => + has_write_mport = true + if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType)) + SubField(p.exp, p.sink, tpe) + case SourceFlow => + SubField(p.exp, p.source, tpe) + } + case None => + g match { + case SinkFlow => + raddrs.get(name) match { + case Some(en) => has_read_mport = Some(en); e + case None => e + } + case SourceFlow => e + } } - } - case SubAccess(expr, index, tpe, _) => SubAccess( - remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe) - case ex => ex map remove_chirrtl_e(g) - } - s match { + case SubAccess(expr, index, tpe, _) => + SubAccess(remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe) + case ex => ex.map(remove_chirrtl_e(g)) + } + s match { case DefNode(info, name, value) => val valuex = remove_chirrtl_e(SourceFlow)(value) val sx = DefNode(info, name, valuex) // Check node is used for read port address remove_chirrtl_e(SinkFlow)(Reference(name, value.tpe)) has_read_mport match { - case None => sx + case None => sx case Some(en) => Block(sx, Connect(info, en, one)) } case Connect(info, loc, expr) => @@ -222,14 +250,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val sx = Connect(info, locx, rocx) val stmts = ArrayBuffer[Statement]() has_read_mport match { - case None => + case None => case Some(en) => stmts += Connect(info, en, one) } if (has_write_mport) { val locs = create_exps(get_mask(refs)(loc)) - stmts ++= (locs map (x => Connect(info, x, one))) + stmts ++= (locs.map(x => Connect(info, x, one))) has_readwrite_mport match { - case None => + case None => case Some(wmode) => stmts += Connect(info, wmode, one) } } @@ -240,20 +268,20 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val sx = PartialConnect(info, locx, rocx) val stmts = ArrayBuffer[Statement]() has_read_mport match { - case None => + case None => case Some(en) => stmts += Connect(info, en, one) } if (has_write_mport) { val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default) val locs = create_exps(get_mask(refs)(loc)) - stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) }) + stmts ++= (ls.map { case (x, _) => Connect(info, locs(x), one) }) has_readwrite_mport match { - case None => + case None => case Some(wmode) => stmts += Connect(info, wmode, one) } } if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq) - case sx => sx map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(SourceFlow) + case sx => sx.map(remove_chirrtl_s(refs, raddrs)).map(remove_chirrtl_e(SourceFlow)) } } @@ -264,16 +292,16 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration { val refs = new DataRefMap val raddrs = new AddrMap renames.setModule(m.name) - (m map collect_smems_and_mports(mports, smems) - map collect_refs(mports, smems, types, refs, raddrs, renames) - map remove_chirrtl_s(refs, raddrs)) + (m.map(collect_smems_and_mports(mports, smems)) + .map(collect_refs(mports, smems, types, refs, raddrs, renames)) + .map(remove_chirrtl_s(refs, raddrs))) } def execute(state: CircuitState): CircuitState = { val c = state.circuit val renames = RenameMap() renames.setCircuit(c.main) - val result = c copy (modules = c.modules map remove_chirrtl_m(renames)) + val result = c.copy(modules = c.modules.map(remove_chirrtl_m(renames))) state.copy(circuit = result, renames = Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala index eabf667c..eb25dcc4 100644 --- a/src/main/scala/firrtl/passes/RemoveEmpty.scala +++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala @@ -15,7 +15,7 @@ object RemoveEmpty extends Pass with DependencyAPIMigration { private def onModule(m: DefModule): DefModule = { m match { - case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) + case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body)) case m: ExtModule => m } } diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala index 7059526c..657b4356 100644 --- a/src/main/scala/firrtl/passes/RemoveIntervals.scala +++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala @@ -13,14 +13,13 @@ import firrtl.options.Dependency import scala.math.BigDecimal.RoundingMode._ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) - extends PassException({ - val toWrap = wrap.args.head.serialize - val toWrapTpe = wrap.args.head.tpe.serialize - val wrapTo = wrap.args(1).serialize - val wrapToTpe = wrap.args(1).tpe.serialize - s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" - }) - + extends PassException({ + val toWrap = wrap.args.head.serialize + val toWrapTpe = wrap.args.head.tpe.serialize + val wrapTo = wrap.args(1).serialize + val wrapToTpe = wrap.args(1).tpe.serialize + s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe" + }) /** Replaces IntervalType with SIntType, three AST walks: * 1) Align binary points @@ -39,48 +38,50 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim) class RemoveIntervals extends Pass { override def prerequisites: Seq[Dependency[Transform]] = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck] + ) ++ firrtl.stage.Forms.Deduped override def invalidates(transform: Transform): Boolean = { transform match { case InferTypes | ResolveKinds => true - case _ => false + case _ => false } } def run(c: Circuit): Circuit = { val alignedCircuit = c val errors = new Errors() - val wiredCircuit = alignedCircuit map makeWireModule - val replacedCircuit = wiredCircuit map replaceModuleInterval(errors) + val wiredCircuit = alignedCircuit.map(makeWireModule) + val replacedCircuit = wiredCircuit.map(replaceModuleInterval(errors)) errors.trigger() replacedCircuit } /* Replace interval types */ private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule = - m map replaceStmtInterval(errors, m.name) map replacePortInterval + m.map(replaceStmtInterval(errors, m.name)).map(replacePortInterval) private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = { val info = s match { case h: HasInfo => h.info case _ => NoInfo } - s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname) + s.map(replaceTypeInterval).map(replaceStmtInterval(errors, mname)).map(replaceExprInterval(errors, info, mname)) } private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match { case _: WRef | _: WSubIndex | _: WSubField => e case o => - o map replaceExprInterval(errors, info, mname) match { + o.map(replaceExprInterval(errors, info, mname)) match { case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe) - case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) - case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) + case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) + case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) => // Output interval (pre-calculated) val clipLo = tpe.minAdjusted.get @@ -94,13 +95,13 @@ class RemoveIntervals extends Pass { val ltOpt = clipLo <= inLow (gtOpt, ltOpt) match { // input range within output range -> no optimization - case (true, true) => a1 + case (true, true) => a1 case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1) case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1) - case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) + case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1)) } - case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => + case sqz @ DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) => // Using (conditional) reassign interval w/o adding mux val a1tpe = a1.tpe.asInstanceOf[IntervalType] val a2tpe = a2.tpe.asInstanceOf[IntervalType] @@ -117,54 +118,55 @@ class RemoveIntervals extends Pass { val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2))) DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2))) } - case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match { - // If a2 type is Interval wrap around range. If UInt, wrap around width - case t: IntervalType => - // Need to match binary points before getting *adjusted! - val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { - case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) - case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") - } - val (inLo, inHi) = a1.tpe match { - case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) - case _ => sys.error("Shouldn't be here") - } - // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) - val range = wrapHi - wrapLo - val ltOpt = Add(a1, (range + 1).S) - val gtOpt = Sub(a1, (range + 1).S) - // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. - // If x < wl - // output: wh - (wl - x) + 1 AKA x + r + 1 - // worst case: wh - (wl - xl) + 1 = wl - // -> xl + wr + 1 = wl - // If x > wh - // output: wl + (x - wh) - 1 AKA x - r - 1 - // worst case: wl + (xh - wh) - 1 = wh - // -> xh - wr - 1 = wh - val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) - (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { - case (true, true, _, _) => a1 - case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) - case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) - // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) - case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) - case _ => - errors.append(new WrapWithRemainder(info, mname, w)) - default - } - case _ => sys.error("Shouldn't be here") - } + case w @ DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => + a2.tpe match { + // If a2 type is Interval wrap around range. If UInt, wrap around width + case t: IntervalType => + // Need to match binary points before getting *adjusted! + val (wrapLo, wrapHi) = t.copy(point = tpe.point) match { + case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get) + case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType") + } + val (inLo, inHi) = a1.tpe match { + case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get) + case _ => sys.error("Shouldn't be here") + } + // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap) + val range = wrapHi - wrapLo + val ltOpt = Add(a1, (range + 1).S) + val gtOpt = Sub(a1, (range + 1).S) + // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it. + // If x < wl + // output: wh - (wl - x) + 1 AKA x + r + 1 + // worst case: wh - (wl - xl) + 1 = wl + // -> xl + wr + 1 = wl + // If x > wh + // output: wl + (x - wh) - 1 AKA x - r - 1 + // worst case: wl + (xh - wh) - 1 = wh + // -> xh - wr - 1 = wh + val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S) + (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match { + case (true, true, _, _) => a1 + case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1) + case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1) + // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work) + case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1)) + case _ => + errors.append(new WrapWithRemainder(info, mname, w)) + default + } + case _ => sys.error("Shouldn't be here") + } case other => other } } - private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval) private def replaceTypeInterval(t: Type): Type = t match { - case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) + case i @ IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width) case i: IntervalType => sys.error(s"Shouldn't be here: $i") - case v => v map replaceTypeInterval + case v => v.map(replaceTypeInterval) } /** Replace Interval Nodes with Interval Wires @@ -174,15 +176,16 @@ class RemoveIntervals extends Pass { * @param m module to replace nodes with wire + connection * @return */ - private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt + private def makeWireModule(m: DefModule): DefModule = m.map(makeWireStmt) private def makeWireStmt(s: Statement): Statement = s match { - case DefNode(info, name, value) => value.tpe match { - case IntervalType(l, u, p) => - val newType = IntervalType(l, u, p) - Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value))) - case other => s - } - case other => other map makeWireStmt + case DefNode(info, name, value) => + value.tpe match { + case IntervalType(l, u, p) => + val newType = IntervalType(l, u, p) + Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value))) + case other => s + } + case other => other.map(makeWireStmt) } } diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 895cb10f..7e82b37b 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -26,14 +26,13 @@ object RemoveValidIf extends Pass { case ClockType => ClockZero case _: FixedType => FixedZero case AsyncResetType => AsyncZero - case other => throwInternalError(s"Unexpected type $other") + case other => throwInternalError(s"Unexpected type $other") } override def prerequisites = firrtl.stage.Forms.LowForm override def optionalPrerequisiteOf = - Seq( Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case Legalize | _: firrtl.transforms.ConstantPropagation => true @@ -42,24 +41,25 @@ object RemoveValidIf extends Pass { // Recursive. Removes ValidIfs private def onExp(e: Expression): Expression = { - e map onExp match { + e.map(onExp) match { case ValidIf(_, value, _) => value - case x => x + case x => x } } // Recursive. Replaces IsInvalid with connecting zero - private def onStmt(s: Statement): Statement = s map onStmt map onExp match { - case invalid @ IsInvalid(info, loc) => loc.tpe match { - case _: AnalogType => EmptyStmt - case tpe => Connect(info, loc, getGroundZero(tpe)) - } + private def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp) match { + case invalid @ IsInvalid(info, loc) => + loc.tpe match { + case _: AnalogType => EmptyStmt + case tpe => Connect(info, loc, getGroundZero(tpe)) + } case other => other } private def onModule(m: DefModule): DefModule = { m match { - case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) + case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) case m: ExtModule => m } } diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala index e31d9410..4a3cd697 100644 --- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala +++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala @@ -18,15 +18,16 @@ object ReplaceAccesses extends Pass { override def invalidates(a: Transform) = false def run(c: Circuit): Circuit = { - def onStmt(s: Statement): Statement = s map onStmt map onExp - def onExp(e: Expression): Expression = e match { - case WSubAccess(ex, UIntLiteral(value, _), t, g) => ex.tpe match { - case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g) - case _ => e map onExp - } - case _ => e map onExp + def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp) + def onExp(e: Expression): Expression = e match { + case WSubAccess(ex, UIntLiteral(value, _), t, g) => + ex.tpe match { + case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g) + case _ => e.map(onExp) + } + case _ => e.map(onExp) } - c copy (modules = c.modules map (_ map onStmt)) + c.copy(modules = c.modules.map(_.map(onStmt))) } } diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala index 85a0a26f..48b9479c 100644 --- a/src/main/scala/firrtl/passes/ResolveFlows.scala +++ b/src/main/scala/firrtl/passes/ResolveFlows.scala @@ -14,17 +14,22 @@ object ResolveFlows extends Pass { override def invalidates(a: Transform) = false def resolve_e(g: Flow)(e: Expression): Expression = e match { - case ex: WRef => ex copy (flow = g) - case WSubField(exp, name, tpe, _) => WSubField( - Utils.field_flip(exp.tpe, name) match { - case Default => resolve_e(g)(exp) - case Flip => resolve_e(Utils.swap(g))(exp) - }, name, tpe, g) + case ex: WRef => ex.copy(flow = g) + case WSubField(exp, name, tpe, _) => + WSubField( + Utils.field_flip(exp.tpe, name) match { + case Default => resolve_e(g)(exp) + case Flip => resolve_e(Utils.swap(g))(exp) + }, + name, + tpe, + g + ) case WSubIndex(exp, value, tpe, _) => WSubIndex(resolve_e(g)(exp), value, tpe, g) case WSubAccess(exp, index, tpe, _) => WSubAccess(resolve_e(g)(exp), resolve_e(SourceFlow)(index), tpe, g) - case _ => e map resolve_e(g) + case _ => e.map(resolve_e(g)) } def resolve_s(s: Statement): Statement = s match { @@ -35,11 +40,11 @@ object ResolveFlows extends Pass { Connect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr)) case PartialConnect(info, loc, expr) => PartialConnect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr)) - case sx => sx map resolve_e(SourceFlow) map resolve_s + case sx => sx.map(resolve_e(SourceFlow)).map(resolve_s) } - def resolve_flow(m: DefModule): DefModule = m map resolve_s + def resolve_flow(m: DefModule): DefModule = m.map(resolve_s) def run(c: Circuit): Circuit = - c copy (modules = c.modules map resolve_flow) + c.copy(modules = c.modules.map(resolve_flow)) } diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index 67360b74..fcbac163 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -20,21 +20,21 @@ object ResolveKinds extends Pass { } def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match { - case ex: WRef => ex copy (kind = kinds(ex.name)) - case _ => e map resolve_expr(kinds) + case ex: WRef => ex.copy(kind = kinds(ex.name)) + case _ => e.map(resolve_expr(kinds)) } def resolve_stmt(kinds: KindMap)(s: Statement): Statement = { s match { - case sx: DefWire => kinds(sx.name) = WireKind - case sx: DefNode => kinds(sx.name) = NodeKind - case sx: DefRegister => kinds(sx.name) = RegKind + case sx: DefWire => kinds(sx.name) = WireKind + case sx: DefNode => kinds(sx.name) = NodeKind + case sx: DefRegister => kinds(sx.name) = RegKind case sx: WDefInstance => kinds(sx.name) = InstanceKind - case sx: DefMemory => kinds(sx.name) = MemKind + case sx: DefMemory => kinds(sx.name) = MemKind case _ => } s.map(resolve_stmt(kinds)) - .map(resolve_expr(kinds)) + .map(resolve_expr(kinds)) } def resolve_kinds(m: DefModule): DefModule = { @@ -44,5 +44,5 @@ object ResolveKinds extends Pass { } def run(c: Circuit): Circuit = - c copy (modules = c.modules map resolve_kinds) + c.copy(modules = c.modules.map(resolve_kinds)) } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index c536cd5d..a65f8921 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -7,7 +7,7 @@ import firrtl.{SystemVerilogEmitter, Transform, VerilogEmitter} import firrtl.ir._ import firrtl.options.Dependency import firrtl.Mappers._ -import firrtl.Utils.{kind, flow, get_info} +import firrtl.Utils.{flow, get_info, kind} // Datastructures import scala.collection.mutable @@ -17,65 +17,63 @@ import scala.collection.mutable object SplitExpressions extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency(firrtl.passes.memlib.VerilogMemDelays) ) + Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays)) override def optionalPrerequisiteOf = - Seq( Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = a match { case ResolveKinds => true case _ => false } - private def onModule(m: Module): Module = { - val namespace = Namespace(m) - def onStmt(s: Statement): Statement = { - val v = mutable.ArrayBuffer[Statement]() - // Splits current expression if needed - // Adds named temporaries to v - def split(e: Expression): Expression = e match { - case e: DoPrim => - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), flow(e)) - case e: Mux => - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), flow(e)) - case e: ValidIf => - val name = namespace.newTemp - v += DefNode(get_info(s), name, e) - WRef(name, e.tpe, kind(e), flow(e)) - case _ => e - } - - // Recursive. Splits compound nodes - def onExp(e: Expression): Expression = - e map onExp match { - case ex: DoPrim => ex map split - case ex => ex - } + private def onModule(m: Module): Module = { + val namespace = Namespace(m) + def onStmt(s: Statement): Statement = { + val v = mutable.ArrayBuffer[Statement]() + // Splits current expression if needed + // Adds named temporaries to v + def split(e: Expression): Expression = e match { + case e: DoPrim => + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), flow(e)) + case e: Mux => + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), flow(e)) + case e: ValidIf => + val name = namespace.newTemp + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), flow(e)) + case _ => e + } - s map onExp match { - case x: Block => x map onStmt - case EmptyStmt => EmptyStmt - case x => - v += x - v.size match { - case 1 => v.head - case _ => Block(v.toSeq) - } + // Recursive. Splits compound nodes + def onExp(e: Expression): Expression = + e.map(onExp) match { + case ex: DoPrim => ex.map(split) + case ex => ex } + + s.map(onExp) match { + case x: Block => x.map(onStmt) + case EmptyStmt => EmptyStmt + case x => + v += x + v.size match { + case 1 => v.head + case _ => Block(v.toSeq) + } } - Module(m.info, m.name, m.ports, onStmt(m.body)) - } - def run(c: Circuit): Circuit = { - val modulesx = c.modules map { - case m: Module => onModule(m) - case m: ExtModule => m - } - Circuit(c.info, modulesx, c.main) - } + } + Module(m.info, m.name, m.ports, onStmt(m.body)) + } + def run(c: Circuit): Circuit = { + val modulesx = c.modules.map { + case m: Module => onModule(m) + case m: ExtModule => m + } + Circuit(c.info, modulesx, c.main) + } } diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala index c271302a..03faaf3c 100644 --- a/src/main/scala/firrtl/passes/ToWorkingIR.scala +++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala @@ -6,5 +6,5 @@ import firrtl.Transform object ToWorkingIR extends Pass { override def prerequisites = firrtl.stage.Forms.MinimalHighForm override def invalidates(a: Transform) = false - def run(c:Circuit): Circuit = c + def run(c: Circuit): Circuit = c } diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala index 822a8125..0a05bd4e 100644 --- a/src/main/scala/firrtl/passes/TrimIntervals.scala +++ b/src/main/scala/firrtl/passes/TrimIntervals.scala @@ -23,10 +23,7 @@ import firrtl.Transform class TrimIntervals extends Pass { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ResolveFlows), - Dependency[InferBinaryPoints] ) + Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows), Dependency[InferBinaryPoints]) override def optionalPrerequisiteOf = Seq.empty @@ -34,48 +31,51 @@ class TrimIntervals extends Pass { def run(c: Circuit): Circuit = { // Open -> closed - val firstPass = InferTypes.run(c map replaceModuleInterval) + val firstPass = InferTypes.run(c.map(replaceModuleInterval)) // Align binary points and adjust range accordingly (loss of precision changes range) - firstPass map alignModuleBP + firstPass.map(alignModuleBP) } /* Replace interval types */ - private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval + private def replaceModuleInterval(m: DefModule): DefModule = m.map(replaceStmtInterval).map(replacePortInterval) - private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval + private def replaceStmtInterval(s: Statement): Statement = s.map(replaceTypeInterval).map(replaceStmtInterval) - private def replacePortInterval(p: Port): Port = p map replaceTypeInterval + private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval) private def replaceTypeInterval(t: Type): Type = t match { - case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => + case i @ IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) => IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p)) case i: IntervalType => i - case v => v map replaceTypeInterval + case v => v.map(replaceTypeInterval) } /* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */ - private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP - - private def alignStmtBP(s: Statement): Statement = s map alignExpBP match { - case c@Connect(info, loc, expr) => loc.tpe match { - case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) - case _ => c - } - case c@PartialConnect(info, loc, expr) => loc.tpe match { - case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) - case _ => c - } - case other => other map alignStmtBP + private def alignModuleBP(m: DefModule): DefModule = m.map(alignStmtBP) + + private def alignStmtBP(s: Statement): Statement = s.map(alignExpBP) match { + case c @ Connect(info, loc, expr) => + loc.tpe match { + case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr)) + case _ => c + } + case c @ PartialConnect(info, loc, expr) => + loc.tpe match { + case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr)) + case _ => c + } + case other => other.map(alignStmtBP) } // Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned // Note - Mul does not need its binary points aligned, because multiplication is cool like that - private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/) + private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq /*, Wrap, Clip, Squeeze*/ ) - private def alignExpBP(e: Expression): Expression = e map alignExpBP match { + private def alignExpBP(e: Expression): Expression = e.map(alignExpBP) match { case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg) - case DoPrim(o, args, consts, t) if opsToFix.contains(o) && - (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => + case DoPrim(o, args, consts, t) + if opsToFix.contains(o) && + (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size => val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _) DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t) case Mux(cond, tval, fval, t: IntervalType) => @@ -85,9 +85,9 @@ class TrimIntervals extends Pass { } private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match { case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e - case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current => DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired))) - case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => + case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current => val shiftAmt = current - desired val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt) val shiftMul = Closed(BigDecimal(1) / shiftGain) diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index b9cd32fa..10198b33 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -2,7 +2,6 @@ package firrtl.passes - import scala.annotation.tailrec import firrtl._ import firrtl.ir._ @@ -35,12 +34,11 @@ import MemPortUtils.memType object Uniquify extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(ResolveKinds), - Dependency(InferTypes) ) ++ firrtl.stage.Forms.WorkingIR + Seq(Dependency(ResolveKinds), Dependency(InferTypes)) ++ firrtl.stage.Forms.WorkingIR override def invalidates(a: Transform): Boolean = a match { case ResolveKinds | InferTypes => true - case _ => false + case _ => false } private case class UniquifyException(msg: String) extends FirrtlInternalException(msg) @@ -55,12 +53,13 @@ object Uniquify extends Transform with DependencyAPIMigration { */ @tailrec def findValidPrefix( - prefix: String, - elts: Seq[String], - namespace: collection.mutable.HashSet[String]): String = { - elts find (elt => namespace.contains(prefix + elt)) match { + prefix: String, + elts: Seq[String], + namespace: collection.mutable.HashSet[String] + ): String = { + elts.find(elt => namespace.contains(prefix + elt)) match { case Some(_) => findValidPrefix(prefix + "_", elts, namespace) - case None => prefix + case None => prefix } } @@ -70,16 +69,16 @@ object Uniquify extends Transform with DependencyAPIMigration { * => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, foo bar 1 a, foo bar 1 b, foo c * }}} */ - private [firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match { + private[firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match { case t: BundleType => - t.fields flatMap { f => - (enumerateNames(f.tpe) map (f.name +: _)) ++ Seq(Seq(f.name)) + t.fields.flatMap { f => + (enumerateNames(f.tpe).map(f.name +: _)) ++ Seq(Seq(f.name)) } case t: VectorType => - ((0 until t.size) map (i => Seq(i.toString))) ++ - ((0 until t.size) flatMap { i => - enumerateNames(t.tpe) map (i.toString +: _) - }) + ((0 until t.size).map(i => Seq(i.toString))) ++ + ((0 until t.size).flatMap { i => + enumerateNames(t.tpe).map(i.toString +: _) + }) case _ => Seq() } @@ -87,27 +86,38 @@ object Uniquify extends Transform with DependencyAPIMigration { def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { // Recursive helper def recStmtToType(s: Statement): Seq[Field] = s match { - case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) + case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe)) case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe)) case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe)) - case sx: DefMemory => sx.dataType match { - case (_: UIntType | _: SIntType | _: FixedType) => - Seq(Field(sx.name, Default, memType(sx))) - case tpe: BundleType => - val newFields = tpe.fields map ( f => - DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency, - sx.readLatency, sx.readers, sx.writers, sx.readwriters) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - case tpe: VectorType => - val newFields = (0 until tpe.size) map ( i => - sx.copy(name = i.toString, dataType = tpe.tpe) - ) flatMap recStmtToType - Seq(Field(sx.name, Default, BundleType(newFields))) - } - case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) + case sx: DefMemory => + sx.dataType match { + case (_: UIntType | _: SIntType | _: FixedType) => + Seq(Field(sx.name, Default, memType(sx))) + case tpe: BundleType => + val newFields = tpe.fields + .map(f => + DefMemory( + sx.info, + f.name, + f.tpe, + sx.depth, + sx.writeLatency, + sx.readLatency, + sx.readers, + sx.writers, + sx.readwriters + ) + ) + .flatMap(recStmtToType) + Seq(Field(sx.name, Default, BundleType(newFields))) + case tpe: VectorType => + val newFields = + (0 until tpe.size).map(i => sx.copy(name = i.toString, dataType = tpe.tpe)).flatMap(recStmtToType) + Seq(Field(sx.name, Default, BundleType(newFields))) + } + case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe)) case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt) - case sx: Block => (sx.stmts map recStmtToType).flatten + case sx: Block => (sx.stmts.map(recStmtToType)).flatten case sx => Seq() } BundleType(recStmtToType(s)) @@ -116,40 +126,44 @@ object Uniquify extends Transform with DependencyAPIMigration { // Accepts a Type and an initial namespace // Returns new Type with uniquified names private def uniquifyNames( - t: BundleType, - namespace: collection.mutable.HashSet[String]) - (implicit sinfo: Info, mname: String): BundleType = { + t: BundleType, + namespace: collection.mutable.HashSet[String] + )( + implicit sinfo: Info, + mname: String + ): BundleType = { def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match { case tx: BundleType => // First add everything - val newFieldsAndElts = tx.fields map { f => + val newFieldsAndElts = tx.fields.map { f => val newName = findValidPrefix(f.name, Seq(""), namespace) namespace += newName Field(newName, f.flip, f.tpe) - } map { f => f.tpe match { - case _: GroundType => (f, Seq[String](f.name)) - case _ => - val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet()) - // Need leading _ for findValidPrefix, it doesn't add _ for checks - val eltsNames: Seq[String] = eltsx map (e => "_" + e) - val prefix = findValidPrefix(f.name, eltsNames, namespace) - // We added f.name in previous map, delete if we change it - if (prefix != f.name) { - namespace -= f.name - namespace += prefix - } - val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e))) - namespace ++= newElts - (Field(prefix, f.flip, tpe), prefix +: newElts) + }.map { f => + f.tpe match { + case _: GroundType => (f, Seq[String](f.name)) + case _ => + val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet()) + // Need leading _ for findValidPrefix, it doesn't add _ for checks + val eltsNames: Seq[String] = eltsx.map(e => "_" + e) + val prefix = findValidPrefix(f.name, eltsNames, namespace) + // We added f.name in previous map, delete if we change it + if (prefix != f.name) { + namespace -= f.name + namespace += prefix + } + val newElts: Seq[String] = eltsx.map(e => LowerTypes.loweredName(prefix +: Seq(e))) + namespace ++= newElts + (Field(prefix, f.flip, tpe), prefix +: newElts) } } val (newFields, elts) = newFieldsAndElts.unzip (BundleType(newFields), elts.flatten) case tx: VectorType => val (tpe, elts) = recUniquifyNames(tx.tpe, namespace) - val newElts = ((0 until tx.size) map (i => i.toString)) ++ - ((0 until tx.size) flatMap { i => - elts map (e => LowerTypes.loweredName(Seq(i.toString, e))) + val newElts = ((0 until tx.size).map(i => i.toString)) ++ + ((0 until tx.size).flatMap { i => + elts.map(e => LowerTypes.loweredName(Seq(i.toString, e))) }) (VectorType(tpe, tx.size), newElts) case tx => (tx, Nil) @@ -164,19 +178,26 @@ object Uniquify extends Transform with DependencyAPIMigration { // Creates a mapping from flattened references to members of $from -> // flattened references to members of $to private def createNameMapping( - from: Type, - to: Type) - (implicit sinfo: Info, mname: String): Map[String, NameMapNode] = { + from: Type, + to: Type + )( + implicit sinfo: Info, + mname: String + ): Map[String, NameMapNode] = { (from, to) match { case (fromx: BundleType, tox: BundleType) => - (fromx.fields zip tox.fields flatMap { case (f, t) => - val eltsMap = createNameMapping(f.tpe, t.tpe) - if ((f.name != t.name) || eltsMap.nonEmpty) { - Map(f.name -> NameMapNode(t.name, eltsMap)) - } else { - Map[String, NameMapNode]() - } - }).toMap + (fromx.fields + .zip(tox.fields) + .flatMap { + case (f, t) => + val eltsMap = createNameMapping(f.tpe, t.tpe) + if ((f.name != t.name) || eltsMap.nonEmpty) { + Map(f.name -> NameMapNode(t.name, eltsMap)) + } else { + Map[String, NameMapNode]() + } + }) + .toMap case (fromx: VectorType, tox: VectorType) => createNameMapping(fromx.tpe, tox.tpe) case (fromx, tox) => @@ -187,18 +208,19 @@ object Uniquify extends Transform with DependencyAPIMigration { // Maps names in expression to new uniquified names private def uniquifyNamesExp( - exp: Expression, - map: Map[String, NameMapNode]) - (implicit sinfo: Info, mname: String): Expression = { + exp: Expression, + map: Map[String, NameMapNode] + )( + implicit sinfo: Info, + mname: String + ): Expression = { // Recursive Helper - def rec(exp: Expression, m: Map[String, NameMapNode]): - (Expression, Map[String, NameMapNode]) = exp match { + def rec(exp: Expression, m: Map[String, NameMapNode]): (Expression, Map[String, NameMapNode]) = exp match { case e: WRef => if (m.contains(e.name)) { val node = m(e.name) (WRef(node.name, e.tpe, e.kind, e.flow), node.elts) - } - else (e, Map()) + } else (e, Map()) case e: WSubField => val (subExp, subMap) = rec(e.expr, m) val (retName, retMap) = @@ -218,18 +240,21 @@ object Uniquify extends Transform with DependencyAPIMigration { (WSubAccess(subExp, index, e.tpe, e.flow), subMap) case (_: UIntLiteral | _: SIntLiteral) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => - (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) + (exp.map((e: Expression) => uniquifyNamesExp(e, map)), m) } rec(exp, map)._1 } // Uses map to recursively rename fields of tpe private def uniquifyNamesType( - tpe: Type, - map: Map[String, NameMapNode]) - (implicit sinfo: Info, mname: String): Type = tpe match { + tpe: Type, + map: Map[String, NameMapNode] + )( + implicit sinfo: Info, + mname: String + ): Type = tpe match { case t: BundleType => - val newFields = t.fields map { f => + val newFields = t.fields.map { f => if (map.contains(f.name)) { val node = map(f.name) Field(node.name, f.flip, uniquifyNamesType(f.tpe, node.elts)) @@ -244,8 +269,11 @@ object Uniquify extends Transform with DependencyAPIMigration { } // Everything wrapped in run so that it's thread safe - @deprecated("The functionality of Uniquify is now part of LowerTypes." + - "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", "Firrtl 1.4") + @deprecated( + "The functionality of Uniquify is now part of LowerTypes." + + "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", + "Firrtl 1.4" + ) def execute(state: CircuitState): CircuitState = { val c = state.circuit val renames = RenameMap() @@ -263,22 +291,22 @@ object Uniquify extends Transform with DependencyAPIMigration { val nameMap = collection.mutable.HashMap[String, NameMapNode]() def uniquifyExp(e: Expression): Expression = e match { - case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess ) => + case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess) => uniquifyNamesExp(e, nameMap.toMap) - case e: Mux => e map uniquifyExp - case e: ValidIf => e map uniquifyExp + case e: Mux => e.map(uniquifyExp) + case e: ValidIf => e.map(uniquifyExp) case (_: UIntLiteral | _: SIntLiteral) => e - case e: DoPrim => e map uniquifyExp + case e: DoPrim => e.map(uniquifyExp) } def uniquifyStmt(s: Statement): Statement = { - s map uniquifyStmt map uniquifyExp match { + s.map(uniquifyStmt).map(uniquifyExp) match { case sx: DefWire => sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) val newType = uniquifyNamesType(sx.tpe, node.elts) - (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } DefWire(sx.info, node.name, newType) @@ -290,7 +318,7 @@ object Uniquify extends Transform with DependencyAPIMigration { if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) val newType = uniquifyNamesType(sx.tpe, node.elts) - (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init) @@ -302,7 +330,7 @@ object Uniquify extends Transform with DependencyAPIMigration { if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) val newType = portTypeMap(m.name) - (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } WDefInstance(sx.info, node.name, sx.module, newType) @@ -317,7 +345,7 @@ object Uniquify extends Transform with DependencyAPIMigration { val mem = sx.copy(name = node.name, dataType = dataType) // Create new mapping to handle references to memory data fields val uniqueMemMap = createNameMapping(memType(sx), memType(mem)) - (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach { + (Utils.create_exps(sx.name, memType(sx)).zip(Utils.create_exps(node.name, memType(mem)))).foreach { case (from, to) => renames.rename(from.serialize, to.serialize) } nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap) @@ -329,9 +357,12 @@ object Uniquify extends Transform with DependencyAPIMigration { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach { - case (from, to) => renames.rename(from.serialize, to.serialize) - } + (Utils + .create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) + .zip(Utils.create_exps(node.name, sx.value.tpe))) + .foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } DefNode(sx.info, node.name, sx.value) } else { sx @@ -354,19 +385,18 @@ object Uniquify extends Transform with DependencyAPIMigration { mname = m.name m match { case m: ExtModule => m - case m: Module => + case m: Module => // Adds port names to namespace and namemap nameMap ++= portNameMap(m.name) - namespace ++= create_exps("", portTypeMap(m.name)) map - LowerTypes.loweredName map (_.tail) - m.copy(body = uniquifyBody(m.body) ) + namespace ++= create_exps("", portTypeMap(m.name)).map(LowerTypes.loweredName).map(_.tail) + m.copy(body = uniquifyBody(m.body)) } } def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = { renames.setModule(m.name) def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { - val portsType = BundleType(ports map { + val portsType = BundleType(ports.map { case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) }) val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet()) @@ -374,11 +404,12 @@ object Uniquify extends Transform with DependencyAPIMigration { portNameMap += (m.name -> localMap) portTypeMap += (m.name -> uniquePortsType) - ports zip uniquePortsType.fields map { case (p, f) => - (Utils.create_exps(p.name, p.tpe) zip Utils.create_exps(f.name, f.tpe)) foreach { - case (from, to) => renames.rename(from.serialize, to.serialize) - } - Port(p.info, f.name, p.direction, f.tpe) + ports.zip(uniquePortsType.fields).map { + case (p, f) => + (Utils.create_exps(p.name, p.tpe).zip(Utils.create_exps(f.name, f.tpe))).foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + Port(p.info, f.name, p.direction, f.tpe) } } @@ -386,12 +417,12 @@ object Uniquify extends Transform with DependencyAPIMigration { mname = m.name m match { case m: ExtModule => m.copy(ports = uniquifyPorts(m.ports)) - case m: Module => m.copy(ports = uniquifyPorts(m.ports)) + case m: Module => m.copy(ports = uniquifyPorts(m.ports)) } } sinfo = c.info - val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main) + val result = Circuit(c.info, c.modules.map(uniquifyPorts(renames)).map(uniquifyModule(renames)), c.main) state.copy(circuit = result, renames = Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala index 36eff379..0b046a5f 100644 --- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala +++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala @@ -12,28 +12,30 @@ import firrtl.options.Dependency import scala.collection.mutable /** - * Verilog has the width of (a % b) = Max(W(a), W(b)) - * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense, - * but nevertheless is a problem when emitting verilog - * - * This pass finds every instance of (a % b) and: - * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b)) - * 2) replaces the reference to (a % b) with a bitslice of the temporary node - * to get back down to width Min(W(a), W(b)) - * - * This is technically incorrect firrtl, but allows the verilog emitter - * to emit correct verilog without needing to add temporary nodes - */ + * Verilog has the width of (a % b) = Max(W(a), W(b)) + * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense, + * but nevertheless is a problem when emitting verilog + * + * This pass finds every instance of (a % b) and: + * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b)) + * 2) replaces the reference to (a % b) with a bitslice of the temporary node + * to get back down to width Min(W(a), W(b)) + * + * This is technically incorrect firrtl, but allows the verilog emitter + * to emit correct verilog without needing to add temporary nodes + */ object VerilogModulusCleanup extends Pass { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -51,32 +53,35 @@ object VerilogModulusCleanup extends Pass { case t => UnknownWidth } - def maxWidth(ws: Seq[Width]): Width = ws reduceLeft { (x,y) => (x,y) match { - case (IntWidth(x), IntWidth(y)) => IntWidth(x max y) - case (x, y) => UnknownWidth - }} + def maxWidth(ws: Seq[Width]): Width = ws.reduceLeft { (x, y) => + (x, y) match { + case (IntWidth(x), IntWidth(y)) => IntWidth(x.max(y)) + case (x, y) => UnknownWidth + } + } def verilogRemWidth(e: DoPrim)(tpe: Type): Type = { val newWidth = maxWidth(e.args.map(exp => getWidth(exp))) - tpe mapWidth (w => newWidth) + tpe.mapWidth(w => newWidth) } def removeRem(e: Expression): Expression = e match { - case e: DoPrim => e.op match { - case Rem => - val name = namespace.newTemp - val newType = e mapType verilogRemWidth(e) - v += DefNode(get_info(s), name, e mapType verilogRemWidth(e)) - val remRef = WRef(name, newType.tpe, kind(e), flow(e)) - val remWidth = bitWidth(e.tpe) - DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe) - case _ => e - } + case e: DoPrim => + e.op match { + case Rem => + val name = namespace.newTemp + val newType = e.mapType(verilogRemWidth(e)) + v += DefNode(get_info(s), name, e.mapType(verilogRemWidth(e))) + val remRef = WRef(name, newType.tpe, kind(e), flow(e)) + val remWidth = bitWidth(e.tpe) + DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe) + case _ => e + } case _ => e } - s map removeRem match { - case x: Block => x map onStmt + s.map(removeRem) match { + case x: Block => x.map(onStmt) case EmptyStmt => EmptyStmt case x => v += x @@ -90,8 +95,8 @@ object VerilogModulusCleanup extends Pass { } def run(c: Circuit): Circuit = { - val modules = c.modules map { - case m: Module => onModule(m) + val modules = c.modules.map { + case m: Module => onModule(m) case m: ExtModule => m } Circuit(c.info, modules, c.main) diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala index 03d47cfc..eeb34fa9 100644 --- a/src/main/scala/firrtl/passes/VerilogPrep.scala +++ b/src/main/scala/firrtl/passes/VerilogPrep.scala @@ -21,15 +21,17 @@ import scala.collection.mutable object VerilogPrep extends Pass { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -46,9 +48,9 @@ object VerilogPrep extends Pass { val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression] lazy val namespace = Namespace(m) - def onStmt(stmt: Statement): Statement = stmt map onStmt match { + def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { case attach: Attach => - val wires = attach.exprs groupBy kind + val wires = attach.exprs.groupBy(kind) val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty) val instPorts = wires.getOrElse(InstanceKind, Seq.empty) // Sanity check (Should be caught by CheckTypes) @@ -71,14 +73,14 @@ object VerilogPrep extends Pass { case s => s } - (m map onStmt, sourceMap.toMap) + (m.map(onStmt), sourceMap.toMap) } def run(c: Circuit): Circuit = { def lowerE(e: Expression): Expression = e match { case (_: WRef | _: WSubField) if kind(e) == InstanceKind => WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e)) - case _ => e map lowerE + case _ => e.map(lowerE) } def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match { @@ -96,12 +98,12 @@ object VerilogPrep extends Pass { }.unzip val newInst = WDefInstanceConnector(info, name, module, tpe, portCons) Block(wires.flatten :+ newInst) - case other => other map lowerS(attachMap) map lowerE + case other => other.map(lowerS(attachMap)).map(lowerE) } - val modulesx = c.modules map { mod => + val modulesx = c.modules.map { mod => val (modx, attachMap) = collectAndRemoveAttach(mod) - modx map lowerS(attachMap) + modx.map(lowerS(attachMap)) } c.copy(modules = modulesx) } diff --git a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala index 39c127de..e61780a4 100644 --- a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala +++ b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala @@ -17,10 +17,7 @@ import firrtl.options.Dependency */ object ZeroLengthVecs extends Pass { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ResolveKinds), - Dependency(InferTypes), - Dependency(ExpandConnects) ) + Seq(Dependency(PullMuxes), Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ExpandConnects)) override def invalidates(a: Transform) = false @@ -28,8 +25,8 @@ object ZeroLengthVecs extends Pass { // interval type with the type alone unless you declare a component private def replaceWithDontCare(toReplace: Expression): Expression = { val default = toReplace.tpe match { - case UIntType(w) => UIntLiteral(0, w) - case SIntType(w) => SIntLiteral(0, w) + case UIntType(w) => UIntLiteral(0, w) + case SIntType(w) => SIntLiteral(0, w) case FixedType(w, p) => FixedLiteral(0, w, p) case it: IntervalType => val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0)) @@ -40,11 +37,11 @@ object ZeroLengthVecs extends Pass { } private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match { - case (_, VectorType(_, 0)) => true - case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case (_, VectorType(_, 0)) => true + case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e) case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e) - case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e) - case _ => false + case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e) + case _ => false } // The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex @@ -52,13 +49,13 @@ object ZeroLengthVecs extends Pass { private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match { case _: WSubIndex | _: WSubAccess | _: WSubField => if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr - case e => e map dropZeroLenSubAccesses + case e => e.map(dropZeroLenSubAccesses) } // Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial private def onStmt(stmt: Statement): Statement = stmt match { case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt - case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt + case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt case Attach(info, sinks) => val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike)) if (filtered.exprs.length < 2) EmptyStmt else filtered diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index 56d66ef0..82321f95 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -11,12 +11,14 @@ import firrtl.options.Dependency object ZeroWidth extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses), - Dependency(ExpandConnects), - Dependency(RemoveAccesses), - Dependency[ExpandWhensAndCheck], - Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped + Seq( + Dependency(PullMuxes), + Dependency(ReplaceAccesses), + Dependency(ExpandConnects), + Dependency(RemoveAccesses), + Dependency[ExpandWhensAndCheck], + Dependency(ConvertFixedToSInt) + ) ++ firrtl.stage.Forms.Deduped override def invalidates(a: Transform): Boolean = a match { case InferTypes => true @@ -24,30 +26,41 @@ object ZeroWidth extends Transform with DependencyAPIMigration { } private def makeEmptyMemBundle(name: String): Field = - Field(name, Flip, BundleType(Seq( - Field("addr", Default, UIntType(IntWidth(0))), - Field("en", Default, UIntType(IntWidth(0))), - Field("clk", Default, UIntType(IntWidth(0))), - Field("data", Flip, UIntType(IntWidth(0))) - ))) + Field( + name, + Flip, + BundleType( + Seq( + Field("addr", Default, UIntType(IntWidth(0))), + Field("en", Default, UIntType(IntWidth(0))), + Field("clk", Default, UIntType(IntWidth(0))), + Field("data", Flip, UIntType(IntWidth(0))) + ) + ) + ) private def onEmptyMemStmt(s: Statement): Statement = s match { - case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => removeZero(tpe) match { - case None => - DefWire(info, name, BundleType( - rs.map(r => makeEmptyMemBundle(r)) ++ - ws.map(w => makeEmptyMemBundle(w)) ++ - rws.map(rw => makeEmptyMemBundle(rw)) - )) - case Some(_) => d - } - case sx => sx map onEmptyMemStmt + case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => + removeZero(tpe) match { + case None => + DefWire( + info, + name, + BundleType( + rs.map(r => makeEmptyMemBundle(r)) ++ + ws.map(w => makeEmptyMemBundle(w)) ++ + rws.map(rw => makeEmptyMemBundle(rw)) + ) + ) + case Some(_) => d + } + case sx => sx.map(onEmptyMemStmt) } private def onModuleEmptyMemStmt(m: DefModule): DefModule = { m match { case ext: ExtModule => ext - case in: Module => in.copy(body = onEmptyMemStmt(in.body)) + case in: Module => in.copy(body = onEmptyMemStmt(in.body)) } } @@ -59,20 +72,20 @@ object ZeroWidth extends Transform with DependencyAPIMigration { * This replaces memories with a DefWire() bundle that contains the address, en, * clk, and data fields implemented as zero width wires. Running the rest of the ZeroWidth * transform will remove these dangling references properly. - * */ def executeEmptyMemStmt(state: CircuitState): CircuitState = { val c = state.circuit - val result = c.copy(modules = c.modules map onModuleEmptyMemStmt) + val result = c.copy(modules = c.modules.map(onModuleEmptyMemStmt)) state.copy(circuit = result) } // This is slightly different and specialized version of create_exps, TODO unify? private def findRemovable(expr: => Expression, tpe: Type): Seq[Expression] = tpe match { - case GroundType(width) => width match { - case IntWidth(ZERO) => List(expr) - case _ => List.empty - } + case GroundType(width) => + width match { + case IntWidth(ZERO) => List(expr) + case _ => List.empty + } case BundleType(fields) => if (fields.isEmpty) List(expr) else fields.flatMap(f => findRemovable(WSubField(expr, f.name, f.tpe, SourceFlow), f.tpe)) @@ -95,7 +108,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { t } x match { - case s: Statement => s map onType(s.name) + case s: Statement => s.map(onType(s.name)) case Port(_, name, _, t) => onType(name)(t) } removedNames @@ -103,14 +116,14 @@ object ZeroWidth extends Transform with DependencyAPIMigration { private[passes] def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None case BundleType(fields) => - fields map (f => (f, removeZero(f.tpe))) collect { + fields.map(f => (f, removeZero(f.tpe))).collect { case (Field(name, flip, _), Some(t)) => Field(name, flip, t) } match { case Nil => None case seq => Some(BundleType(seq)) } - case VectorType(t, size) => removeZero(t) map (VectorType(_, size)) - case x => Some(x) + case VectorType(t, size) => removeZero(t).map(VectorType(_, size)) + case x => Some(x) } private def onExp(e: Expression): Expression = e match { case DoPrim(Cat, args, consts, tpe) => @@ -118,26 +131,27 @@ object ZeroWidth extends Transform with DependencyAPIMigration { x.tpe match { case UIntType(IntWidth(ZERO)) => Seq.empty[Expression] case SIntType(IntWidth(ZERO)) => Seq.empty[Expression] - case other => Seq(x) + case other => Seq(x) } } nonZeros match { - case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1))) + case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1))) case Seq(x) => x - case seq => DoPrim(Cat, seq, consts, tpe) map onExp + case seq => DoPrim(Cat, seq, consts, tpe).map(onExp) } case DoPrim(Andr, Seq(x), _, _) if (bitWidth(x.tpe) == 0) => UIntLiteral(1) // nothing false - case other => other.tpe match { - case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1))) - case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1))) - case _ => e map onExp - } + case other => + other.tpe match { + case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1))) + case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1))) + case _ => e.map(onExp) + } } private def onStmt(renames: RenameMap)(s: Statement): Statement = s match { case d @ DefWire(info, name, tpe) => renames.delete(getRemoved(d)) removeZero(tpe) match { - case None => EmptyStmt + case None => EmptyStmt case Some(t) => DefWire(info, name, t) } case d @ DefRegister(info, name, tpe, clock, reset, init) => @@ -145,7 +159,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { removeZero(tpe) match { case None => EmptyStmt case Some(t) => - DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init)) + DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init)) } case d: DefMemory => renames.delete(getRemoved(d)) @@ -154,25 +168,28 @@ object ZeroWidth extends Transform with DependencyAPIMigration { Utils.throwInternalError(s"private pass ZeroWidthMemRemove should have removed this memory: $d") case Some(t) => d.copy(dataType = t) } - case Connect(info, loc, exp) => removeZero(loc.tpe) match { - case None => EmptyStmt - case Some(t) => Connect(info, loc, onExp(exp)) - } - case IsInvalid(info, exp) => removeZero(exp.tpe) match { - case None => EmptyStmt - case Some(t) => IsInvalid(info, onExp(exp)) - } - case DefNode(info, name, value) => removeZero(value.tpe) match { - case None => EmptyStmt - case Some(t) => DefNode(info, name, onExp(value)) - } - case sx => sx map onStmt(renames) map onExp + case Connect(info, loc, exp) => + removeZero(loc.tpe) match { + case None => EmptyStmt + case Some(t) => Connect(info, loc, onExp(exp)) + } + case IsInvalid(info, exp) => + removeZero(exp.tpe) match { + case None => EmptyStmt + case Some(t) => IsInvalid(info, onExp(exp)) + } + case DefNode(info, name, value) => + removeZero(value.tpe) match { + case None => EmptyStmt + case Some(t) => DefNode(info, name, onExp(value)) + } + case sx => sx.map(onStmt(renames)).map(onExp) } private def onModule(renames: RenameMap)(m: DefModule): DefModule = { renames.setModule(m.name) // For each port, record deleted subcomponents - m.ports.foreach{p => renames.delete(getRemoved(p))} - val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap { + m.ports.foreach { p => renames.delete(getRemoved(p)) } + val ports = m.ports.map(p => (p, removeZero(p.tpe))).flatMap { case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t)) case (Port(_, name, _, _), None) => renames.delete(name) @@ -180,7 +197,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { } m match { case ext: ExtModule => ext.copy(ports = ports) - case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body)) + case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body)) } } def execute(state: CircuitState): CircuitState = { @@ -189,7 +206,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration { val c = InferTypes.run(executeEmptyMemStmt(state).circuit) val renames = RenameMap() renames.setCircuit(c.main) - val result = c.copy(modules = c.modules map onModule(renames)) + val result = c.copy(modules = c.modules.map(onModule(renames))) CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala index c2323d4c..bfc03b51 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala @@ -13,8 +13,8 @@ import Utils._ import memlib.AnalysisUtils._ /** Starting with a top module, determine the clock origins of each child instance. - * Write the result to writer. - */ + * Write the result to writer. + */ class ClockList(top: String, writer: Writer) extends Pass { def run(c: Circuit): Circuit = { // Build useful datastructures @@ -29,7 +29,7 @@ class ClockList(top: String, writer: Writer) extends Pass { // Clock sources must be blackbox outputs and top's clock val partialSourceList = getSourceList(moduleMap)(lineages) - val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n } + val sourceList = partialSourceList ++ moduleMap(top).ports.collect { case Port(i, n, Input, ClockType) => n } writer.append(s"Sourcelist: $sourceList \n") // Remove everything from the circuit, unless it has a clock type @@ -37,8 +37,9 @@ class ClockList(top: String, writer: Writer) extends Pass { val onlyClockCircuit = RemoveAllButClocks.run(c) // Inline the clock-only circuit up to the specified top module - val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet - val inlineTransform = new InlineInstances{ override val inlineDelim = "$" } + val modulesToInline = + (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet + val inlineTransform = new InlineInstances { override val inlineDelim = "$" } val inlinedCircuit = inlineTransform.run(onlyClockCircuit, modulesToInline, Set(), Seq()).circuit val topModule = inlinedCircuit.modules.find(_.name == top).getOrElse(throwInternalError("no top module")) @@ -49,13 +50,14 @@ class ClockList(top: String, writer: Writer) extends Pass { val origins = getOrigins(connects, "", moduleMap)(lineages) // If the clock origin is contained in the source list, label good (otherwise bad) - origins.foreach { case (instance, origin) => - val sep = if(instance == "") "" else "." - if(!sourceList.contains(origin.replace('.','$'))){ - outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n") - } else { - outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n") - } + origins.foreach { + case (instance, origin) => + val sep = if (instance == "") "" else "." + if (!sourceList.contains(origin.replace('.', '$'))) { + outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n") + } else { + outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n") + } } // Write to output file diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala index e6617857..468ba905 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala @@ -12,8 +12,7 @@ import memlib._ import firrtl.options.{RegisteredTransform, ShellOption} import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} -case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends - SingleTargetAnnotation[ModuleName] { +case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = ClockListAnnotation(n, outputConfig) } @@ -44,7 +43,7 @@ Usage: ) passOptions.get(InputConfigFileName) match { case Some(x) => error("Unneeded input config file name!" + usage) - case None => + case None => } val target = ModuleName(passModule, CircuitName(passCircuit)) ClockListAnnotation(target, outputConfig) @@ -53,18 +52,20 @@ Usage: class ClockListTransform extends Transform with DependencyAPIMigration with RegisteredTransform { - override def prerequisites = Forms.LowForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Forms.LowEmitters + override def prerequisites = Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.LowEmitters val options = Seq( new ShellOption[String]( longOption = "list-clocks", - toAnnotationSeq = (a: String) => Seq( passes.clocklist.ClockListAnnotation.parse(a), - RunFirrtlTransformAnnotation(new ClockListTransform) ), + toAnnotationSeq = (a: String) => + Seq(passes.clocklist.ClockListAnnotation.parse(a), RunFirrtlTransformAnnotation(new ClockListTransform)), helpText = "List which signal drives each clock of every descendent of specified modules", shortOption = Some("clks"), - helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>") ) ) + helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>") + ) + ) def passSeq(top: String, writer: Writer): Seq[Pass] = Seq(new ClockList(top, writer)) diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala index b77629fc..00e07588 100644 --- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala +++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala @@ -10,45 +10,56 @@ import Utils._ import memlib.AnalysisUtils._ object ClockListUtils { + /** Returns a list of clock outputs from instances of external modules - */ + */ def getSourceList(moduleMap: Map[String, DefModule])(lin: Lineage): Seq[String] = { - val s = lin.foldLeft(Seq[String]()){case (sL, (i, l)) => - val sLx = getSourceList(moduleMap)(l) - val sLxx = sLx map (i + "$" + _) - sL ++ sLxx + val s = lin.foldLeft(Seq[String]()) { + case (sL, (i, l)) => + val sLx = getSourceList(moduleMap)(l) + val sLxx = sLx.map(i + "$" + _) + sL ++ sLxx } val sourceList = moduleMap(lin.name) match { case ExtModule(i, n, ports, dn, p) => - val portExps = ports.flatMap{p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction)))} + val portExps = ports.flatMap { p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction))) } portExps.filter(e => (e.tpe == ClockType) && (flow(e) == SinkFlow)).map(_.serialize) case _ => Nil } val sx = sourceList ++ s sx } + /** Returns a map from instance name to its clock origin. - * Child instances are not included if they share the same clock as their parent - */ - def getOrigins(connects: Connects, me: String, moduleMap: Map[String, DefModule])(lin: Lineage): Map[String, String] = { - val sep = if(me == "") "" else "$" + * Child instances are not included if they share the same clock as their parent + */ + def getOrigins( + connects: Connects, + me: String, + moduleMap: Map[String, DefModule] + )(lin: Lineage + ): Map[String, String] = { + val sep = if (me == "") "" else "$" // Get origins from all children - val childrenOrigins = lin.foldLeft(Map[String, String]()){case (o, (i, l)) => - o ++ getOrigins(connects, me + sep + i, moduleMap)(l) + val childrenOrigins = lin.foldLeft(Map[String, String]()) { + case (o, (i, l)) => + o ++ getOrigins(connects, me + sep + i, moduleMap)(l) } // If I have a clock, get it val clockOpt = moduleMap(lin.name) match { - case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" } + case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" } case ExtModule(i, n, ports, dn, p) => None } // Return new origins with direct children removed, if they match my clock clockOpt match { case Some(clock) => val myOrigin = getOrigin(connects, clock).serialize - childrenOrigins.foldLeft(Map(me -> myOrigin)) { case (o, (childInstance, childOrigin)) => - val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance } - // If direct child shares my origin, omit it - if(childOrigin == myOrigin && childrenInstances.contains(childInstance)) o else o + (childInstance -> childOrigin) + childrenOrigins.foldLeft(Map(me -> myOrigin)) { + case (o, (childInstance, childOrigin)) => + val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance } + // If direct child shares my origin, omit it + if (childOrigin == myOrigin && childrenInstances.contains(childInstance)) o + else o + (childInstance -> childOrigin) } case None => childrenOrigins } diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala index 6eb8c138..d72bc293 100644 --- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala +++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala @@ -9,22 +9,22 @@ import Utils._ import Mappers._ /** Remove all statements and ports (except instances/whens/blocks) whose - * expressions do not relate to ground types. - */ + * expressions do not relate to ground types. + */ object RemoveAllButClocks extends Pass { - def onStmt(s: Statement): Statement = (s map onStmt) match { - case DefWire(i, n, ClockType) => s + def onStmt(s: Statement): Statement = (s.map(onStmt)) match { + case DefWire(i, n, ClockType) => s case DefNode(i, n, value) if value.tpe == ClockType => s - case Connect(i, l, r) if l.tpe == ClockType => s - case sx: WDefInstance => sx - case sx: DefInstance => sx - case sx: Block => sx + case Connect(i, l, r) if l.tpe == ClockType => s + case sx: WDefInstance => sx + case sx: DefInstance => sx + case sx: Block => sx case sx: Conditionally => sx case _ => EmptyStmt } def onModule(m: DefModule): DefModule = m match { - case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b))) + case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b))) case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps.filter(_.tpe == ClockType), dn, p) } - def run(c: Circuit): Circuit = c.copy(modules = c.modules map onModule) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(onModule)) } diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 14bd9e44..d237c36a 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -19,8 +19,9 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform import CustomYAMLProtocol._ val configs = r.parse[Config] val oldAnnos = state.annotations - val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) => - (annos, pins :+ config.pin.name) + val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { + case ((annos, pins), config) => + (annos, pins :+ config.pin.name) } state.copy(annotations = PinAnnotation(pins.toSeq) +: as) } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 4847a698..e290633e 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -10,12 +10,11 @@ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} import firrtl.options.{HasShellOptions, ShellOption} import MemPortUtils.memPortField -import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} +import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects} import WrappedExpression.weq import annotations._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} - case object InferReadWriteAnnotation extends NoTargetAnnotation // This pass examine the enable signals of the read & write ports of memories @@ -40,12 +39,13 @@ object InferReadWritePass extends Pass { getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval) // Visit each term of AND operation case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTerms(connects)) + e +: (args.flatMap(getProductTerms(connects))) // Visit connected nodes to references - case _: WRef | _: WSubField | _: WSubIndex => connects get e match { - case None => Seq(e) - case Some(ex) => e +: getProductTerms(connects)(ex) - } + case _: WRef | _: WSubField | _: WSubIndex => + connects.get(e) match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) + } // Otherwise just return itself case _ => Seq(e) } @@ -58,96 +58,103 @@ object InferReadWritePass extends Pass { // b ?= Eq(a, 0) or b ?= Eq(0, a) case (_, DoPrim(Eq, args, _, _)) => weq(args.head, a) && weq(args(1), zero) || - weq(args(1), a) && weq(args.head, zero) + weq(args(1), a) && weq(args.head, zero) // a ?= Eq(b, 0) or b ?= Eq(0, a) case (DoPrim(Eq, args, _, _), _) => weq(args.head, b) && weq(args(1), zero) || - weq(args(1), b) && weq(args.head, zero) + weq(args(1), b) && weq(args.head, zero) case _ => false } - def replaceExp(repl: Netlist)(e: Expression): Expression = - e map replaceExp(repl) match { - case ex: WSubField => repl getOrElse (ex.serialize, ex) + e.map(replaceExp(repl)) match { + case ex: WSubField => repl.getOrElse(ex.serialize, ex) case ex => ex } def replaceStmt(repl: Netlist)(s: Statement): Statement = - s map replaceStmt(repl) map replaceExp(repl) match { + s.map(replaceStmt(repl)).map(replaceExp(repl)) match { case Connect(_, EmptyExpression, _) => EmptyStmt - case sx => sx + case sx => sx } - def inferReadWriteStmt(connects: Connects, - repl: Netlist, - stmts: Statements) - (s: Statement): Statement = s match { + def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match { // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => val readers = new PortSet val writers = new PortSet val readwriters = collection.mutable.ArrayBuffer[String]() val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) - for (w <- mem.writers ; r <- mem.readers) { + for { + w <- mem.writers + r <- mem.readers + } { val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en")) val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en")) - val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b))) + val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) { - val rw = namespace newName "rw" + val rw = namespace.newName("rw") val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw readers += r writers += w - repl(memPortField(mem, r, "clk")) = EmptyExpression - repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") - repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = EmptyExpression + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = EmptyExpression repl(memPortField(mem, w, "addr")) = EmptyExpression repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) - stmts += Connect(NoInfo, WSubField(rwExp, "en"), - DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), - connects(memPortField(mem, w, "en"))), Nil, BoolType)) - stmts += Connect(NoInfo, WSubField(rwExp, "addr"), - Mux(connects(memPortField(mem, w, "en")), - connects(memPortField(mem, w, "addr")), - connects(memPortField(mem, r, "addr")), UnknownType)) + stmts += Connect( + NoInfo, + WSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType) + ) + stmts += Connect( + NoInfo, + WSubField(rwExp, "addr"), + Mux( + connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), + UnknownType + ) + ) } } - if (readwriters.isEmpty) mem else mem copy ( - readers = mem.readers filterNot readers, - writers = mem.writers filterNot writers, - readwriters = mem.readwriters ++ readwriters) - case sx => sx map inferReadWriteStmt(connects, repl, stmts) + if (readwriters.isEmpty) mem + else + mem.copy( + readers = mem.readers.filterNot(readers), + writers = mem.writers.filterNot(writers), + readwriters = mem.readwriters ++ readwriters + ) + case sx => sx.map(inferReadWriteStmt(connects, repl, stmts)) } def inferReadWrite(m: DefModule) = { val connects = getConnects(m) val repl = new Netlist val stmts = new Statements - (m map inferReadWriteStmt(connects, repl, stmts) - map replaceStmt(repl)) match { + (m.map(inferReadWriteStmt(connects, repl, stmts)) + .map(replaceStmt(repl))) match { case m: ExtModule => m - case m: Module => m copy (body = Block(m.body +: stmts.toSeq)) + case m: Module => m.copy(body = Block(m.body +: stmts.toSeq)) } } - def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) + def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite)) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform - with DependencyAPIMigration - with SeqTransformBased - with HasShellOptions { +class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty @@ -159,7 +166,9 @@ class InferReadWrite extends Transform longOption = "infer-rw", toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)), helpText = "Enable read/write port inference for memories", - shortOption = Some("firw") ) ) + shortOption = Some("firw") + ) + ) def transforms = Seq( InferReadWritePass, diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala index 3809c47c..871a1093 100644 --- a/src/main/scala/firrtl/passes/memlib/MemConf.scala +++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala @@ -3,7 +3,6 @@ package firrtl.passes package memlib - sealed abstract class MemPort(val name: String) { override def toString = name } case object ReadPort extends MemPort("read") @@ -19,22 +18,27 @@ object MemPort { def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s) def fromString(s: String): Map[MemPort, Int] = { - s.split(",").toSeq.map(MemPort.apply).map(_ match { - case Some(x) => x - case _ => throw new Exception(s"Error parsing MemPort string : ${s}") - }).groupBy(identity).mapValues(_.size).toMap + s.split(",") + .toSeq + .map(MemPort.apply) + .map(_ match { + case Some(x) => x + case _ => throw new Exception(s"Error parsing MemPort string : ${s}") + }) + .groupBy(identity) + .mapValues(_.size) + .toMap } } case class MemConf( - name: String, - depth: BigInt, - width: Int, - ports: Map[MemPort, Int], - maskGranularity: Option[Int] -) { + name: String, + depth: BigInt, + width: Int, + ports: Map[MemPort, Int], + maskGranularity: Option[Int]) { - private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",") + private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",") private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("") // Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes @@ -49,21 +53,34 @@ object MemConf { val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r def fromString(s: String): Seq[MemConf] = { - s.split("\n").toSeq.map(_ match { - case MemConf.regex(name, depth, width, ports, maskGran) => Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))) - case "" => None - case _ => throw new Exception(s"Error parsing MemConf string : ${s}") - }).flatten + s.split("\n") + .toSeq + .map(_ match { + case MemConf.regex(name, depth, width, ports, maskGran) => + Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))) + case "" => None + case _ => throw new Exception(s"Error parsing MemConf string : ${s}") + }) + .flatten } - def apply(name: String, depth: BigInt, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = { + def apply( + name: String, + depth: BigInt, + width: Int, + readPorts: Int, + writePorts: Int, + readWritePorts: Int, + maskGranularity: Option[Int] + ): MemConf = { val ports: Seq[(MemPort, Int)] = (if (maskGranularity.isEmpty) { - (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++ - (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts)) - } else { - (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++ - (if (readWritePorts == 0) Seq() else Seq(MaskedReadWritePort -> readWritePorts)) - }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts)) + (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++ + (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts)) + } else { + (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++ + (if (readWritePorts == 0) Seq() + else Seq(MaskedReadWritePort -> readWritePorts)) + }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts)) new MemConf(name, depth, width, ports.toMap, maskGranularity) } } diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala index 3731ea86..c8cd3e8d 100644 --- a/src/main/scala/firrtl/passes/memlib/MemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -19,38 +19,38 @@ object DefAnnotatedMemory { m.readwriters, m.readUnderWrite, None, // mask granularity annotation - None // No reference yet to another memory + None // No reference yet to another memory ) } } case class DefAnnotatedMemory( - info: Info, - name: String, - dataType: Type, - depth: BigInt, - writeLatency: Int, - readLatency: Int, - readers: Seq[String], - writers: Seq[String], - readwriters: Seq[String], - readUnderWrite: ReadUnderWrite.Value, - maskGran: Option[BigInt], - memRef: Option[(String, String)] /* (Module, Mem) */ - //pins: Seq[Pin], - ) extends Statement with IsDeclaration { + info: Info, + name: String, + dataType: Type, + depth: BigInt, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String], + readUnderWrite: ReadUnderWrite.Value, + maskGran: Option[BigInt], + memRef: Option[(String, String)] /* (Module, Mem) */ + //pins: Seq[Pin], +) extends Statement + with IsDeclaration { override def serialize: String = this.toMem.serialize - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def toMem = DefMemory(info, name, dataType, depth, - writeLatency, readLatency, readers, writers, - readwriters, readUnderWrite) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(dataType) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def toMem = + DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala index f0c9ebf4..1db132f7 100644 --- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala +++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala @@ -7,8 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption} class MemLibOptions extends RegisteredLibrary { val name: String = "MemLib Options" - val options: Seq[ShellOption[_]] = Seq( new InferReadWrite, - new ReplSeqMem ) + val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem) .flatMap(_.options) } diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala index b6a9a23d..f153fa2b 100644 --- a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala @@ -11,12 +11,12 @@ import MemPortUtils.{MemPortMap} object MemTransformUtils { /** Replaces references to old memory port names with new memory port names - */ + */ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { //TODO(izraelevitz): check speed def updateRef(e: Expression): Expression = { - val ex = e map updateRef - repl getOrElse (ex.serialize, ex) + val ex = e.map(updateRef) + repl.getOrElse(ex.serialize, ex) } def hasEmptyExpr(stmt: Statement): Boolean = { @@ -24,16 +24,16 @@ object MemTransformUtils { def testEmptyExpr(e: Expression): Expression = { e match { case EmptyExpression => foundEmpty = true - case _ => + case _ => } - e map testEmptyExpr // map must return; no foreach + e.map(testEmptyExpr) // map must return; no foreach } - stmt map testEmptyExpr + stmt.map(testEmptyExpr) foundEmpty } def updateStmtRefs(s: Statement): Statement = - s map updateStmtRefs map updateRef match { + s.map(updateStmtRefs).map(updateRef) match { case c: Connect if hasEmptyExpr(c) => EmptyStmt case s => s } @@ -42,6 +42,6 @@ object MemTransformUtils { } def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem) - def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField = + def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField = MemPortUtils.memPortField(s.toMem, p, f) } diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala index 69c6b284..f325c0ba 100644 --- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -7,19 +7,19 @@ import firrtl.ir._ import firrtl.Utils._ /** Given a mask, return a bitmask corresponding to the desired datatype. - * Requirements: - * - The mask type and datatype must be equivalent, except any ground type in - * datatype must be matched by a 1-bit wide UIntType. - * - The mask must be a reference, subfield, or subindex - * The bitmask is a series of concatenations of the single mask bit over the - * length of the corresponding ground type, e.g.: - *{{{ - * wire mask: {x: UInt<1>, y: UInt<1>} - * wire data: {x: UInt<2>, y: SInt<2>} - * // this would return: - * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) - * }}} - */ + * Requirements: + * - The mask type and datatype must be equivalent, except any ground type in + * datatype must be matched by a 1-bit wide UIntType. + * - The mask must be a reference, subfield, or subindex + * The bitmask is a series of concatenations of the single mask bit over the + * length of the corresponding ground type, e.g.: + * {{{ + * wire mask: {x: UInt<1>, y: UInt<1>} + * wire data: {x: UInt<2>, y: SInt<2>} + * // this would return: + * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) + * }}} + */ object toBitMask { def apply(mask: Expression, dataType: Type): Expression = mask match { case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType) @@ -28,12 +28,13 @@ object toBitMask { private def hiermask(mask: Expression, dataType: Type): Expression = (mask.tpe, dataType) match { case (mt: VectorType, dt: VectorType) => - seqCat((0 until mt.size).reverse map { i => + seqCat((0 until mt.size).reverse.map { i => hiermask(WSubIndex(mask, i, mt.tpe, UnknownFlow), dt.tpe) }) case (mt: BundleType, dt: BundleType) => - seqCat((mt.fields zip dt.fields) map { case (mf, df) => - hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe) + seqCat((mt.fields.zip(dt.fields)).map { + case (mf, df) => + hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe) }) case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) => seqCat(List.fill(bitWidth(dt).intValue)(mask)) @@ -44,7 +45,7 @@ object toBitMask { object createMask { def apply(dt: Type): Type = dt match { case t: VectorType => VectorType(apply(t.tpe), t.size) - case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) + case t: BundleType => BundleType(t.fields.map(f => f.copy(tpe = apply(f.tpe)))) case GroundType(w) if w == IntWidth(0) => UIntType(IntWidth(0)) case t: GroundType => BoolType } @@ -56,27 +57,33 @@ object MemPortUtils { type Modules = collection.mutable.ArrayBuffer[DefModule] def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( - Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))), + Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1).max(1)))), Field("en", Default, BoolType), Field("clk", Default, ClockType) ) // Todo: merge it with memToBundle def memType(mem: DefMemory): BundleType = { - val rType = BundleType(defaultPortSeq(mem) :+ - Field("data", Flip, mem.dataType)) - val wType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("data", Default, mem.dataType), - Field("mask", Default, createMask(mem.dataType)))) - val rwType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("rdata", Flip, mem.dataType), - Field("wmode", Default, BoolType), - Field("wdata", Default, mem.dataType), - Field("wmask", Default, createMask(mem.dataType)))) + val rType = BundleType( + defaultPortSeq(mem) :+ + Field("data", Flip, mem.dataType) + ) + val wType = BundleType( + defaultPortSeq(mem) ++ Seq(Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType))) + ) + val rwType = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("rdata", Flip, mem.dataType), + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("wmask", Default, createMask(mem.dataType)) + ) + ) BundleType( - (mem.readers map (Field(_, Flip, rType))) ++ - (mem.writers map (Field(_, Flip, wType))) ++ - (mem.readwriters map (Field(_, Flip, rwType)))) + (mem.readers.map(Field(_, Flip, rType))) ++ + (mem.writers.map(Field(_, Flip, wType))) ++ + (mem.readwriters.map(Field(_, Flip, rwType))) + ) } def memPortField(s: DefMemory, p: String, f: String): WSubField = { diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index c51a0adc..30529119 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -9,27 +9,27 @@ import firrtl.Mappers._ import MemPortUtils._ import MemTransformUtils._ - /** Changes memory port names to standard port names (i.e. RW0 instead T_408) - */ + */ object RenameAnnotatedMemoryPorts extends Pass { + /** Renames memory ports to a standard naming scheme: - * - R0, R1, ... for each read port - * - W0, W1, ... for each write port - * - RW0, RW1, ... for each readwrite port - */ + * - R0, R1, ... for each read port + * - W0, W1, ... for each write port + * - RW0, RW1, ... for each readwrite port + */ def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = { - val rports = m.readers.indices map (i => s"R$i") - val wports = m.writers.indices map (i => s"W$i") - val rwports = m.readwriters.indices map (i => s"RW$i") - m copy (readers = rports, writers = wports, readwriters = rwports) + val rports = m.readers.indices.map(i => s"R$i") + val wports = m.writers.indices.map(i => s"W$i") + val rwports = m.readwriters.indices.map(i => s"RW$i") + m.copy(readers = rports, writers = wports, readwriters = rwports) } /** Maps the serialized form of all memory port field names to the - * corresponding new memory port field Expression. - * E.g.: - * - ("m.read.addr") becomes (m.R0.addr) - */ + * corresponding new memory port field Expression. + * E.g.: + * - ("m.read.addr") becomes (m.R0.addr) + */ def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = { val defaultFields = Seq("addr", "en", "clk") val rFields = defaultFields :+ "data" @@ -37,7 +37,10 @@ object RenameAnnotatedMemoryPorts extends Pass { val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit = - for ((p, i) <- ports.zipWithIndex; f <- fields) { + for { + (p, i) <- ports.zipWithIndex + f <- fields + } { val newPort = WSubField(WRef(m.name), newPortKind + i) val field = WSubField(newPort, f) memPortMap(s"${m.name}.$p.$f") = field @@ -55,16 +58,16 @@ object RenameAnnotatedMemoryPorts extends Pass { val updatedMem = createMemProto(m) getMemPortMap(m, memPortMap) updatedMem - case s => s map updateMemStmts(memPortMap) + case s => s.map(updateMemStmts(memPortMap)) } /** Replaces candidate memories and their references with standard port names - */ + */ def updateMemMods(m: DefModule) = { val memPortMap = new MemPortMap - (m map updateMemStmts(memPortMap) - map updateStmtRefs(memPortMap)) + (m.map(updateMemStmts(memPortMap)) + .map(updateStmtRefs(memPortMap))) } - def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) + def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods)) } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index bfbc163a..fc381e88 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -13,7 +13,6 @@ import firrtl.annotations._ import firrtl.stage.Forms import wiring._ - /** Annotates the name of the pins to add for WiringTransform */ case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation @@ -35,14 +34,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM /** Return true if mask granularity is per bit, false if per byte or unspecified */ private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match { - case None => false + case None => false case Some(v) => v == 1 } private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType) + ) private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)) + ) /** Catch incorrect memory instantiations when there are masked memories with unsupported aggregate types. * @@ -82,7 +83,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM ) private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match { - case None => Nil + case None => Nil case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType))) case Some(_) => { checkMaskDatatype(mem) @@ -111,7 +112,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM Field("wdata", Default, flattenType(mem.dataType)), Field("rdata", Flip, flattenType(mem.dataType)) ) ++ (mem.maskGran match { - case None => Nil + case None => Nil case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType))) case Some(_) => { checkMaskDatatype(mem) @@ -122,32 +123,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM def memToBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s))) + ) def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))) + ) /** Creates a wrapper module and external module to replace a candidate memory - * The wrapper module has the same type as the memory it replaces - * The external module - */ + * The wrapper module has the same type as the memory it replaces + * The external module + */ def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) val wrapperIoType = memToBundle(m) - val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) // Creates a type with the write/readwrite masks omitted if necessary val bbIoType = memToFlattenBundle(m) - val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbIoPorts = bbIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) val bbRef = WRef(m.name, bbIoType) val hasMask = m.maskGran.isDefined val fillMask = getFillWMask(m) def portRef(p: String) = WRef(p, field_type(wrapperIoType, p)) val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ - (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ - (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ - (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) + (m.readers.flatMap(r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ + (m.writers.flatMap(w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters.flatMap(rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty) // TODO: Annotate? -- use actual annotation map @@ -160,16 +163,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM // TODO(shunshou): get rid of copy pasta // Connects the clk, en, and addr fields from the wrapperPort to the bbPort def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = - Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) + Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f)) // Generates mask bits (concatenates an aggregate to ground type) // depending on mask granularity (# bits = data width / mask granularity) def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = if (fillMask) toBitMask(mask, dataType) else toBits(mask) - def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = + def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = defaultConnects(wrapperPort, bbPort) :+ - fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) + fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { val wrapperData = WSubField(wrapperPort, "data") @@ -177,11 +180,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData)) hasMask match { case false => defaultSeq - case true => defaultSeq :+ Connect( - NoInfo, - WSubField(bbPort, "mask"), - maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) - ) + case true => + defaultSeq :+ Connect( + NoInfo, + WSubField(bbPort, "mask"), + maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) + ) } } @@ -190,61 +194,67 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")), connectFields(bbPort, "wmode", wrapperPort, "wmode"), - Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))) + Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)) + ) hasMask match { case false => defaultSeq - case true => defaultSeq :+ Connect( - NoInfo, - WSubField(bbPort, "wmask"), - maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) - ) + case true => + defaultSeq :+ Connect( + NoInfo, + WSubField(bbPort, "wmask"), + maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) + ) } } /** Mapping from (module, memory name) pairs to blackbox names */ private type NameMap = collection.mutable.HashMap[(String, String), String] + /** Construct NameMap by assigning unique names for each memory blackbox */ def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = { s match { - case m: DefAnnotatedMemory => m.memRef match { - case None => nameMap(mname -> m.name) = namespace newName m.name - case Some(_) => - } + case m: DefAnnotatedMemory => + m.memRef match { + case None => nameMap(mname -> m.name) = namespace.newName(m.name) + case Some(_) => + } case _ => } - s map constructNameMap(namespace, nameMap, mname) + s.map(constructNameMap(namespace, nameMap, mname)) } - def updateMemStmts(namespace: Namespace, - nameMap: NameMap, - mname: String, - memPortMap: MemPortMap, - memMods: Modules) - (s: Statement): Statement = s match { + def updateMemStmts( + namespace: Namespace, + nameMap: NameMap, + mname: String, + memPortMap: MemPortMap, + memMods: Modules + )(s: Statement + ): Statement = s match { case m: DefAnnotatedMemory => if (m.maskGran.isEmpty) { - m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } - m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } + m.writers.foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } + m.readwriters.foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } } m.memRef match { case None => // prototype mem val newWrapperName = nameMap(mname -> m.name) - val newMemBBName = namespace newName s"${newWrapperName}_ext" - val newMem = m copy (name = newMemBBName) + val newMemBBName = namespace.newName(s"${newWrapperName}_ext") + val newMem = m.copy(name = newMemBBName) memMods ++= createMemModule(newMem, newWrapperName) WDefInstance(m.info, m.name, newWrapperName, UnknownType) case Some((module, mem)) => WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType) } - case sx => sx map updateMemStmts(namespace, nameMap, mname, memPortMap, memMods) + case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)) } def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = { val memPortMap = new MemPortMap - (m map updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods) - map updateStmtRefs(memPortMap)) + (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)) + .map(updateStmtRefs(memPortMap))) } def execute(state: CircuitState): CircuitState = { @@ -252,15 +262,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap - c.modules map (m => m map constructNameMap(namespace, nameMap, m.name)) - val modules = c.modules map updateMemMods(namespace, nameMap, memMods) + c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name))) + val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods)) // print conf writer.serialize() val pannos = state.annotations.collect { case a: PinAnnotation => a } val pins = pannos match { - case Seq() => Nil + case Seq() => Nil case Seq(PinAnnotation(pins)) => pins - case _ => throwInternalError("Something went wrong") + case _ => throwInternalError("Something went wrong") } val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => seq ++ memMods.collect { diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 87321ea0..79e07640 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -7,7 +7,7 @@ import firrtl._ import firrtl.annotations._ import firrtl.options.{HasShellOptions, ShellOption} import Utils.error -import java.io.{File, CharArrayWriter, PrintWriter} +import java.io.{CharArrayWriter, File, PrintWriter} import wiring._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} @@ -50,7 +50,15 @@ class ConfWriter(filename: String) { // assert that we don't overflow going from BigInt to Int conversion require(bitWidth(m.dataType) <= Int.MaxValue) m.maskGran.foreach { case x => require(x <= Int.MaxValue) } - val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt)) + val conf = MemConf( + m.name, + m.depth, + bitWidth(m.dataType).toInt, + m.readers.length, + m.writers.length, + m.readwriters.length, + m.maskGran.map(_.toInt) + ) outputBuffer.append(conf.toString) } def serialize() = { @@ -113,27 +121,31 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat val options = Seq( new ShellOption[String]( longOption = "repl-seq-mem", - toAnnotationSeq = (a: String) => Seq( passes.memlib.ReplSeqMemAnnotation.parse(a), - RunFirrtlTransformAnnotation(new ReplSeqMem) ), + toAnnotationSeq = + (a: String) => Seq(passes.memlib.ReplSeqMemAnnotation.parse(a), RunFirrtlTransformAnnotation(new ReplSeqMem)), helpText = "Blackbox and emit a configuration file for each sequential memory", shortOption = Some("frsq"), - helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") ) ) + helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") + ) + ) def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = - Seq(new SimpleMidTransform(Legalize), - new SimpleMidTransform(ToMemIR), - new SimpleMidTransform(ResolveMaskGranularity), - new SimpleMidTransform(RenameAnnotatedMemoryPorts), - new ResolveMemoryReference, - new CreateMemoryAnnotations(inConfigFile), - new ReplaceMemMacros(outConfigFile), - new WiringTransform, - new SimpleMidTransform(RemoveEmpty), - new SimpleMidTransform(CheckInitialization), - new SimpleMidTransform(InferTypes), - Uniquify, - new SimpleMidTransform(ResolveKinds), - new SimpleMidTransform(ResolveFlows)) + Seq( + new SimpleMidTransform(Legalize), + new SimpleMidTransform(ToMemIR), + new SimpleMidTransform(ResolveMaskGranularity), + new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new ResolveMemoryReference, + new CreateMemoryAnnotations(inConfigFile), + new ReplaceMemMacros(outConfigFile), + new WiringTransform, + new SimpleMidTransform(RemoveEmpty), + new SimpleMidTransform(CheckInitialization), + new SimpleMidTransform(InferTypes), + Uniquify, + new SimpleMidTransform(ResolveKinds), + new SimpleMidTransform(ResolveFlows) + ) def execute(state: CircuitState): CircuitState = { val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 41c47dce..434c7602 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -28,10 +28,10 @@ object AnalysisUtils { connects(value.serialize) = WInvalid case _ => // do nothing } - s map getConnects(connects) + s.map(getConnects(connects)) } val connects = new Connects - m map getConnects(connects) + m.map(getConnects(connects)) connects } @@ -56,8 +56,8 @@ object AnalysisUtils { else if (weq(tvOrigin, fvOrigin)) tvOrigin else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin else e - case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one - case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero + case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => val extractionWidth = (msb - lsb) + 1 val nodeWidth = bitWidth(args.head.tpe) @@ -69,10 +69,10 @@ object AnalysisUtils { case ValidIf(cond, value, _) => getOrigin(connects)(value) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => - connects get e.serialize match { - case Some(ex) => getOrigin(connects)(ex) - case None => e - } + connects.get(e.serialize) match { + case Some(ex) => getOrigin(connects)(ex) + case None => e + } case _ => e } } @@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass { */ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { val wenOrigin = getOrigin(connects)(wen) - val wmaskOrigin = connects.keys filter - (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)} + val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) } // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) - val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one)) if (redundantMask) None else Some(wmaskOrigin.size) } @@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass { def updateStmts(connects: Connects)(s: Statement): Statement = s match { case m: DefAnnotatedMemory => val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (rw => - getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) - val wMasks = m.writers map (w => - getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val rwMasks = + m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) val maskGran = (rwMasks ++ wMasks).head match { - case None => None + case None => None case Some(maskBits) => Some(dataBits / maskBits) } m.copy(maskGran = maskGran) - case sx => sx map updateStmts(connects) + case sx => sx.map(updateStmts(connects)) } - def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m)) - def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m))) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems)) } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index b5ff10c6..e80e0c4a 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -14,7 +14,7 @@ case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnot } /** Resolves annotation ref to memories that exactly match (except name) another memory - */ + */ class ResolveMemoryReference extends Transform with DependencyAPIMigration { override def prerequisites = Forms.MidForm @@ -45,10 +45,12 @@ class ResolveMemoryReference extends Transform with DependencyAPIMigration { /** If a candidate memory is identical except for name to another, add an * annotation that references the name of the other memory. */ - def updateMemStmts(mname: String, - existingMems: AnnotatedMemories, - noDedupMap: Map[String, Set[String]]) - (s: Statement): Statement = s match { + def updateMemStmts( + mname: String, + existingMems: AnnotatedMemories, + noDedupMap: Map[String, Set[String]] + )(s: Statement + ): Statement = s match { // If not dedupable, no need to add to existing (since nothing can dedup with it) // We just return the DefAnnotatedMemory as is in the default case below case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) => diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala index 554a3572..9fe7f852 100644 --- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -14,16 +14,17 @@ import firrtl.ir._ * - undefined read-under-write behavior */ object ToMemIR extends Pass { + /** Only annotate memories that are candidates for memory macro replacements * i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined." */ import ReadUnderWrite._ def updateStmts(s: Statement): Statement = s match { - case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 => + case m @ DefMemory(_, _, _, _, 1, 1, r, w, rw, Undefined) if (w.length + rw.length) == 1 && r.length <= 1 => DefAnnotatedMemory(m) - case sx => sx map updateStmts + case sx => sx.map(updateStmts) } - def annotateModMems(m: DefModule) = m map updateStmts - def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule) = m.map(updateStmts) + def run(c: Circuit) = c.copy(modules = c.modules.map(annotateModMems)) } diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index dd644323..a2b14343 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -24,19 +24,19 @@ object MemDelayAndReadwriteTransformer { case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect]) // Utilities for generating hardware - def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) - def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) - def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) - def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) + def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) + def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) + def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) + def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) // Utilities for working with WithValid groups def connect(l: WithValid, r: WithValid): Seq[Connect] = { - val paired = (l.valid +: l.payload) zip (r.valid +: r.payload) + val paired = (l.valid +: l.payload).zip(r.valid +: r.payload) paired.map { case (le, re) => connect(le, re) } } def condConnect(l: WithValid, r: WithValid): Seq[Connect] = { - connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) } + connect(l.valid, r.valid) +: (l.payload.zip(r.payload)).map { case (le, re) => condConnect(r.valid)(le, re) } } // Internal representation of a pipeline stage with an associated valid signal @@ -47,20 +47,23 @@ object MemDelayAndReadwriteTransformer { private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_") // Pipeline a group of signals with an associated valid signal. Gate registers when possible. - def pipelineWithValid(ns: Namespace)( - clock: Expression, - depth: Int, - src: WithValid, - nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = { + def pipelineWithValid( + ns: Namespace + )(clock: Expression, + depth: Int, + src: WithValid, + nameTemplate: Option[WithValid] = None + ): (WithValid, Seq[Statement], Seq[Connect]) = { def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e) val template = nameTemplate.getOrElse(src) - val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev => - def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) - val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) - val regs = (ref.valid +: ref.payload).map(asReg) - PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) + val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { + case prev => + def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) + val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) + val regs = (ref.valid +: ref.payload).map(asReg) + PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) } (stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns)) } @@ -84,10 +87,10 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { private def findMemConns(s: Statement): Unit = s match { case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr - case _ => s.foreach(findMemConns) + case _ => s.foreach(findMemConns) } - private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match { + private def swapMemRefs(e: Expression): Expression = e.map(swapMemRefs) match { case sf: WSubField => exprReplacements.getOrElse(we(sf), sf) case ex => ex } @@ -105,51 +108,57 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0 val wCmdDelay = mem.writeLatency - 1 - val readStmts = (mem.readers ++ mem.readwriters).map { case r => - def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) - def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) - val clk = oldDriver("clk") - - // Pack sources of read command inputs into WithValid object -> different for readwriter - val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") - val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) - val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) - val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) - val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) - - // Pipeline read response using *last* command pipe stage enable as the valid signal - val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) - val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names - val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) - - // Make sure references to the read data get appropriately substituted - val oldRDataName = if (rMap.contains(r)) "rdata" else "data" - exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head - - // Return all statements; they're separated so connects can go after all declarations - SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) + val readStmts = (mem.readers ++ mem.readwriters).map { + case r => + def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) + def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) + val clk = oldDriver("clk") + + // Pack sources of read command inputs into WithValid object -> different for readwriter + val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") + val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) + val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) + val (cmdPiped, cmdDecls, cmdConns) = + pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Pipeline read response using *last* command pipe stage enable as the valid signal + val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) + val respPipeNameTemplate = + Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names + val (respPiped, respDecls, respConns) = + pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) + + // Make sure references to the read data get appropriately substituted + val oldRDataName = if (rMap.contains(r)) "rdata" else "data" + exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) } - val writeStmts = (mem.writers ++ mem.readwriters).map { case w => - def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) - def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) - val clk = oldDriver("clk") - - // Pack sources of write command inputs into WithValid object -> different for readwriter - val cmdSrc = if (wMap.contains(w)) { - val en = AND(oldDriver("en"), oldDriver("wmode")) - WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) - } else { - WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) - } - - // Pipeline write command, connect to memory - val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) - val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) - val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) - - // Return all statements; they're separated so connects can go after all declarations - SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) + val writeStmts = (mem.writers ++ mem.readwriters).map { + case w => + def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) + def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) + val clk = oldDriver("clk") + + // Pack sources of write command inputs into WithValid object -> different for readwriter + val cmdSrc = if (wMap.contains(w)) { + val en = AND(oldDriver("en"), oldDriver("wmode")) + WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) + } else { + WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) + } + + // Pipeline write command, connect to memory + val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) + val (cmdPiped, cmdDecls, cmdConns) = + pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) } newConns ++= (readStmts ++ writeStmts).flatMap(_.conns) @@ -171,8 +180,7 @@ object VerilogMemDelays extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) override val optionalPrerequisiteOf = - Seq( Dependency[VerilogEmitter], - Dependency[SystemVerilogEmitter] ) + Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case _: transforms.ConstantPropagation | ResolveFlows => true @@ -180,5 +188,5 @@ object VerilogMemDelays extends Pass { } def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed - def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) } diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala index a43adfe2..b5f91e7b 100644 --- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -6,7 +6,6 @@ import net.jcazevedo.moultingyaml._ import java.io.{CharArrayWriter, File, PrintWriter} import firrtl.FileUtils - object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top implicit val _pin = yamlFormat1(Pin) @@ -20,17 +19,15 @@ case class Source(name: String, module: String) case class Top(name: String) case class Config(pin: Pin, source: Source, top: Top) - class YamlFileReader(file: String) { - def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { + def parse[A](implicit reader: YamlReader[A]): Seq[A] = { if (new File(file).exists) { val yamlString = FileUtils.getText(file) - yamlString.parseYamls flatMap (x => - try Some(reader read x) + yamlString.parseYamls.flatMap(x => + try Some(reader.read(x)) catch { case e: Exception => None } ) - } - else sys.error("Yaml file doesn't exist!") + } else sys.error("Yaml file doesn't exist!") } } @@ -38,11 +35,11 @@ class YamlFileWriter(file: String) { val outputBuffer = new CharArrayWriter val separator = "--- \n" def append(in: YamlValue): Unit = { - outputBuffer append s"$separator${in.prettyPrint}" + outputBuffer.append(s"$separator${in.prettyPrint}") } def dump(): Unit = { val outputFile = new PrintWriter(file) - outputFile write outputBuffer.toString + outputFile.write(outputBuffer.toString) outputFile.close() } } diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala index 3f74e5d2..a69b7797 100644 --- a/src/main/scala/firrtl/passes/wiring/Wiring.scala +++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala @@ -18,8 +18,7 @@ import firrtl.graph.EulerTour case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String) /** A data store of wiring names */ -case class WiringNames(compName: String, source: String, sinks: Seq[Named], - pin: String) +case class WiringNames(compName: String, source: String, sinks: Seq[Named], pin: String) /** Pass that computes and applies a sequence of wiring modifications * @@ -28,31 +27,39 @@ case class WiringNames(compName: String, source: String, sinks: Seq[Named], */ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { def run(c: Circuit): Circuit = analyze(c) - .foldLeft(c){ - case (cx, (tpe, modsMap)) => cx.copy( - modules = cx.modules map onModule(tpe, modsMap)) } + .foldLeft(c) { + case (cx, (tpe, modsMap)) => cx.copy(modules = cx.modules.map(onModule(tpe, modsMap))) + } /** Converts multiple units of wiring information to module modifications */ private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = { val names = wiSeq - .map ( wi => (wi.source, wi.sinks, wi.pin) match { - case (ComponentName(comp, ModuleName(source,_)), sinks, pin) => - WiringNames(comp, source, sinks, pin) }) + .map(wi => + (wi.source, wi.sinks, wi.pin) match { + case (ComponentName(comp, ModuleName(source, _)), sinks, pin) => + WiringNames(comp, source, sinks, pin) + } + ) val portNames = mutable.Seq.fill(names.size)(Map[String, String]()) - c.modules.foreach{ m => + c.modules.foreach { m => val ns = Namespace(m) - names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) => - portNames(i) = portNames(i) + - ( m.name -> { - if (si.exists(getModuleName(_) == m.name)) ns.newName(p) - else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_") - })}} + names.zipWithIndex.foreach { + case (WiringNames(c, so, si, p), i) => + portNames(i) = portNames(i) + + (m.name -> { + if (si.exists(getModuleName(_) == m.name)) ns.newName(p) + else ns.newName(tokenize(c).filterNot("[]." contains _).mkString("_")) + }) + } + } val iGraph = InstanceKeyGraph(c) - names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) => - computeModifications(c, iGraph, comp, so, si, pn) } + names.zip(portNames).map { + case (WiringNames(comp, so, si, _), pn) => + computeModifications(c, iGraph, comp, so, si, pn) + } } /** Converts a single unit of wiring information to module modifications @@ -69,19 +76,20 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { * @return a tuple of the component type and a map of module names * to pending modifications */ - private def computeModifications(c: Circuit, - iGraph: InstanceKeyGraph, - compName: String, - source: String, - sinks: Seq[Named], - portNames: Map[String, String]): - (Type, Map[String, Modifications]) = { + private def computeModifications( + c: Circuit, + iGraph: InstanceKeyGraph, + compName: String, + source: String, + sinks: Seq[Named], + portNames: Map[String, String] + ): (Type, Map[String, Modifications]) = { val sourceComponentType = getType(c, source, compName) - val sinkComponents: Map[String, Seq[String]] = sinks - .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) } - .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){ - case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) } + val sinkComponents: Map[String, Seq[String]] = sinks.collect { case ComponentName(c, ModuleName(m, _)) => (c, m) } + .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]) { + case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil))) + } // Determine "ownership" of sources to sinks via minimum distance val owners = sinksToSourcesSeq(sinks, source, iGraph) @@ -95,86 +103,88 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { def makeWire(m: Modifications, portName: String): Modifications = m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire)))) def makeWireC(m: Modifications, portName: String, c: (String, String)): Modifications = - m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct ) + m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct) val tour = EulerTour(iGraph.graph, iGraph.top) // Finds the lowest common ancestor instances for two module names in a design def lowestCommonAncestor(moduleA: Seq[InstanceKey], moduleB: Seq[InstanceKey]): Seq[InstanceKey] = tour.rmq(moduleA, moduleB) - owners.foreach { case (sink, source) => - val lca = lowestCommonAncestor(sink, source) - - // Compute metadata along Sink to LCA paths. - sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach { - case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) => - val to = s"$ci.${portNames(cm)}" - val from = s"${portNames(pm)}" - meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) - meta(cm) = meta(cm).copy( - addPortOrWire = Some((portNames(cm), DecInput)) - ) - // Case where the sink is the LCA - case Seq(InstanceKey(_,pm)) => - // Case where the source is also the LCA - if (source.drop(lca.size).isEmpty) { - meta(pm) = makeWire(meta(pm), portNames(pm)) - } else { - val InstanceKey(ci,cm) = source.drop(lca.size).head - val to = s"${portNames(pm)}" - val from = s"$ci.${portNames(cm)}" - meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) - } - } + owners.foreach { + case (sink, source) => + val lca = lowestCommonAncestor(sink, source) - // Compute metadata for the Sink - sink.last match { case InstanceKey( _, m) => - if (sinkComponents.contains(m)) { - val from = s"${portNames(m)}" - sinkComponents(m).foreach( to => - meta(m) = meta(m).copy( - cons = (meta(m).cons :+( (to, from) )).distinct + // Compute metadata along Sink to LCA paths. + sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach { + case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) => + val to = s"$ci.${portNames(cm)}" + val from = s"${portNames(pm)}" + meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) + meta(cm) = meta(cm).copy( + addPortOrWire = Some((portNames(cm), DecInput)) ) - ) + // Case where the sink is the LCA + case Seq(InstanceKey(_, pm)) => + // Case where the source is also the LCA + if (source.drop(lca.size).isEmpty) { + meta(pm) = makeWire(meta(pm), portNames(pm)) + } else { + val InstanceKey(ci, cm) = source.drop(lca.size).head + val to = s"${portNames(pm)}" + val from = s"$ci.${portNames(cm)}" + meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from)) + } } - } - // Compute metadata for the Source - source.last match { case InstanceKey( _, m) => - val to = s"${portNames(m)}" - val from = compName - meta(m) = meta(m).copy( - cons = (meta(m).cons :+( (to, from) )).distinct - ) - } + // Compute metadata for the Sink + sink.last match { + case InstanceKey(_, m) => + if (sinkComponents.contains(m)) { + val from = s"${portNames(m)}" + sinkComponents(m).foreach(to => + meta(m) = meta(m).copy( + cons = (meta(m).cons :+ ((to, from))).distinct + ) + ) + } + } - // Compute metadata along Source to LCA path - source.drop(lca.size - 1).sliding(2).toList.reverse.map { - case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) => { - val to = s"${portNames(pm)}" - val from = s"$ci.${portNames(cm)}" - meta(pm) = meta(pm).copy( - cons = (meta(pm).cons :+( (to, from) )).distinct - ) - meta(cm) = meta(cm).copy( - addPortOrWire = Some((portNames(cm), DecOutput)) - ) + // Compute metadata for the Source + source.last match { + case InstanceKey(_, m) => + val to = s"${portNames(m)}" + val from = compName + meta(m) = meta(m).copy( + cons = (meta(m).cons :+ ((to, from))).distinct + ) } - // Case where the source is the LCA - case Seq(InstanceKey(_,pm)) => { - // Case where the sink is also the LCA. We do nothing here, - // as we've created the connecting wire above - if (sink.drop(lca.size).isEmpty) { - } else { - val InstanceKey(ci,cm) = sink.drop(lca.size).head - val to = s"$ci.${portNames(cm)}" - val from = s"${portNames(pm)}" + + // Compute metadata along Source to LCA path + source.drop(lca.size - 1).sliding(2).toList.reverse.map { + case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) => { + val to = s"${portNames(pm)}" + val from = s"$ci.${portNames(cm)}" meta(pm) = meta(pm).copy( - cons = (meta(pm).cons :+( (to, from) )).distinct + cons = (meta(pm).cons :+ ((to, from))).distinct ) + meta(cm) = meta(cm).copy( + addPortOrWire = Some((portNames(cm), DecOutput)) + ) + } + // Case where the source is the LCA + case Seq(InstanceKey(_, pm)) => { + // Case where the sink is also the LCA. We do nothing here, + // as we've created the connecting wire above + if (sink.drop(lca.size).isEmpty) {} else { + val InstanceKey(ci, cm) = sink.drop(lca.size).head + val to = s"$ci.${portNames(cm)}" + val from = s"${portNames(pm)}" + meta(pm) = meta(pm).copy( + cons = (meta(pm).cons :+ ((to, from))).distinct + ) + } } } - } } (sourceComponentType, meta.toMap) } @@ -189,20 +199,22 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass { val ports = mutable.ArrayBuffer[Port]() l.addPortOrWire match { case None => - case Some((s, dt)) => dt match { - case DecInput => ports += Port(NoInfo, s, Input, t) - case DecOutput => ports += Port(NoInfo, s, Output, t) - case DecWire => defines += DefWire(NoInfo, s, t) - } + case Some((s, dt)) => + dt match { + case DecInput => ports += Port(NoInfo, s, Input, t) + case DecOutput => ports += Port(NoInfo, s, Output, t) + case DecWire => defines += DefWire(NoInfo, s, t) + } } - connects ++= (l.cons map { case ((l, r)) => - Connect(NoInfo, toExp(l), toExp(r)) + connects ++= (l.cons.map { + case ((l, r)) => + Connect(NoInfo, toExp(l), toExp(r)) }) m match { case Module(i, n, ps, body) => val stmts = body match { case Block(sx) => sx - case s => Seq(s) + case s => Seq(s) } Module(i, n, ps ++ ports, Block(List() ++ defines ++ stmts ++ connects)) case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p) diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala index 20fb1215..d6658f16 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala @@ -14,14 +14,12 @@ import firrtl.stage.Forms case class WiringException(msg: String) extends PassException(msg) /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ -case class SourceAnnotation(target: ComponentName, pin: String) extends - SingleTargetAnnotation[ComponentName] { +case class SourceAnnotation(target: ComponentName, pin: String) extends SingleTargetAnnotation[ComponentName] { def duplicate(n: ComponentName) = this.copy(target = n) } /** A module, e.g. ExtModule etc., that should add the input pin */ -case class SinkAnnotation(target: Named, pin: String) extends - SingleTargetAnnotation[Named] { +case class SinkAnnotation(target: Named, pin: String) extends SingleTargetAnnotation[Named] { def duplicate(n: Named) = this.copy(target = n) } @@ -76,8 +74,9 @@ class WiringTransform extends Transform with DependencyAPIMigration { (sources.size, sinks.size) match { case (0, p) => state case (s, p) if (p > 0) => - val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) => - seq :+ WiringInfo(source, sinks(pin), pin) + val wis = sources.foldLeft(Seq[WiringInfo]()) { + case (seq, (pin, source)) => + seq :+ WiringInfo(source, sinks(pin), pin) } val annosx = state.annotations.filterNot(annos.toSet.contains) transforms(wis) diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala index c220692a..5e8f8616 100644 --- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala +++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala @@ -25,54 +25,54 @@ case object DecWire extends DecKind /** Store of pending wiring information for a Module */ case class Modifications( addPortOrWire: Option[(String, DecKind)] = None, - cons: Seq[(String, String)] = Seq.empty) { + cons: Seq[(String, String)] = Seq.empty) { override def toString: String = serialize("") def serialize(tab: String): String = s""" - |$tab addPortOrWire: $addPortOrWire - |$tab cons: $cons - |""".stripMargin + |$tab addPortOrWire: $addPortOrWire + |$tab cons: $cons + |""".stripMargin } /** A lineage tree representing the instance hierarchy in a design */ @deprecated("Use DiGraph/InstanceGraph", "1.1.1") case class Lineage( - name: String, - children: Seq[(String, Lineage)] = Seq.empty, - source: Boolean = false, - sink: Boolean = false, - sourceParent: Boolean = false, - sinkParent: Boolean = false, - sharedParent: Boolean = false, - addPort: Option[(String, DecKind)] = None, - cons: Seq[(String, String)] = Seq.empty) { + name: String, + children: Seq[(String, Lineage)] = Seq.empty, + source: Boolean = false, + sink: Boolean = false, + sourceParent: Boolean = false, + sinkParent: Boolean = false, + sharedParent: Boolean = false, + addPort: Option[(String, DecKind)] = None, + cons: Seq[(String, String)] = Seq.empty) { def map(f: Lineage => Lineage): Lineage = - this.copy(children = children.map{ case (i, m) => (i, f(m)) }) + this.copy(children = children.map { case (i, m) => (i, f(m)) }) override def toString: String = shortSerialize("") def shortSerialize(tab: String): String = s""" - |$tab name: $name, - |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))} - |""".stripMargin + |$tab name: $name, + |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))} + |""".stripMargin def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B = this.children.foldLeft(z)(op) def serialize(tab: String): String = s""" - |$tab name: $name, - |$tab source: $source, - |$tab sink: $sink, - |$tab sourceParent: $sourceParent, - |$tab sinkParent: $sinkParent, - |$tab sharedParent: $sharedParent, - |$tab addPort: $addPort - |$tab cons: $cons - |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))} - |""".stripMargin + |$tab name: $name, + |$tab source: $source, + |$tab sink: $sink, + |$tab sourceParent: $sourceParent, + |$tab sinkParent: $sinkParent, + |$tab sharedParent: $sharedParent, + |$tab addPort: $addPort + |$tab cons: $cons + |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))} + |""".stripMargin } object WiringUtils { @@ -87,12 +87,12 @@ object WiringUtils { val childrenMap = new ChildrenMap() def getChildren(mname: String)(s: Statement): Unit = s match { case s: WDefInstance => - childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) ) + childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module)) case s: DefInstance => - childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) ) + childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module)) case s => s.foreach(getChildren(mname)) } - c.modules.foreach{ m => + c.modules.foreach { m => childrenMap(m.name) = Nil m.foreach(getChildren(m.name)) } @@ -103,7 +103,7 @@ object WiringUtils { */ @deprecated("Use DiGraph/InstanceGraph", "1.1.1") def getLineage(childrenMap: ChildrenMap, module: String): Lineage = - Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } ) + Lineage(module, childrenMap(module).map { case (i, m) => (i, getLineage(childrenMap, m)) }) /** Return a map of sink instances to source instances that minimizes * distance @@ -114,22 +114,25 @@ object WiringUtils { * @return a map of sink instance names to source instance names * @throws WiringException if a sink is equidistant to two sources */ - @deprecated("This method can lead to non-determinism in your compiler pass and exposes internal details." + - " Please file an issue with firrtl if you have a use case!", "Firrtl 1.4") + @deprecated( + "This method can lead to non-determinism in your compiler pass and exposes internal details." + + " Please file an issue with firrtl if you have a use case!", + "Firrtl 1.4" + ) def sinksToSources(sinks: Seq[Named], source: String, i: InstanceGraph): Map[Seq[WDefInstance], Seq[WDefInstance]] = { // The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap. val owners = new mutable.LinkedHashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]] val queue = new mutable.Queue[Seq[WDefInstance]] val visited = new mutable.HashMap[Seq[WDefInstance], Boolean].withDefaultValue(false) - val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v } + val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v } sourcePaths.flatten.foreach { l => queue.enqueue(l) owners(l) = Vector(l) } val sinkModuleNames = sinks.map(getModuleName).toSet - val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v } + val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v } // sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten @@ -156,8 +159,8 @@ object WiringUtils { // [todo] This is the critical section edges - .filter( e => !visited(e) && e.nonEmpty ) - .foreach{ v => + .filter(e => !visited(e) && e.nonEmpty) + .foreach { v => owners(v) = owners.getOrElse(v, Vector()) ++ owners(u) queue.enqueue(v) } @@ -167,8 +170,8 @@ object WiringUtils { // this should fail is if a sink is equidistant to two sources. sinkInsts.foreach { s => if (!owners.contains(s) || owners(s).size > 1) { - throw new WiringException( - s"Unable to determine source mapping for sink '${s.map(_.name)}'") } + throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'") + } } } @@ -184,21 +187,24 @@ object WiringUtils { * @return a map of sink instance names to source instance names * @throws WiringException if a sink is equidistant to two sources */ - private[firrtl] def sinksToSourcesSeq(sinks: Seq[Named], source: String, i: InstanceKeyGraph): - Seq[(Seq[InstanceKey], Seq[InstanceKey])] = { + private[firrtl] def sinksToSourcesSeq( + sinks: Seq[Named], + source: String, + i: InstanceKeyGraph + ): Seq[(Seq[InstanceKey], Seq[InstanceKey])] = { // The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap. val owners = new mutable.LinkedHashMap[Seq[InstanceKey], Vector[Seq[InstanceKey]]] val queue = new mutable.Queue[Seq[InstanceKey]] val visited = new mutable.HashMap[Seq[InstanceKey], Boolean].withDefaultValue(false) - val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v } + val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v } sourcePaths.flatten.foreach { l => queue.enqueue(l) owners(l) = Vector(l) } val sinkModuleNames = sinks.map(getModuleName).toSet - val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v } + val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v } // sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten @@ -225,8 +231,8 @@ object WiringUtils { // [todo] This is the critical section edges - .filter( e => !visited(e) && e.nonEmpty ) - .foreach{ v => + .filter(e => !visited(e) && e.nonEmpty) + .foreach { v => owners(v) = owners.getOrElse(v, Vector()) ++ owners(u) queue.enqueue(v) } @@ -236,8 +242,8 @@ object WiringUtils { // this should fail is if a sink is equidistant to two sources. sinkInsts.foreach { s => if (!owners.contains(s) || owners(s).size > 1) { - throw new WiringException( - s"Unable to determine source mapping for sink '${s.map(_.name)}'") } + throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'") + } } } @@ -249,8 +255,7 @@ object WiringUtils { n match { case ModuleName(m, _) => m case ComponentName(_, ModuleName(m, _)) => m - case _ => throw new WiringException( - "Only Components or Modules have an associated Module name") + case _ => throw new WiringException("Only Components or Modules have an associated Module name") } } @@ -266,9 +271,9 @@ object WiringUtils { def getType(c: Circuit, module: String, comp: String): Type = { def getRoot(e: Expression): String = e match { case r: Reference => r.name - case i: SubIndex => getRoot(i.expr) + case i: SubIndex => getRoot(i.expr) case a: SubAccess => getRoot(a.expr) - case f: SubField => getRoot(f.expr) + case f: SubField => getRoot(f.expr) } val eComp = toExp(comp) val root = getRoot(eComp) @@ -289,11 +294,12 @@ object WiringUtils { case sx: DefMemory if sx.name == root => tpe = Some(MemPortUtils.memType(sx)) sx - case sx => sx map getType + case sx => sx.map(getType) + } + val m = c.modules.find(_.name == module).getOrElse { + throw new WiringException(s"Must have a module named $module") } - val m = c.modules find (_.name == module) getOrElse { - throw new WiringException(s"Must have a module named $module") } - tpe = m.ports find (_.name == root) map (_.tpe) + tpe = m.ports.find(_.name == root).map(_.tpe) m match { case Module(i, n, ps, b) => getType(b) case e: ExtModule => @@ -301,10 +307,10 @@ object WiringUtils { tpe match { case None => throw new WiringException(s"Didn't find $comp in $module!") case Some(t) => - def setType(e: Expression): Expression = e map setType match { + def setType(e: Expression): Expression = e.map(setType) match { case ex: Reference => ex.copy(tpe = t) - case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) - case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) + case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name)) + case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe)) case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe)) } setType(eComp).tpe diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala index 41a7e1de..5b9dd371 100644 --- a/src/main/scala/firrtl/proto/FromProto.scala +++ b/src/main/scala/firrtl/proto/FromProto.scala @@ -35,9 +35,9 @@ object FromProto { // Convert from ProtoBuf message repeated Statements to FIRRRTL Block private def compressStmts(stmts: scala.collection.Seq[ir.Statement]): ir.Statement = stmts match { - case scala.collection.Seq() => ir.EmptyStmt + case scala.collection.Seq() => ir.EmptyStmt case scala.collection.Seq(stmt) => stmt - case multiple => ir.Block(multiple.toSeq) + case multiple => ir.Block(multiple.toSeq) } def convert(info: Firrtl.SourceInfo): ir.Info = @@ -100,16 +100,16 @@ object FromProto { def convert(expr: Firrtl.Expression): ir.Expression = { import Firrtl.Expression._ expr.getExpressionCase.getNumber match { - case REFERENCE_FIELD_NUMBER => ir.Reference(expr.getReference.getId, ir.UnknownType) - case SUB_FIELD_FIELD_NUMBER => convert(expr.getSubField) - case SUB_INDEX_FIELD_NUMBER => convert(expr.getSubIndex) - case SUB_ACCESS_FIELD_NUMBER => convert(expr.getSubAccess) - case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral) - case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral) + case REFERENCE_FIELD_NUMBER => ir.Reference(expr.getReference.getId, ir.UnknownType) + case SUB_FIELD_FIELD_NUMBER => convert(expr.getSubField) + case SUB_INDEX_FIELD_NUMBER => convert(expr.getSubIndex) + case SUB_ACCESS_FIELD_NUMBER => convert(expr.getSubAccess) + case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral) + case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral) case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral) - case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) - case MUX_FIELD_NUMBER => convert(expr.getMux) - case VALID_IF_FIELD_NUMBER => convert(expr.getValidIf) + case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp) + case MUX_FIELD_NUMBER => convert(expr.getMux) + case VALID_IF_FIELD_NUMBER => convert(expr.getValidIf) } } @@ -123,8 +123,14 @@ object FromProto { ir.DefWire(convert(info), wire.getId, convert(wire.getType)) def convert(reg: Firrtl.Statement.Register, info: Firrtl.SourceInfo): ir.DefRegister = - ir.DefRegister(convert(info), reg.getId, convert(reg.getType), convert(reg.getClock), - convert(reg.getReset), convert(reg.getInit)) + ir.DefRegister( + convert(info), + reg.getId, + convert(reg.getType), + convert(reg.getClock), + convert(reg.getReset), + convert(reg.getInit) + ) def convert(node: Firrtl.Statement.Node, info: Firrtl.SourceInfo): ir.DefNode = ir.DefNode(convert(info), node.getId, convert(node.getExpression)) @@ -140,8 +146,8 @@ object FromProto { def convert(ruw: ReadUnderWrite): ir.ReadUnderWrite.Value = ruw match { case ReadUnderWrite.UNDEFINED => ir.ReadUnderWrite.Undefined - case ReadUnderWrite.OLD => ir.ReadUnderWrite.Old - case ReadUnderWrite.NEW => ir.ReadUnderWrite.New + case ReadUnderWrite.OLD => ir.ReadUnderWrite.Old + case ReadUnderWrite.NEW => ir.ReadUnderWrite.New } def convert(dt: Firrtl.Statement.CMemory.TypeAndDepth): (ir.Type, BigInt) = @@ -161,9 +167,9 @@ object FromProto { import Firrtl.Statement.MemoryPort.Direction._ def convert(mportdir: Firrtl.Statement.MemoryPort.Direction): MPortDir = mportdir match { - case MEMORY_PORT_DIRECTION_INFER => MInfer - case MEMORY_PORT_DIRECTION_READ => MRead - case MEMORY_PORT_DIRECTION_WRITE => MWrite + case MEMORY_PORT_DIRECTION_INFER => MInfer + case MEMORY_PORT_DIRECTION_READ => MRead + case MEMORY_PORT_DIRECTION_WRITE => MWrite case MEMORY_PORT_DIRECTION_READ_WRITE => MReadWrite } @@ -184,12 +190,18 @@ object FromProto { def convert(formal: Formal): ir.Formal.Value = formal match { case Formal.ASSERT => ir.Formal.Assert case Formal.ASSUME => ir.Formal.Assume - case Formal.COVER => ir.Formal.Cover + case Formal.COVER => ir.Formal.Cover } def convert(ver: Firrtl.Statement.Verification, info: Firrtl.SourceInfo): ir.Verification = - ir.Verification(convert(ver.getOp), convert(info), convert(ver.getClk), - convert(ver.getCond), convert(ver.getEn), ir.StringLit(ver.getMsg)) + ir.Verification( + convert(ver.getOp), + convert(info), + convert(ver.getClk), + convert(ver.getCond), + convert(ver.getEn), + ir.StringLit(ver.getMsg) + ) def convert(mem: Firrtl.Statement.Memory, info: Firrtl.SourceInfo): ir.DefMemory = { val dtype = convert(mem.getType) @@ -198,11 +210,21 @@ object FromProto { val rws = mem.getReadwriterIdList.asScala.toSeq import Firrtl.Statement.Memory._ val depth = mem.getDepthCase.getNumber match { - case UINT_DEPTH_FIELD_NUMBER => BigInt(mem.getUintDepth) + case UINT_DEPTH_FIELD_NUMBER => BigInt(mem.getUintDepth) case BIGINT_DEPTH_FIELD_NUMBER => convert(mem.getBigintDepth) } - ir.DefMemory(convert(info), mem.getId, dtype, depth, mem.getWriteLatency, mem.getReadLatency, - rs, ws, rws, convert(mem.getReadUnderWrite)) + ir.DefMemory( + convert(info), + mem.getId, + dtype, + depth, + mem.getWriteLatency, + mem.getReadLatency, + rs, + ws, + rws, + convert(mem.getReadUnderWrite) + ) } def convert(attach: Firrtl.Statement.Attach, info: Firrtl.SourceInfo): ir.Attach = { @@ -214,21 +236,21 @@ object FromProto { import Firrtl.Statement._ val info = stmt.getSourceInfo stmt.getStatementCase.getNumber match { - case NODE_FIELD_NUMBER => convert(stmt.getNode, info) - case CONNECT_FIELD_NUMBER => convert(stmt.getConnect, info) + case NODE_FIELD_NUMBER => convert(stmt.getNode, info) + case CONNECT_FIELD_NUMBER => convert(stmt.getConnect, info) case PARTIAL_CONNECT_FIELD_NUMBER => convert(stmt.getPartialConnect, info) - case WIRE_FIELD_NUMBER => convert(stmt.getWire, info) - case REGISTER_FIELD_NUMBER => convert(stmt.getRegister, info) - case WHEN_FIELD_NUMBER => convert(stmt.getWhen, info) - case INSTANCE_FIELD_NUMBER => convert(stmt.getInstance, info) - case PRINTF_FIELD_NUMBER => convert(stmt.getPrintf, info) - case STOP_FIELD_NUMBER => convert(stmt.getStop, info) - case MEMORY_FIELD_NUMBER => convert(stmt.getMemory, info) + case WIRE_FIELD_NUMBER => convert(stmt.getWire, info) + case REGISTER_FIELD_NUMBER => convert(stmt.getRegister, info) + case WHEN_FIELD_NUMBER => convert(stmt.getWhen, info) + case INSTANCE_FIELD_NUMBER => convert(stmt.getInstance, info) + case PRINTF_FIELD_NUMBER => convert(stmt.getPrintf, info) + case STOP_FIELD_NUMBER => convert(stmt.getStop, info) + case MEMORY_FIELD_NUMBER => convert(stmt.getMemory, info) case IS_INVALID_FIELD_NUMBER => ir.IsInvalid(convert(info), convert(stmt.getIsInvalid.getExpression)) - case CMEMORY_FIELD_NUMBER => convert(stmt.getCmemory, info) + case CMEMORY_FIELD_NUMBER => convert(stmt.getCmemory, info) case MEMORY_PORT_FIELD_NUMBER => convert(stmt.getMemoryPort, info) - case ATTACH_FIELD_NUMBER => convert(stmt.getAttach, info) + case ATTACH_FIELD_NUMBER => convert(stmt.getAttach, info) } } @@ -244,7 +266,7 @@ object FromProto { val w = if (ut.hasWidth) convert(ut.getWidth) else ir.UnknownWidth ir.UIntType(w) } - + def convert(st: Firrtl.Type.SIntType): ir.SIntType = { val w = if (st.hasWidth) convert(st.getWidth) else ir.UnknownWidth ir.SIntType(w) @@ -272,13 +294,13 @@ object FromProto { def convert(tpe: Firrtl.Type): ir.Type = { import Firrtl.Type._ tpe.getTypeCase.getNumber match { - case UINT_TYPE_FIELD_NUMBER => convert(tpe.getUintType) - case SINT_TYPE_FIELD_NUMBER => convert(tpe.getSintType) - case FIXED_TYPE_FIELD_NUMBER => convert(tpe.getFixedType) - case CLOCK_TYPE_FIELD_NUMBER => ir.ClockType + case UINT_TYPE_FIELD_NUMBER => convert(tpe.getUintType) + case SINT_TYPE_FIELD_NUMBER => convert(tpe.getSintType) + case FIXED_TYPE_FIELD_NUMBER => convert(tpe.getFixedType) + case CLOCK_TYPE_FIELD_NUMBER => ir.ClockType case ASYNC_RESET_TYPE_FIELD_NUMBER => ir.AsyncResetType - case RESET_TYPE_FIELD_NUMBER => ir.ResetType - case ANALOG_TYPE_FIELD_NUMBER => convert(tpe.getAnalogType) + case RESET_TYPE_FIELD_NUMBER => ir.ResetType + case ANALOG_TYPE_FIELD_NUMBER => convert(tpe.getAnalogType) case BUNDLE_TYPE_FIELD_NUMBER => ir.BundleType(tpe.getBundleType.getFieldList.asScala.map(convert(_)).toSeq) case VECTOR_TYPE_FIELD_NUMBER => convert(tpe.getVectorType) @@ -287,7 +309,7 @@ object FromProto { def convert(dir: Firrtl.Port.Direction): ir.Direction = { dir match { - case Firrtl.Port.Direction.PORT_DIRECTION_IN => ir.Input + case Firrtl.Port.Direction.PORT_DIRECTION_IN => ir.Input case Firrtl.Port.Direction.PORT_DIRECTION_OUT => ir.Output } } @@ -302,9 +324,9 @@ object FromProto { import Firrtl.Module.ExternalModule.Parameter._ val name = param.getId param.getValueCase.getNumber match { - case INTEGER_FIELD_NUMBER => ir.IntParam(name, convert(param.getInteger)) - case DOUBLE_FIELD_NUMBER => ir.DoubleParam(name, param.getDouble) - case STRING_FIELD_NUMBER => ir.StringParam(name, ir.StringLit(param.getString)) + case INTEGER_FIELD_NUMBER => ir.IntParam(name, convert(param.getInteger)) + case DOUBLE_FIELD_NUMBER => ir.DoubleParam(name, param.getDouble) + case STRING_FIELD_NUMBER => ir.StringParam(name, ir.StringLit(param.getString)) case RAW_STRING_FIELD_NUMBER => ir.RawStringParam(name, param.getRawString) } } diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala index 47fb3cec..78b95582 100644 --- a/src/main/scala/firrtl/proto/ToProto.scala +++ b/src/main/scala/firrtl/proto/ToProto.scala @@ -6,7 +6,7 @@ package proto import java.io.OutputStream import FirrtlProtos._ -import Firrtl.Statement.{ReadUnderWrite, Formal} +import Firrtl.Statement.{Formal, ReadUnderWrite} import Firrtl.Expression.PrimOp.Op import com.google.protobuf.{CodedOutputStream, WireFormat} import firrtl.PrimOps._ @@ -15,7 +15,6 @@ import scala.collection.JavaConverters._ object ToProto { - /** Serialize a FIRRTL Circuit to an Output Stream as a ProtoBuf message * * @param ostream Output stream that will be written @@ -38,9 +37,9 @@ object ToProto { // Note this function is sensitive to changes to the Firrtl and Circuit protobuf message definitions def writeToStreamFast( ostream: OutputStream, - info: ir.Info, + info: ir.Info, modules: Seq[() => ir.DefModule], - main: String + main: String ): Unit = { val costream = CodedOutputStream.newInstance(ostream) @@ -110,23 +109,25 @@ object ToProto { def convert(ruw: ir.ReadUnderWrite.Value): ReadUnderWrite = ruw match { case ir.ReadUnderWrite.Undefined => ReadUnderWrite.UNDEFINED - case ir.ReadUnderWrite.Old => ReadUnderWrite.OLD - case ir.ReadUnderWrite.New => ReadUnderWrite.NEW + case ir.ReadUnderWrite.Old => ReadUnderWrite.OLD + case ir.ReadUnderWrite.New => ReadUnderWrite.NEW } def convert(formal: ir.Formal.Value): Formal = formal match { case ir.Formal.Assert => Formal.ASSERT case ir.Formal.Assume => Formal.ASSUME - case ir.Formal.Cover => Formal.COVER + case ir.Formal.Cover => Formal.COVER } def convertToIntegerLiteral(value: BigInt): Firrtl.Expression.IntegerLiteral.Builder = { - Firrtl.Expression.IntegerLiteral.newBuilder() + Firrtl.Expression.IntegerLiteral + .newBuilder() .setValue(value.toString) } def convertToBigInt(value: BigInt): Firrtl.BigInt.Builder = { - Firrtl.BigInt.newBuilder() + Firrtl.BigInt + .newBuilder() .setValue(com.google.protobuf.ByteString.copyFrom(value.toByteArray)) } @@ -135,7 +136,7 @@ object ToProto { info match { case ir.NoInfo => ib.setNone(Firrtl.SourceInfo.None.newBuilder) - case f : ir.FileInfo => + case f: ir.FileInfo => ib.setText(f.unescaped) // TODO properly implement MultiInfo case ir.MultiInfo(infos) => @@ -148,54 +149,64 @@ object ToProto { val eb = Firrtl.Expression.newBuilder() expr match { case ir.Reference(name, _, _, _) => - val rb = Firrtl.Expression.Reference.newBuilder() + val rb = Firrtl.Expression.Reference + .newBuilder() .setId(name) eb.setReference(rb) case ir.SubField(e, name, _, _) => - val sb = Firrtl.Expression.SubField.newBuilder() + val sb = Firrtl.Expression.SubField + .newBuilder() .setExpression(convert(e)) .setField(name) eb.setSubField(sb) case ir.SubIndex(e, value, _, _) => - val sb = Firrtl.Expression.SubIndex.newBuilder() + val sb = Firrtl.Expression.SubIndex + .newBuilder() .setExpression(convert(e)) .setIndex(convertToIntegerLiteral(value)) eb.setSubIndex(sb) case ir.SubAccess(e, index, _, _) => - val sb = Firrtl.Expression.SubAccess.newBuilder() + val sb = Firrtl.Expression.SubAccess + .newBuilder() .setExpression(convert(e)) .setIndex(convert(index)) eb.setSubAccess(sb) case ir.UIntLiteral(value, width) => - val ub = Firrtl.Expression.UIntLiteral.newBuilder() + val ub = Firrtl.Expression.UIntLiteral + .newBuilder() .setValue(convertToIntegerLiteral(value)) convert(width).foreach(ub.setWidth) eb.setUintLiteral(ub) case ir.SIntLiteral(value, width) => - val sb = Firrtl.Expression.SIntLiteral.newBuilder() + val sb = Firrtl.Expression.SIntLiteral + .newBuilder() .setValue(convertToIntegerLiteral(value)) convert(width).foreach(sb.setWidth) eb.setSintLiteral(sb) case ir.FixedLiteral(value, width, point) => - val fb = Firrtl.Expression.FixedLiteral.newBuilder() + val fb = Firrtl.Expression.FixedLiteral + .newBuilder() .setValue(convertToBigInt(value)) convert(width).foreach(fb.setWidth) convert(point).foreach(fb.setPoint) eb.setFixedLiteral(fb) case ir.DoPrim(op, args, consts, _) => - val db = Firrtl.Expression.PrimOp.newBuilder() + val db = Firrtl.Expression.PrimOp + .newBuilder() .setOp(convert(op)) consts.foreach(c => db.addConst(convertToIntegerLiteral(c))) args.foreach(a => db.addArg(convert(a))) eb.setPrimOp(db) case ir.Mux(cond, tval, fval, _) => - val mb = Firrtl.Expression.Mux.newBuilder() + val mb = Firrtl.Expression.Mux + .newBuilder() .setCondition(convert(cond)) .setTValue(convert(tval)) .setFValue(convert(fval)) eb.setMux(mb) case ir.ValidIf(cond, value, _) => - val vb = Firrtl.Expression.ValidIf.newBuilder() + val vb = Firrtl.Expression.ValidIf + .newBuilder() .setCondition(convert(cond)) .setValue(convert(value)) eb.setValidIf(vb) @@ -205,37 +216,41 @@ object ToProto { def convert(dir: MPortDir): Firrtl.Statement.MemoryPort.Direction = { import Firrtl.Statement.MemoryPort.Direction._ dir match { - case MInfer => MEMORY_PORT_DIRECTION_INFER - case MRead => MEMORY_PORT_DIRECTION_READ - case MWrite => MEMORY_PORT_DIRECTION_WRITE + case MInfer => MEMORY_PORT_DIRECTION_INFER + case MRead => MEMORY_PORT_DIRECTION_READ + case MWrite => MEMORY_PORT_DIRECTION_WRITE case MReadWrite => MEMORY_PORT_DIRECTION_READ_WRITE } } def convert(tpe: ir.Type, depth: BigInt): Firrtl.Statement.CMemory.TypeAndDepth.Builder = - Firrtl.Statement.CMemory.TypeAndDepth.newBuilder() + Firrtl.Statement.CMemory.TypeAndDepth + .newBuilder() .setDataType(convert(tpe)) .setDepth(convertToBigInt(depth)) def convert(stmt: ir.Statement): Seq[Firrtl.Statement.Builder] = { stmt match { case ir.Block(stmts) => stmts.flatMap(convert(_)) - case ir.EmptyStmt => Seq.empty + case ir.EmptyStmt => Seq.empty case other => val sb = Firrtl.Statement.newBuilder() other match { case ir.DefNode(_, name, expr) => - val nb = Firrtl.Statement.Node.newBuilder() + val nb = Firrtl.Statement.Node + .newBuilder() .setId(name) .setExpression(convert(expr)) sb.setNode(nb) case ir.DefWire(_, name, tpe) => - val wb = Firrtl.Statement.Wire.newBuilder() + val wb = Firrtl.Statement.Wire + .newBuilder() .setId(name) .setType(convert(tpe)) sb.setWire(wb) case ir.DefRegister(_, name, tpe, clock, reset, init) => - val rb = Firrtl.Statement.Register.newBuilder() + val rb = Firrtl.Statement.Register + .newBuilder() .setId(name) .setType(convert(tpe)) .setClock(convert(clock)) @@ -243,54 +258,63 @@ object ToProto { .setInit(convert(init)) sb.setRegister(rb) case ir.DefInstance(_, name, module, _) => - val ib = Firrtl.Statement.Instance.newBuilder() + val ib = Firrtl.Statement.Instance + .newBuilder() .setId(name) .setModuleId(module) sb.setInstance(ib) case ir.Connect(_, loc, expr) => - val cb = Firrtl.Statement.Connect.newBuilder() + val cb = Firrtl.Statement.Connect + .newBuilder() .setLocation(convert(loc)) .setExpression(convert(expr)) sb.setConnect(cb) case ir.PartialConnect(_, loc, expr) => - val cb = Firrtl.Statement.PartialConnect.newBuilder() + val cb = Firrtl.Statement.PartialConnect + .newBuilder() .setLocation(convert(loc)) .setExpression(convert(expr)) sb.setPartialConnect(cb) case ir.Conditionally(_, pred, conseq, alt) => val cs = convert(conseq) val as = convert(alt) - val wb = Firrtl.Statement.When.newBuilder() + val wb = Firrtl.Statement.When + .newBuilder() .setPredicate(convert(pred)) cs.foreach(wb.addConsequent) as.foreach(wb.addOtherwise) sb.setWhen(wb) case ir.Print(_, string, args, clk, en) => - val pb = Firrtl.Statement.Printf.newBuilder() + val pb = Firrtl.Statement.Printf + .newBuilder() .setValue(string.string) .setClk(convert(clk)) .setEn(convert(en)) args.foreach(a => pb.addArg(convert(a))) sb.setPrintf(pb) case ir.Stop(_, ret, clk, en) => - val stopb = Firrtl.Statement.Stop.newBuilder() + val stopb = Firrtl.Statement.Stop + .newBuilder() .setReturnValue(ret) .setClk(convert(clk)) .setEn(convert(en)) sb.setStop(stopb) case ir.Verification(op, _, clk, cond, en, msg) => - val vb = Firrtl.Statement.Verification.newBuilder() + val vb = Firrtl.Statement.Verification + .newBuilder() .setOp(convert(op)) .setClk(convert(clk)) .setCond(convert(cond)) .setEn(convert(en)) .setMsg(msg.string) case ir.IsInvalid(_, expr) => - val ib = Firrtl.Statement.IsInvalid.newBuilder() + val ib = Firrtl.Statement.IsInvalid + .newBuilder() .setExpression(convert(expr)) sb.setIsInvalid(ib) case ir.DefMemory(_, name, dtype, depth, wlat, rlat, rs, ws, rws, ruw) => - val mem = Firrtl.Statement.Memory.newBuilder() + val mem = Firrtl.Statement.Memory + .newBuilder() .setId(name) .setType(convert(dtype)) .setBigintDepth(convertToBigInt(depth)) @@ -302,14 +326,16 @@ object ToProto { mem.addAllReadwriterId(rws.asJava) sb.setMemory(mem) case CDefMemory(_, name, tpe, size, seq, ruw) => - val mb = Firrtl.Statement.CMemory.newBuilder() + val mb = Firrtl.Statement.CMemory + .newBuilder() .setId(name) .setTypeAndDepth(convert(tpe, size)) .setSyncRead(seq) .setReadUnderWrite(convert(ruw)) sb.setCmemory(mb) case CDefMPort(_, name, _, mem, exprs, dir) => - val pb = Firrtl.Statement.MemoryPort.newBuilder() + val pb = Firrtl.Statement.MemoryPort + .newBuilder() .setId(name) .setMemoryId(mem) .setMemoryIndex(convert(exprs.head)) @@ -330,7 +356,8 @@ object ToProto { } def convert(field: ir.Field): Firrtl.Type.BundleType.Field.Builder = { - val b = Firrtl.Type.BundleType.Field.newBuilder() + val b = Firrtl.Type.BundleType.Field + .newBuilder() .setId(field.name) .setIsFlipped(field.flip == ir.Flip) .setType(convert(field.tpe)) @@ -343,12 +370,13 @@ object ToProto { * @return Option width where None means the width field should be cleared in the parent object */ def convert(width: ir.Width): Option[Firrtl.Width.Builder] = width match { - case ir.IntWidth(w) => Some(Firrtl.Width.newBuilder().setValue(w.toInt)) + case ir.IntWidth(w) => Some(Firrtl.Width.newBuilder().setValue(w.toInt)) case ir.UnknownWidth => None } def convert(vtpe: ir.VectorType): Firrtl.Type.VectorType.Builder = - Firrtl.Type.VectorType.newBuilder() + Firrtl.Type.VectorType + .newBuilder() .setType(convert(vtpe.tpe)) .setSize(vtpe.size) @@ -379,7 +407,7 @@ object ToProto { tb.setResetType(rt) case ir.AnalogType(width) => val at = Firrtl.Type.AnalogType.newBuilder() - convert(width).foreach(at.setWidth) + convert(width).foreach(at.setWidth) tb.setAnalogType(at) case ir.BundleType(fields) => val bt = Firrtl.Type.BundleType.newBuilder() @@ -392,12 +420,13 @@ object ToProto { } def convert(direction: ir.Direction): Firrtl.Port.Direction = direction match { - case ir.Input => Firrtl.Port.Direction.PORT_DIRECTION_IN + case ir.Input => Firrtl.Port.Direction.PORT_DIRECTION_IN case ir.Output => Firrtl.Port.Direction.PORT_DIRECTION_OUT } def convert(port: ir.Port): Firrtl.Port.Builder = { - Firrtl.Port.newBuilder() + Firrtl.Port + .newBuilder() .setId(port.name) .setDirection(convert(port.direction)) .setType(convert(port.tpe)) @@ -405,7 +434,8 @@ object ToProto { def convert(param: ir.Param): Firrtl.Module.ExternalModule.Parameter.Builder = { import Firrtl.Module.ExternalModule._ - val pb = Parameter.newBuilder() + val pb = Parameter + .newBuilder() .setId(param.name) param match { case ir.IntParam(_, value) => @@ -425,13 +455,15 @@ object ToProto { module match { case mod: ir.Module => val stmts = convert(mod.body) - val mb = Firrtl.Module.UserModule.newBuilder() + val mb = Firrtl.Module.UserModule + .newBuilder() .setId(mod.name) ports.foreach(mb.addPort) stmts.foreach(mb.addStatement) b.setUserModule(mb) case ext: ir.ExtModule => - val eb = Firrtl.Module.ExternalModule.newBuilder() + val eb = Firrtl.Module.ExternalModule + .newBuilder() .setId(ext.name) .setDefinedName(ext.defname) ports.foreach(eb.addPort) @@ -448,7 +480,8 @@ object ToProto { for (m <- moduleBuilders) { cb.addModule(m) } - Firrtl.newBuilder() + Firrtl + .newBuilder() .addCircuit(cb.build()) .build() } diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala index d587fd8c..d37d2881 100644 --- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala +++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala @@ -35,14 +35,16 @@ sealed trait CircuitOption extends Unserializable { this: Annotation => case class FirrtlFileAnnotation(file: String) extends NoTargetAnnotation with CircuitOption { def toCircuit(info: Parser.InfoMode): FirrtlCircuitAnnotation = { - val circuit = try { - FirrtlStageUtils.getFileExtension(file) match { - case ProtoBufFile => proto.FromProto.fromFile(file) - case FirrtlFile => Parser.parseFile(file, info) } - } catch { - case a @ (_: FileNotFoundException | _: NoSuchFileException) => - throw new OptionsException(s"Input file '$file' not found! (Did you misspell it?)", a) - } + val circuit = + try { + FirrtlStageUtils.getFileExtension(file) match { + case ProtoBufFile => proto.FromProto.fromFile(file) + case FirrtlFile => Parser.parseFile(file, info) + } + } catch { + case a @ (_: FileNotFoundException | _: NoSuchFileException) => + throw new OptionsException(s"Input file '$file' not found! (Did you misspell it?)", a) + } FirrtlCircuitAnnotation(circuit) } @@ -52,11 +54,13 @@ object FirrtlFileAnnotation extends HasShellOptions { val options = Seq( new ShellOption[String]( - longOption = "input-file", + longOption = "input-file", toAnnotationSeq = a => Seq(FirrtlFileAnnotation(a)), - helpText = "An input FIRRTL file", - shortOption = Some("i"), - helpValueName = Some("<file>") ) ) + helpText = "An input FIRRTL file", + shortOption = Some("i"), + helpValueName = Some("<file>") + ) + ) } @@ -70,11 +74,13 @@ object OutputFileAnnotation extends HasShellOptions { val options = Seq( new ShellOption[String]( - longOption = "output-file", + longOption = "output-file", toAnnotationSeq = a => Seq(OutputFileAnnotation(a)), - helpText = "The output FIRRTL file", - shortOption = Some("o"), - helpValueName = Some("<file>") ) ) + helpText = "The output FIRRTL file", + shortOption = Some("o"), + helpValueName = Some("<file>") + ) + ) } @@ -84,8 +90,10 @@ object OutputFileAnnotation extends HasShellOptions { * @note This cannote be directly converted to [[Parser.InfoMode]] as that depends on an optional [[FirrtlFileAnnotation]] */ case class InfoModeAnnotation(modeName: String = "use") extends NoTargetAnnotation with FirrtlOption { - require(modeName match { case "use" | "ignore" | "gen" | "append" => true; case _ => false }, - s"Unknown info mode '$modeName'! (Did you misspell it?)") + require( + modeName match { case "use" | "ignore" | "gen" | "append" => true; case _ => false }, + s"Unknown info mode '$modeName'! (Did you misspell it?)" + ) /** Return the [[Parser.InfoMode]] equivalent for this [[firrtl.annotations.Annotation Annotation]] * @param infoSource the name of a file to use for "gen" or "append" info modes @@ -93,7 +101,7 @@ case class InfoModeAnnotation(modeName: String = "use") extends NoTargetAnnotati def toInfoMode(infoSource: Option[String] = None): Parser.InfoMode = modeName match { case "use" => Parser.UseInfo case "ignore" => Parser.IgnoreInfo - case _ => + case _ => val a = infoSource.getOrElse("unknown source") modeName match { case "gen" => Parser.GenInfo(a) @@ -106,10 +114,12 @@ object InfoModeAnnotation extends HasShellOptions { val options = Seq( new ShellOption[String]( - longOption = "info-mode", + longOption = "info-mode", toAnnotationSeq = a => Seq(InfoModeAnnotation(a)), - helpText = s"Source file info handling mode (default: ${apply().modeName})", - helpValueName = Some("<ignore|use|gen|append>") ) ) + helpText = s"Source file info handling mode (default: ${apply().modeName})", + helpValueName = Some("<ignore|use|gen|append>") + ) + ) } @@ -128,10 +138,12 @@ object FirrtlSourceAnnotation extends HasShellOptions { val options = Seq( new ShellOption[String]( - longOption = "firrtl-source", + longOption = "firrtl-source", toAnnotationSeq = a => Seq(FirrtlSourceAnnotation(a)), - helpText = "An input FIRRTL circuit string", - helpValueName = Some("<string>") ) ) + helpText = "An input FIRRTL circuit string", + helpValueName = Some("<string>") + ) + ) } @@ -144,27 +156,29 @@ case class CompilerAnnotation(compiler: Compiler = new VerilogCompiler()) extend object CompilerAnnotation extends HasShellOptions { - private [firrtl] def apply(compilerName: String): CompilerAnnotation = { + private[firrtl] def apply(compilerName: String): CompilerAnnotation = { val c = compilerName match { - case "none" => new NoneCompiler() - case "high" => new HighFirrtlCompiler() - case "low" => new LowFirrtlCompiler() - case "middle" => new MiddleFirrtlCompiler() - case "verilog" => new VerilogCompiler() - case "mverilog" => new MinimumVerilogCompiler() - case "sverilog" => new SystemVerilogCompiler() - case _ => throw new OptionsException(s"Unknown compiler name '$compilerName'! (Did you misspell it?)") + case "none" => new NoneCompiler() + case "high" => new HighFirrtlCompiler() + case "low" => new LowFirrtlCompiler() + case "middle" => new MiddleFirrtlCompiler() + case "verilog" => new VerilogCompiler() + case "mverilog" => new MinimumVerilogCompiler() + case "sverilog" => new SystemVerilogCompiler() + case _ => throw new OptionsException(s"Unknown compiler name '$compilerName'! (Did you misspell it?)") } CompilerAnnotation(c) } val options = Seq( new ShellOption[String]( - longOption = "compiler", + longOption = "compiler", toAnnotationSeq = a => Seq(CompilerAnnotation(a)), - helpText = "The FIRRTL compiler to use (default: verilog)", - shortOption = Some("X"), - helpValueName = Some("<none|high|middle|low|verilog|mverilog|sverilog>") ) ) + helpText = "The FIRRTL compiler to use (default: verilog)", + shortOption = Some("X"), + helpValueName = Some("<none|high|middle|low|verilog|mverilog|sverilog>") + ) + ) } @@ -188,21 +202,26 @@ object RunFirrtlTransformAnnotation extends HasShellOptions { val tx = Class.forName(txName).asInstanceOf[Class[_ <: Transform]].newInstance() RunFirrtlTransformAnnotation(tx) } catch { - case e: ClassNotFoundException => throw new OptionsException( - s"Unable to locate custom transform $txName (did you misspell it?)", e) - case e: InstantiationException => throw new OptionsException( - s"Unable to create instance of Transform $txName (is this an anonymous class?)", e) - case e: Throwable => throw new OptionsException( - s"Unknown error when instantiating class $txName", e) }), + case e: ClassNotFoundException => + throw new OptionsException(s"Unable to locate custom transform $txName (did you misspell it?)", e) + case e: InstantiationException => + throw new OptionsException( + s"Unable to create instance of Transform $txName (is this an anonymous class?)", + e + ) + case e: Throwable => throw new OptionsException(s"Unknown error when instantiating class $txName", e) + } + ), helpText = "Run these transforms during compilation", shortOption = Some("fct"), - helpValueName = Some("<package>.<class>") ), + helpValueName = Some("<package>.<class>") + ), new ShellOption[String]( longOption = "change-name-case", toAnnotationSeq = _ match { case "lower" => Seq(RunFirrtlTransformAnnotation(new firrtl.features.LowerCaseNames)) case "upper" => Seq(RunFirrtlTransformAnnotation(new firrtl.features.UpperCaseNames)) - case a => throw new OptionsException(s"Unknown case '$a'. Did you misspell it?") + case a => throw new OptionsException(s"Unknown case '$a'. Did you misspell it?") }, helpText = "Convert all FIRRTL names to a specific case", helpValueName = Some("<lower|upper>") @@ -231,9 +250,9 @@ case object SuppressScalaVersionWarning extends NoTargetAnnotation with FirrtlOp def longOption: String = "Wno-scala-version-warning" val options = Seq( new ShellOption[Unit]( - longOption = longOption, + longOption = longOption, toAnnotationSeq = { _ => Seq(this) }, - helpText = "Suppress Scala 2.11 deprecation warning (ignored in Scala 2.12+)" + helpText = "Suppress Scala 2.11 deprecation warning (ignored in Scala 2.12+)" ) ) } diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala index 39b89bea..fb5aa09f 100644 --- a/src/main/scala/firrtl/stage/FirrtlCli.scala +++ b/src/main/scala/firrtl/stage/FirrtlCli.scala @@ -11,16 +11,18 @@ import firrtl.transforms.NoCircuitDedupAnnotation */ trait FirrtlCli { this: Shell => parser.note("FIRRTL Compiler Options") - Seq( FirrtlFileAnnotation, - OutputFileAnnotation, - InfoModeAnnotation, - FirrtlSourceAnnotation, - CompilerAnnotation, - RunFirrtlTransformAnnotation, - firrtl.EmitCircuitAnnotation, - firrtl.EmitAllModulesAnnotation, - NoCircuitDedupAnnotation, - SuppressScalaVersionWarning) + Seq( + FirrtlFileAnnotation, + OutputFileAnnotation, + InfoModeAnnotation, + FirrtlSourceAnnotation, + CompilerAnnotation, + RunFirrtlTransformAnnotation, + firrtl.EmitCircuitAnnotation, + firrtl.EmitAllModulesAnnotation, + NoCircuitDedupAnnotation, + SuppressScalaVersionWarning + ) .map(_.addOptions(parser)) phases.DriverCompatibility.TopNameAnnotation.addOptions(parser) diff --git a/src/main/scala/firrtl/stage/FirrtlOptions.scala b/src/main/scala/firrtl/stage/FirrtlOptions.scala index 61dec7c5..55d4cc31 100644 --- a/src/main/scala/firrtl/stage/FirrtlOptions.scala +++ b/src/main/scala/firrtl/stage/FirrtlOptions.scala @@ -9,19 +9,17 @@ import firrtl.ir.Circuit * @param infoModeName the policy for generating [[firrtl.ir Info]] when processing FIRRTL (default: "append") * @param firrtlCircuit a [[firrtl.ir Circuit]] */ -class FirrtlOptions private [stage] ( - val outputFileName: Option[String] = None, - val infoModeName: String = InfoModeAnnotation().modeName, - val firrtlCircuit: Option[Circuit] = None) { +class FirrtlOptions private[stage] ( + val outputFileName: Option[String] = None, + val infoModeName: String = InfoModeAnnotation().modeName, + val firrtlCircuit: Option[Circuit] = None) { - private [stage] def copy( - outputFileName: Option[String] = outputFileName, - infoModeName: String = infoModeName, - firrtlCircuit: Option[Circuit] = firrtlCircuit ): FirrtlOptions = { + private[stage] def copy( + outputFileName: Option[String] = outputFileName, + infoModeName: String = infoModeName, + firrtlCircuit: Option[Circuit] = firrtlCircuit + ): FirrtlOptions = { - new FirrtlOptions( - outputFileName = outputFileName, - infoModeName = infoModeName, - firrtlCircuit = firrtlCircuit ) + new FirrtlOptions(outputFileName = outputFileName, infoModeName = infoModeName, firrtlCircuit = firrtlCircuit) } } diff --git a/src/main/scala/firrtl/stage/FirrtlStage.scala b/src/main/scala/firrtl/stage/FirrtlStage.scala index 1042f979..58d07e43 100644 --- a/src/main/scala/firrtl/stage/FirrtlStage.scala +++ b/src/main/scala/firrtl/stage/FirrtlStage.scala @@ -7,8 +7,7 @@ import firrtl.options.{Dependency, Phase, PhaseManager, Shell, Stage, StageMain} import firrtl.options.phases.DeletedWrapper import firrtl.stage.phases.CatchExceptions -class FirrtlPhase - extends PhaseManager(targets=Seq(Dependency[firrtl.stage.phases.Compiler])) { +class FirrtlPhase extends PhaseManager(targets = Seq(Dependency[firrtl.stage.phases.Compiler])) { override def invalidates(a: Phase) = false diff --git a/src/main/scala/firrtl/stage/FirrtlStageUtils.scala b/src/main/scala/firrtl/stage/FirrtlStageUtils.scala index e2304a92..aa9781db 100644 --- a/src/main/scala/firrtl/stage/FirrtlStageUtils.scala +++ b/src/main/scala/firrtl/stage/FirrtlStageUtils.scala @@ -2,14 +2,14 @@ package firrtl.stage -private [stage] sealed trait FileExtension -private [stage] case object FirrtlFile extends FileExtension -private [stage] case object ProtoBufFile extends FileExtension +private[stage] sealed trait FileExtension +private[stage] case object FirrtlFile extends FileExtension +private[stage] case object ProtoBufFile extends FileExtension /** Utilities that help with processing FIRRTL options */ object FirrtlStageUtils { - private [stage] def getFileExtension(file: String): FileExtension = file.drop(file.lastIndexOf('.')) match { + private[stage] def getFileExtension(file: String): FileExtension = file.drop(file.lastIndexOf('.')) match { case ".pb" => ProtoBufFile case _ => FirrtlFile } diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala index 636d0609..a0c5ea0c 100644 --- a/src/main/scala/firrtl/stage/Forms.scala +++ b/src/main/scala/firrtl/stage/Forms.scala @@ -17,28 +17,34 @@ object Forms { val ChirrtlForm: Seq[TransformDependency] = Seq.empty val MinimalHighForm: Seq[TransformDependency] = ChirrtlForm ++ - Seq( Dependency(passes.CheckChirrtl), - Dependency(passes.CInferTypes), - Dependency(passes.CInferMDir), - Dependency(passes.RemoveCHIRRTL), - Dependency[annotations.transforms.CleanupNamedTargets] ) + Seq( + Dependency(passes.CheckChirrtl), + Dependency(passes.CInferTypes), + Dependency(passes.CInferMDir), + Dependency(passes.RemoveCHIRRTL), + Dependency[annotations.transforms.CleanupNamedTargets] + ) val WorkingIR: Seq[TransformDependency] = MinimalHighForm :+ Dependency(passes.ToWorkingIR) val Checks: Seq[TransformDependency] = - Seq( Dependency(passes.CheckHighForm), - Dependency(passes.CheckTypes), - Dependency(passes.CheckFlows), - Dependency(passes.CheckWidths) ) + Seq( + Dependency(passes.CheckHighForm), + Dependency(passes.CheckTypes), + Dependency(passes.CheckFlows), + Dependency(passes.CheckWidths) + ) val Resolved: Seq[TransformDependency] = WorkingIR ++ Checks ++ - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferBinaryPoints], - Dependency[passes.TrimIntervals], - Dependency[passes.InferWidths], - Dependency[firrtl.transforms.InferResets] ) + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferBinaryPoints], + Dependency[passes.TrimIntervals], + Dependency[passes.InferWidths], + Dependency[firrtl.transforms.InferResets] + ) val Deduped: Seq[TransformDependency] = Resolved :+ Dependency[firrtl.transforms.DedupModules] @@ -49,61 +55,71 @@ object Forms { Deduped val MidForm: Seq[TransformDependency] = HighForm ++ - Seq( Dependency(passes.PullMuxes), - Dependency(passes.ReplaceAccesses), - Dependency(passes.ExpandConnects), - Dependency(passes.RemoveAccesses), - Dependency(passes.ZeroLengthVecs), - Dependency[passes.ExpandWhensAndCheck], - Dependency[passes.RemoveIntervals], - Dependency(passes.ConvertFixedToSInt), - Dependency(passes.ZeroWidth), - Dependency[firrtl.transforms.formal.AssertSubmoduleAssumptions] ) + Seq( + Dependency(passes.PullMuxes), + Dependency(passes.ReplaceAccesses), + Dependency(passes.ExpandConnects), + Dependency(passes.RemoveAccesses), + Dependency(passes.ZeroLengthVecs), + Dependency[passes.ExpandWhensAndCheck], + Dependency[passes.RemoveIntervals], + Dependency(passes.ConvertFixedToSInt), + Dependency(passes.ZeroWidth), + Dependency[firrtl.transforms.formal.AssertSubmoduleAssumptions] + ) val LowForm: Seq[TransformDependency] = MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), - Dependency(firrtl.transforms.RemoveReset), - Dependency[firrtl.transforms.CheckCombLoops], - Dependency[checks.CheckResets], - Dependency[firrtl.transforms.RemoveWires] ) + Seq( + Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(firrtl.transforms.RemoveReset), + Dependency[firrtl.transforms.CheckCombLoops], + Dependency[checks.CheckResets], + Dependency[firrtl.transforms.RemoveWires] + ) val LowFormMinimumOptimized: Seq[TransformDependency] = LowForm ++ - Seq( Dependency(passes.RemoveValidIf), - Dependency(passes.PadWidths), - Dependency(passes.memlib.VerilogMemDelays), - Dependency(passes.SplitExpressions), - Dependency[firrtl.transforms.LegalizeAndReductionsTransform] ) + Seq( + Dependency(passes.RemoveValidIf), + Dependency(passes.PadWidths), + Dependency(passes.memlib.VerilogMemDelays), + Dependency(passes.SplitExpressions), + Dependency[firrtl.transforms.LegalizeAndReductionsTransform] + ) val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.ConstantPropagation], - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination), - Dependency[firrtl.transforms.DeadCodeElimination] ) + Seq( + Dependency[firrtl.transforms.ConstantPropagation], + Dependency[firrtl.transforms.CombineCats], + Dependency(passes.CommonSubexpressionElimination), + Dependency[firrtl.transforms.DeadCodeElimination] + ) val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(passes.VerilogPrep), - Dependency[firrtl.AddDescriptionNodes] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] + ) val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogMinimumOptimized val AssertsRemoved: Seq[TransformDependency] = - Seq( Dependency(firrtl.transforms.formal.ConvertAsserts), - Dependency[firrtl.transforms.formal.RemoveVerificationStatements] ) + Seq( + Dependency(firrtl.transforms.formal.ConvertAsserts), + Dependency[firrtl.transforms.formal.RemoveVerificationStatements] + ) val BackendEmitters = - Seq( Dependency[VerilogEmitter], - Dependency[MinimumVerilogEmitter], - Dependency[SystemVerilogEmitter] ) + Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[SystemVerilogEmitter]) val LowEmitters = Dependency[LowFirrtlEmitter] +: BackendEmitters diff --git a/src/main/scala/firrtl/stage/TransformManager.scala b/src/main/scala/firrtl/stage/TransformManager.scala index 1b3032be..aa96ca86 100644 --- a/src/main/scala/firrtl/stage/TransformManager.scala +++ b/src/main/scala/firrtl/stage/TransformManager.scala @@ -12,15 +12,17 @@ import firrtl.options.{Dependency, DependencyManager} * @param knownObjects existing transform objects that have already been constructed */ class TransformManager( - val targets: Seq[TransformManager.TransformDependency], + val targets: Seq[TransformManager.TransformDependency], val currentState: Seq[TransformManager.TransformDependency] = Seq.empty, - val knownObjects: Set[Transform] = Set.empty) extends Transform + val knownObjects: Set[Transform] = Set.empty) + extends Transform with DependencyAPIMigration with DependencyManager[CircuitState, Transform] { override def execute(state: CircuitState): CircuitState = transform(state) - override protected def copy(a: Seq[Dependency[Transform]], b: Seq[Dependency[Transform]], c: Set[Transform]) = new TransformManager(a, b, c) + override protected def copy(a: Seq[Dependency[Transform]], b: Seq[Dependency[Transform]], c: Set[Transform]) = + new TransformManager(a, b, c) } diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala index 123c763a..37e2d13c 100644 --- a/src/main/scala/firrtl/stage/package.scala +++ b/src/main/scala/firrtl/stage/package.scala @@ -25,46 +25,50 @@ package object stage { /** * @todo custom transforms are appended as discovered, can this be prepended safely? */ - def view(options: AnnotationSeq): FirrtlOptions = options - .collect { case a: FirrtlOption => a } - .foldLeft(new FirrtlOptions()){ (c, x) => + def view(options: AnnotationSeq): FirrtlOptions = options.collect { case a: FirrtlOption => a } + .foldLeft(new FirrtlOptions()) { (c, x) => x match { - case OutputFileAnnotation(f) => c.copy(outputFileName = Some(f)) - case InfoModeAnnotation(i) => c.copy(infoModeName = i) - case FirrtlCircuitAnnotation(cir) => c.copy(firrtlCircuit = Some(cir)) - case a : CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a") ; c - case SuppressScalaVersionWarning => c + case OutputFileAnnotation(f) => c.copy(outputFileName = Some(f)) + case InfoModeAnnotation(i) => c.copy(infoModeName = i) + case FirrtlCircuitAnnotation(cir) => c.copy(firrtlCircuit = Some(cir)) + case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c + case SuppressScalaVersionWarning => c } } } - private [firrtl] implicit object FirrtlExecutionResultView extends OptionsView[FirrtlExecutionResult] with LazyLogging { + private[firrtl] implicit object FirrtlExecutionResultView + extends OptionsView[FirrtlExecutionResult] + with LazyLogging { def view(options: AnnotationSeq): FirrtlExecutionResult = { - val emittedRes = options - .collect{ case a: EmittedAnnotation[_] => a.value.value } + val emittedRes = options.collect { case a: EmittedAnnotation[_] => a.value.value } .mkString("\n") - val emitters = options.collect{ case RunFirrtlTransformAnnotation(e: Emitter) => e } - if(emitters.length > 1) { - logger.warn("More than one emitter used which cannot be accurately represented" + - "in the deprecated FirrtlExecutionResult: " + emitters.map(_.name).mkString(", ")) + val emitters = options.collect { case RunFirrtlTransformAnnotation(e: Emitter) => e } + if (emitters.length > 1) { + logger.warn( + "More than one emitter used which cannot be accurately represented" + + "in the deprecated FirrtlExecutionResult: " + emitters.map(_.name).mkString(", ") + ) } - val compilers = options.collect{ case CompilerAnnotation(c) => c } + val compilers = options.collect { case CompilerAnnotation(c) => c } val emitType = emitters.headOption.orElse(compilers.headOption).map(_.name).getOrElse("N/A") val form = emitters.headOption.orElse(compilers.headOption).map(_.outputForm).getOrElse(UnknownForm) - options.collectFirst{ case a: FirrtlCircuitAnnotation => a.circuit } match { + options.collectFirst { case a: FirrtlCircuitAnnotation => a.circuit } match { case None => FirrtlExecutionFailure("No circuit found in AnnotationSeq!") - case Some(a) => FirrtlExecutionSuccess( - emitType = emitType, - emitted = emittedRes, - circuitState = CircuitState( - circuit = a, - form = form, - annotations = options, - renames = None - )) + case Some(a) => + FirrtlExecutionSuccess( + emitType = emitType, + emitted = emittedRes, + circuitState = CircuitState( + circuit = a, + form = form, + annotations = options, + renames = None + ) + ) } } } diff --git a/src/main/scala/firrtl/stage/phases/AddCircuit.scala b/src/main/scala/firrtl/stage/phases/AddCircuit.scala index f3ff3372..c00e71b6 100644 --- a/src/main/scala/firrtl/stage/phases/AddCircuit.scala +++ b/src/main/scala/firrtl/stage/phases/AddCircuit.scala @@ -39,11 +39,10 @@ class AddCircuit extends Phase { * @throws $infoModeException */ private def infoMode(annotations: AnnotationSeq): Parser.InfoMode = { - val infoModeAnnotation = annotations - .collectFirst{ case a: InfoModeAnnotation => a } - .getOrElse { throw new PhasePrerequisiteException( - "An InfoModeAnnotation must be present (did you forget to run AddDefaults?)") } - val infoSource = annotations.collectFirst{ + val infoModeAnnotation = annotations.collectFirst { case a: InfoModeAnnotation => a }.getOrElse { + throw new PhasePrerequisiteException("An InfoModeAnnotation must be present (did you forget to run AddDefaults?)") + } + val infoSource = annotations.collectFirst { case FirrtlFileAnnotation(f) => f case _: FirrtlSourceAnnotation => "anonymous source" }.getOrElse("not defined") @@ -58,7 +57,7 @@ class AddCircuit extends Phase { lazy val info = infoMode(annotations) annotations.map { case a: CircuitOption => a.toCircuit(info) - case a => a + case a => a } } diff --git a/src/main/scala/firrtl/stage/phases/AddDefaults.scala b/src/main/scala/firrtl/stage/phases/AddDefaults.scala index d4c5bab4..9f4163cc 100644 --- a/src/main/scala/firrtl/stage/phases/AddDefaults.scala +++ b/src/main/scala/firrtl/stage/phases/AddDefaults.scala @@ -26,21 +26,21 @@ class AddDefaults extends Phase { var bb, c, em, im = true annotations.foreach { case _: BlackBoxTargetDirAnno => bb = false - case _: CompilerAnnotation => c = false - case _: InfoModeAnnotation => im = false - case RunFirrtlTransformAnnotation(_ : firrtl.Emitter) => em = false + case _: CompilerAnnotation => c = false + case _: InfoModeAnnotation => im = false + case RunFirrtlTransformAnnotation(_: firrtl.Emitter) => em = false case _ => } val default = new FirrtlOptions() - val targetDir = annotations - .collectFirst { case d: TargetDirAnnotation => d } - .getOrElse(TargetDirAnnotation()).directory - - (if (bb) Seq(BlackBoxTargetDirAnno(targetDir)) else Seq() ) ++ - // if there is no compiler or emitter specified, add the default emitter - (if (c && em) Seq(RunFirrtlTransformAnnotation(DefaultEmitterTarget)) else Seq() ) ++ - (if (im) Seq(InfoModeAnnotation()) else Seq() ) ++ + val targetDir = annotations.collectFirst { case d: TargetDirAnnotation => d } + .getOrElse(TargetDirAnnotation()) + .directory + + (if (bb) Seq(BlackBoxTargetDirAnno(targetDir)) else Seq()) ++ + // if there is no compiler or emitter specified, add the default emitter + (if (c && em) Seq(RunFirrtlTransformAnnotation(DefaultEmitterTarget)) else Seq()) ++ + (if (im) Seq(InfoModeAnnotation()) else Seq()) ++ annotations } diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala index edf62c3a..3c0a2388 100644 --- a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala +++ b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala @@ -18,16 +18,19 @@ class AddImplicitEmitter extends Phase { override def invalidates(a: Phase) = false def transform(annos: AnnotationSeq): AnnotationSeq = { - val emit = annos.collectFirst{ case a: EmitAnnotation => a } - val emitter = annos.collectFirst{ case RunFirrtlTransformAnnotation(e : Emitter) => e } - val compiler = annos.collectFirst{ case CompilerAnnotation(a) => a } + val emit = annos.collectFirst { case a: EmitAnnotation => a } + val emitter = annos.collectFirst { case RunFirrtlTransformAnnotation(e: Emitter) => e } + val compiler = annos.collectFirst { case CompilerAnnotation(a) => a } if (emit.isEmpty && (compiler.nonEmpty || emitter.nonEmpty)) { - annos.flatMap{ - case a: CompilerAnnotation => Seq(a, - RunFirrtlTransformAnnotation(compiler.get.emitter), - EmitCircuitAnnotation(compiler.get.emitter.getClass)) - case a @ RunFirrtlTransformAnnotation(e : Emitter) => Seq(a, EmitCircuitAnnotation(e.getClass)) + annos.flatMap { + case a: CompilerAnnotation => + Seq( + a, + RunFirrtlTransformAnnotation(compiler.get.emitter), + EmitCircuitAnnotation(compiler.get.emitter.getClass) + ) + case a @ RunFirrtlTransformAnnotation(e: Emitter) => Seq(a, EmitCircuitAnnotation(e.getClass)) case a => Some(a) } } else { diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala index f57e9c39..10af13d5 100644 --- a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala +++ b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala @@ -30,13 +30,12 @@ class AddImplicitOutputFile extends Phase { /** Add an [[OutputFileAnnotation]] to an [[AnnotationSeq]] */ def transform(annotations: AnnotationSeq): AnnotationSeq = - annotations - .collectFirst { case _: OutputFileAnnotation | _: EmitAllModulesAnnotation => annotations } - .getOrElse { - val topName = Viewer[FirrtlOptions].view(annotations) - .firrtlCircuit - .map(_.main) - .getOrElse("a") - OutputFileAnnotation(topName) +: annotations - } + annotations.collectFirst { case _: OutputFileAnnotation | _: EmitAllModulesAnnotation => annotations }.getOrElse { + val topName = Viewer[FirrtlOptions] + .view(annotations) + .firrtlCircuit + .map(_.main) + .getOrElse("a") + OutputFileAnnotation(topName) +: annotations + } } diff --git a/src/main/scala/firrtl/stage/phases/CatchExceptions.scala b/src/main/scala/firrtl/stage/phases/CatchExceptions.scala index f65ed481..5181653b 100644 --- a/src/main/scala/firrtl/stage/phases/CatchExceptions.scala +++ b/src/main/scala/firrtl/stage/phases/CatchExceptions.scala @@ -4,8 +4,12 @@ package firrtl.stage.phases import firrtl.options.{DependencyManagerException, OptionsException, Phase, PhaseException} import firrtl.{ - AnnotationSeq, CustomTransformException, FIRRTLException, - FirrtlInternalException, FirrtlUserException, Utils + AnnotationSeq, + CustomTransformException, + FIRRTLException, + FirrtlInternalException, + FirrtlUserException, + Utils } import scala.util.control.ControlThrowable @@ -27,15 +31,15 @@ class CatchExceptions(val underlying: Phase) extends Phase { } catch { /* Rethrow the exceptions which are expected or due to the runtime environment (out of memory, stack overflow, etc.). * Any UNEXPECTED exceptions should be treated as internal errors. */ - case p @ (_: ControlThrowable | _: FIRRTLException | _: OptionsException | _: FirrtlUserException - | _: FirrtlInternalException | _: PhaseException | _: DependencyManagerException) => throw p + case p @ (_: ControlThrowable | _: FIRRTLException | _: OptionsException | _: FirrtlUserException | + _: FirrtlInternalException | _: PhaseException | _: DependencyManagerException) => + throw p case CustomTransformException(cause) => throw cause case e: Exception => Utils.throwInternalError(exception = Some(e)) } } - object CatchExceptions { def apply(p: Phase): CatchExceptions = new CatchExceptions(p) diff --git a/src/main/scala/firrtl/stage/phases/Checks.scala b/src/main/scala/firrtl/stage/phases/Checks.scala index 7ecdc47e..6576d311 100644 --- a/src/main/scala/firrtl/stage/phases/Checks.scala +++ b/src/main/scala/firrtl/stage/phases/Checks.scala @@ -32,67 +32,70 @@ class Checks extends Phase { */ def transform(annos: AnnotationSeq): AnnotationSeq = { val inF, inS, eam, ec, outF, comp, emitter, im, inC = collection.mutable.ListBuffer[Annotation]() - annos.foreach( - _ match { - case a: FirrtlFileAnnotation => a +=: inF - case a: FirrtlSourceAnnotation => a +=: inS - case a: EmitAllModulesAnnotation => a +=: eam - case a: EmitCircuitAnnotation => a +=: ec - case a: OutputFileAnnotation => a +=: outF - case a: CompilerAnnotation => a +=: comp - case a: InfoModeAnnotation => a +=: im - case a: FirrtlCircuitAnnotation => a +=: inC - case a @ RunFirrtlTransformAnnotation(_ : firrtl.Emitter) => a +=: emitter - case _ => }) + annos.foreach(_ match { + case a: FirrtlFileAnnotation => a +=: inF + case a: FirrtlSourceAnnotation => a +=: inS + case a: EmitAllModulesAnnotation => a +=: eam + case a: EmitCircuitAnnotation => a +=: ec + case a: OutputFileAnnotation => a +=: outF + case a: CompilerAnnotation => a +=: comp + case a: InfoModeAnnotation => a +=: im + case a: FirrtlCircuitAnnotation => a +=: inC + case a @ RunFirrtlTransformAnnotation(_: firrtl.Emitter) => a +=: emitter + case _ => + }) /* At this point, only a FIRRTL Circuit should exist */ if (inF.isEmpty && inS.isEmpty && inC.isEmpty) { - throw new OptionsException( - s"""|Unable to determine FIRRTL source to read. None of the following were found: - | - an input file: -i, --input-file, FirrtlFileAnnotation - | - FIRRTL source: --firrtl-source, FirrtlSourceAnnotation - | - FIRRTL circuit: FirrtlCircuitAnnotation""".stripMargin )} + throw new OptionsException(s"""|Unable to determine FIRRTL source to read. None of the following were found: + | - an input file: -i, --input-file, FirrtlFileAnnotation + | - FIRRTL source: --firrtl-source, FirrtlSourceAnnotation + | - FIRRTL circuit: FirrtlCircuitAnnotation""".stripMargin) + } /* Only one FIRRTL input can exist */ if (inF.size + inS.size + inC.size > 1) { - throw new OptionsException( - s"""|Multiply defined input FIRRTL sources. More than one of the following was found: - | - an input file (${inF.size} times): -i, --input-file, FirrtlFileAnnotation - | - FIRRTL source (${inS.size} times): --firrtl-source, FirrtlSourceAnnotation - | - FIRRTL circuit (${inC.size} times): FirrtlCircuitAnnotation""".stripMargin )} + throw new OptionsException(s"""|Multiply defined input FIRRTL sources. More than one of the following was found: + | - an input file (${inF.size} times): -i, --input-file, FirrtlFileAnnotation + | - FIRRTL source (${inS.size} times): --firrtl-source, FirrtlSourceAnnotation + | - FIRRTL circuit (${inC.size} times): FirrtlCircuitAnnotation""".stripMargin) + } /* Specifying an output file and one-file-per module conflict */ if (eam.nonEmpty && outF.nonEmpty) { throw new OptionsException( s"""|Output file is incompatible with emit all modules annotation, but multiples were found: | - explicit output file (${outF.size} times): -o, --output-file, OutputFileAnnotation - | - one file per module (${eam.size} times): -e, --emit-modules, EmitAllModulesAnnotation""" - .stripMargin )} + | - one file per module (${eam.size} times): -e, --emit-modules, EmitAllModulesAnnotation""".stripMargin + ) + } /* Only one output file can be specified */ if (outF.size > 1) { - val x = outF.map{ case OutputFileAnnotation(x) => x } + val x = outF.map { case OutputFileAnnotation(x) => x } throw new OptionsException( s"""|No more than one output file can be specified, but found '${x.mkString(", ")}' specified via: - | - option or annotation: -o, --output-file, OutputFileAnnotation""".stripMargin) } + | - option or annotation: -o, --output-file, OutputFileAnnotation""".stripMargin + ) + } /* One mandatory compiler (or emitter) must be specified */ if (comp.size != 1 && emitter.isEmpty) { - val x = comp.map{ case CompilerAnnotation(x) => x } - val (msg, suggest) = if (comp.size == 0) { ("none found", "forget one of") } - else { (s"""found '${x.mkString(", ")}'""", "use multiple of") } - throw new OptionsException( - s"""|Exactly one compiler must be specified, but $msg. Did you $suggest the following? - | - an option or annotation: -X, --compiler, CompilerAnnotation""".stripMargin )} + val x = comp.map { case CompilerAnnotation(x) => x } + val (msg, suggest) = if (comp.size == 0) { ("none found", "forget one of") } + else { (s"""found '${x.mkString(", ")}'""", "use multiple of") } + throw new OptionsException(s"""|Exactly one compiler must be specified, but $msg. Did you $suggest the following? + | - an option or annotation: -X, --compiler, CompilerAnnotation""".stripMargin) + } /* One mandatory info mode must be specified */ if (im.size != 1) { - val x = im.map{ case InfoModeAnnotation(x) => x } - val (msg, suggest) = if (im.size == 0) { ("none found", "forget one of") } - else { (s"""found '${x.mkString(", ")}'""", "use multiple of") } - throw new OptionsException( - s"""|Exactly one info mode must be specified, but $msg. Did you $suggest the following? - | - an option or annotation: --info-mode, InfoModeAnnotation""".stripMargin )} + val x = im.map { case InfoModeAnnotation(x) => x } + val (msg, suggest) = if (im.size == 0) { ("none found", "forget one of") } + else { (s"""found '${x.mkString(", ")}'""", "use multiple of") } + throw new OptionsException(s"""|Exactly one info mode must be specified, but $msg. Did you $suggest the following? + | - an option or annotation: --info-mode, InfoModeAnnotation""".stripMargin) + } annos } diff --git a/src/main/scala/firrtl/stage/phases/Compiler.scala b/src/main/scala/firrtl/stage/phases/Compiler.scala index b73e3058..0d1181a6 100644 --- a/src/main/scala/firrtl/stage/phases/Compiler.scala +++ b/src/main/scala/firrtl/stage/phases/Compiler.scala @@ -10,17 +10,17 @@ import firrtl.stage.TransformManager.TransformDependency import scala.collection.mutable /** An encoding of the information necessary to run the FIRRTL compiler once */ -private [stage] case class CompilerRun( - stateIn: CircuitState, - stateOut: Option[CircuitState], +private[stage] case class CompilerRun( + stateIn: CircuitState, + stateOut: Option[CircuitState], transforms: Seq[Transform], - compiler: Option[FirrtlCompiler] ) + compiler: Option[FirrtlCompiler]) /** An encoding of possible defaults for a [[CompilerRun]] */ -private [stage] case class Defaults( +private[stage] case class Defaults( annotations: AnnotationSeq = Seq.empty, - transforms: Seq[Transform] = Seq.empty, - compiler: Option[FirrtlCompiler] = None) + transforms: Seq[Transform] = Seq.empty, + compiler: Option[FirrtlCompiler] = None) /** Runs the FIRRTL compilers on an [[AnnotationSeq]]. If the input [[AnnotationSeq]] contains more than one circuit * (i.e., more than one [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]]), then annotations will be @@ -45,11 +45,13 @@ private [stage] case class Defaults( class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { override def prerequisites = - Seq(Dependency[AddDefaults], - Dependency[AddImplicitEmitter], - Dependency[Checks], - Dependency[AddCircuit], - Dependency[AddImplicitOutputFile]) + Seq( + Dependency[AddDefaults], + Dependency[AddImplicitEmitter], + Dependency[Checks], + Dependency[AddCircuit], + Dependency[AddImplicitOutputFile] + ) override def optionalPrerequisiteOf = Seq.empty @@ -59,28 +61,30 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { protected def aToB(a: AnnotationSeq): Seq[CompilerRun] = { var foundFirstCircuit = false val c = mutable.ArrayBuffer.empty[CompilerRun] - a.foldLeft(Defaults()){ + a.foldLeft(Defaults()) { case (d, FirrtlCircuitAnnotation(circuit)) => foundFirstCircuit = true CompilerRun(CircuitState(circuit, ChirrtlForm, d.annotations, None), None, d.transforms, d.compiler) +=: c d - case (d, a) if foundFirstCircuit => a match { - case RunFirrtlTransformAnnotation(transform) => - c(0) = c(0).copy(transforms = transform +: c(0).transforms) - d - case CompilerAnnotation(compiler) => - c(0) = c(0).copy(compiler = Some(compiler)) - d - case annotation => - val state = c(0).stateIn - c(0) = c(0).copy(stateIn = state.copy(annotations = annotation +: state.annotations)) - d - } - case (d, a) if !foundFirstCircuit => a match { - case RunFirrtlTransformAnnotation(transform) => d.copy(transforms = transform +: d.transforms) - case CompilerAnnotation(compiler) => d.copy(compiler = Some(compiler)) - case annotation => d.copy(annotations = annotation +: d.annotations) - } + case (d, a) if foundFirstCircuit => + a match { + case RunFirrtlTransformAnnotation(transform) => + c(0) = c(0).copy(transforms = transform +: c(0).transforms) + d + case CompilerAnnotation(compiler) => + c(0) = c(0).copy(compiler = Some(compiler)) + d + case annotation => + val state = c(0).stateIn + c(0) = c(0).copy(stateIn = state.copy(annotations = annotation +: state.annotations)) + d + } + case (d, a) if !foundFirstCircuit => + a match { + case RunFirrtlTransformAnnotation(transform) => d.copy(transforms = transform +: d.transforms) + case CompilerAnnotation(compiler) => d.copy(compiler = Some(compiler)) + case annotation => d.copy(annotations = annotation +: d.annotations) + } } c.toSeq } @@ -89,7 +93,7 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { * removed ([[CompilerAnnotation]]s and [[RunFirrtlTransformAnnotation]]s). */ protected def bToA(b: Seq[CompilerRun]): AnnotationSeq = - b.flatMap( bb => FirrtlCircuitAnnotation(bb.stateOut.get.circuit) +: bb.stateOut.get.annotations ) + b.flatMap(bb => FirrtlCircuitAnnotation(bb.stateOut.get.circuit) +: bb.stateOut.get.annotations) /** Run the FIRRTL compiler some number of times. If more than one run is specified, a parallel collection will be * used. @@ -98,9 +102,9 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { def f(c: CompilerRun): CompilerRun = { val targets = c.compiler match { case Some(d) => c.transforms.reverse.map(Dependency.fromTransform(_)) ++ compilerToTransforms(d) - case None => + case None => val hasEmitter = c.transforms.collectFirst { case _: firrtl.Emitter => true }.isDefined - if(!hasEmitter) { + if (!hasEmitter) { throw new PhasePrerequisiteException("No compiler specified!") } else { c.transforms.reverse.map(Dependency.fromTransform) @@ -118,18 +122,19 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] { c.copy(stateOut = Some(annotationsOut)) } - if (b.size <= 1) { b.map(f) } else { - collection.parallel.immutable.ParVector(b :_*).par.map(f).seq + if (b.size <= 1) { b.map(f) } + else { + collection.parallel.immutable.ParVector(b: _*).par.map(f).seq } } private def compilerToTransforms(a: FirrtlCompiler): Seq[TransformDependency] = a match { - case _: firrtl.NoneCompiler => Forms.ChirrtlForm - case _: firrtl.HighFirrtlCompiler => Forms.MinimalHighForm - case _: firrtl.MiddleFirrtlCompiler => Forms.MidForm - case _: firrtl.LowFirrtlCompiler => Forms.LowForm + case _: firrtl.NoneCompiler => Forms.ChirrtlForm + case _: firrtl.HighFirrtlCompiler => Forms.MinimalHighForm + case _: firrtl.MiddleFirrtlCompiler => Forms.MidForm + case _: firrtl.LowFirrtlCompiler => Forms.LowForm case _: firrtl.VerilogCompiler | _: firrtl.SystemVerilogCompiler => Forms.LowFormOptimized - case _: firrtl.MinimumVerilogCompiler => Forms.LowFormMinimumOptimized + case _: firrtl.MinimumVerilogCompiler => Forms.LowFormMinimumOptimized } } diff --git a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala index b149a791..0b558cc0 100644 --- a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala +++ b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala @@ -47,7 +47,10 @@ object DriverCompatibility { /** Holds the name of the top (main) module in an input circuit * @param value top module name */ - @deprecated(""""top-name" is deprecated as part of the Stage/Phase refactor. Use explicit input/output files.""", "1.2") + @deprecated( + """"top-name" is deprecated as part of the Stage/Phase refactor. Use explicit input/output files.""", + "1.2" + ) case class TopNameAnnotation(topName: String) extends NoTargetAnnotation object TopNameAnnotation { @@ -57,7 +60,7 @@ object DriverCompatibility { .abbr("tn") .hidden .unbounded - .action( (_, _) => throw new OptionsException(optionRemoved("--top-name/-tn")) ) + .action((_, _) => throw new OptionsException(optionRemoved("--top-name/-tn"))) } /** Indicates that the implicit emitter, derived from a [[CompilerAnnotation]] should be an [[EmitAllModulesAnnotation]] @@ -70,7 +73,7 @@ object DriverCompatibility { .abbr("fsm") .hidden .unbounded - .action( (_, _) => throw new OptionsException(optionRemoved("--split-modules/-fsm")) ) + .action((_, _) => throw new OptionsException(optionRemoved("--split-modules/-fsm"))) } @@ -84,13 +87,16 @@ object DriverCompatibility { * @return the top module ''if it can be determined'' */ private def topName(annotations: AnnotationSeq): Option[String] = - annotations.collectFirst{ case TopNameAnnotation(n) => n }.orElse( - annotations.collectFirst{ case FirrtlCircuitAnnotation(c) => c.main }.orElse( - annotations.collectFirst{ case FirrtlSourceAnnotation(s) => Parser.parse(s).main }.orElse( - annotations.collectFirst{ case FirrtlFileAnnotation(f) => - FirrtlStageUtils.getFileExtension(f) match { - case ProtoBufFile => FromProto.fromFile(f).main - case FirrtlFile => Parser.parse(FileUtils.getText(f)).main } } ))) + annotations.collectFirst { case TopNameAnnotation(n) => n } + .orElse(annotations.collectFirst { case FirrtlCircuitAnnotation(c) => c.main }.orElse(annotations.collectFirst { + case FirrtlSourceAnnotation(s) => Parser.parse(s).main + }.orElse(annotations.collectFirst { + case FirrtlFileAnnotation(f) => + FirrtlStageUtils.getFileExtension(f) match { + case ProtoBufFile => FromProto.fromFile(f).main + case FirrtlFile => Parser.parse(FileUtils.getText(f)).main + } + }))) /** Determine the target directory with the following precedence (highest to lowest): * - Explicitly from the user-specified [[firrtl.options.TargetDirAnnotation TargetDirAnnotation]] @@ -131,22 +137,27 @@ object DriverCompatibility { override def invalidates(a: Phase) = false /** Try to add an [[firrtl.options.InputAnnotationFileAnnotation InputAnnotationFileAnnotation]] implicitly specified by - * an [[AnnotationSeq]]. */ - def transform(annotations: AnnotationSeq): AnnotationSeq = annotations - .collectFirst{ case a: InputAnnotationFileAnnotation => a } match { - case Some(_) => annotations - case None => topName(annotations) match { + * an [[AnnotationSeq]]. + */ + def transform(annotations: AnnotationSeq): AnnotationSeq = annotations.collectFirst { + case a: InputAnnotationFileAnnotation => a + } match { + case Some(_) => annotations + case None => + topName(annotations) match { case Some(n) => val filename = targetDir(annotations) + "/" + n + ".anno" if (new File(filename).exists) { StageUtils.dramaticWarning( - s"Implicit reading of the annotation file is deprecated! Use an explict --annotation-file argument.") + s"Implicit reading of the annotation file is deprecated! Use an explict --annotation-file argument." + ) annotations :+ InputAnnotationFileAnnotation(filename) } else { annotations } case None => annotations - } } + } + } } @@ -180,7 +191,8 @@ object DriverCompatibility { annotations } else if (main.nonEmpty) { StageUtils.dramaticWarning( - s"Implicit reading of the input file is deprecated! Use an explict --input-file argument.") + s"Implicit reading of the input file is deprecated! Use an explict --input-file argument." + ) FirrtlFileAnnotation(Viewer[StageOptions].view(annotations).getBuildFileName(s"${main.get}.fir")) +: annotations } else { annotations @@ -194,8 +206,10 @@ object DriverCompatibility { * this adds an [[EmitCircuitAnnotation]]. This replicates old behavior where specifying a compiler automatically * meant that an emitter would also run. */ - @deprecated("""AddImplicitEmitter should only be used to build Driver compatibility wrappers. Switch to Stage.""", - "1.2") + @deprecated( + """AddImplicitEmitter should only be used to build Driver compatibility wrappers. Switch to Stage.""", + "1.2" + ) class AddImplicitEmitter extends Phase { override def prerequisites = Seq.empty @@ -206,13 +220,13 @@ object DriverCompatibility { /** Add one [[EmitAnnotation]] foreach [[CompilerAnnotation]]. */ def transform(annotations: AnnotationSeq): AnnotationSeq = { - val splitModules = annotations.collectFirst{ case a: EmitOneFilePerModuleAnnotation.type => a }.isDefined + val splitModules = annotations.collectFirst { case a: EmitOneFilePerModuleAnnotation.type => a }.isDefined annotations.flatMap { case a @ CompilerAnnotation(c) => val b = RunFirrtlTransformAnnotation(a.compiler.emitter) if (splitModules) { Seq(a, b, EmitAllModulesAnnotation(c.emitter.getClass)) } - else { Seq(a, b, EmitCircuitAnnotation (c.emitter.getClass)) } + else { Seq(a, b, EmitCircuitAnnotation(c.emitter.getClass)) } case a => Seq(a) } } @@ -222,8 +236,10 @@ object DriverCompatibility { /** Adds an [[OutputFileAnnotation]] derived from a [[TopNameAnnotation]] if no [[OutputFileAnnotation]] already * exists. If no [[TopNameAnnotation]] exists, then no [[OutputFileAnnotation]] is added. */ - @deprecated("""AddImplicitOutputFile should only be used to build Driver compatibility wrappers. Switch to Stage.""", - "1.2") + @deprecated( + """AddImplicitOutputFile should only be used to build Driver compatibility wrappers. Switch to Stage.""", + "1.2" + ) class AddImplicitOutputFile extends Phase { override def prerequisites = Seq(Dependency[AddImplicitFirrtlFile]) @@ -234,9 +250,9 @@ object DriverCompatibility { /** Add an [[OutputFileAnnotation]] derived from a [[TopNameAnnotation]] if needed. */ def transform(annotations: AnnotationSeq): AnnotationSeq = { - val hasOutputFile = annotations - .collectFirst{ case a @(_: EmitOneFilePerModuleAnnotation.type | _: OutputFileAnnotation) => a } - .isDefined + val hasOutputFile = annotations.collectFirst { + case a @ (_: EmitOneFilePerModuleAnnotation.type | _: OutputFileAnnotation) => a + }.isDefined val top = topName(annotations) if (!hasOutputFile && top.isDefined) { diff --git a/src/main/scala/firrtl/stage/phases/WriteEmitted.scala b/src/main/scala/firrtl/stage/phases/WriteEmitted.scala index e2db2a94..614ce62f 100644 --- a/src/main/scala/firrtl/stage/phases/WriteEmitted.scala +++ b/src/main/scala/firrtl/stage/phases/WriteEmitted.scala @@ -2,7 +2,7 @@ package firrtl.stage.phases -import firrtl.{AnnotationSeq, EmittedModuleAnnotation, EmittedCircuitAnnotation} +import firrtl.{AnnotationSeq, EmittedCircuitAnnotation, EmittedModuleAnnotation} import firrtl.options.{Phase, StageOptions, Viewer} import firrtl.stage.FirrtlOptions @@ -24,8 +24,11 @@ import java.io.PrintWriter * * Any annotations written to files will be deleted. */ -@deprecated("Annotations that mixin the CustomFileEmission trait are automatically serialized by stages." + - "This will be removed in FIRRTL 1.5", "FIRRTL 1.4.0") +@deprecated( + "Annotations that mixin the CustomFileEmission trait are automatically serialized by stages." + + "This will be removed in FIRRTL 1.5", + "FIRRTL 1.4.0" +) class WriteEmitted extends Phase { override def prerequisites = Seq.empty @@ -47,7 +50,8 @@ class WriteEmitted extends Phase { None case a: EmittedCircuitAnnotation[_] => val pw = new PrintWriter( - sopts.getBuildFileName(fopts.outputFileName.getOrElse(a.value.name), Some(a.value.outputSuffix))) + sopts.getBuildFileName(fopts.outputFileName.getOrElse(a.value.name), Some(a.value.outputSuffix)) + ) pw.write(a.value.value) pw.close() None diff --git a/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala b/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala index ebcd7cfb..742d2b7e 100644 --- a/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala +++ b/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala @@ -9,7 +9,8 @@ class CatchCustomTransformExceptions(val underlying: Transform) extends Transfor override def execute(c: CircuitState): CircuitState = try { underlying.transform(c) } catch { - case e: Exception if CatchCustomTransformExceptions.isCustomTransform(trueUnderlying) => throw CustomTransformException(e) + case e: Exception if CatchCustomTransformExceptions.isCustomTransform(trueUnderlying) => + throw CustomTransformException(e) } } diff --git a/src/main/scala/firrtl/stage/transforms/Compiler.scala b/src/main/scala/firrtl/stage/transforms/Compiler.scala index 9988e443..251f4387 100644 --- a/src/main/scala/firrtl/stage/transforms/Compiler.scala +++ b/src/main/scala/firrtl/stage/transforms/Compiler.scala @@ -7,12 +7,12 @@ import firrtl.stage.TransformManager import firrtl.{Transform, VerilogEmitter} /** A [[firrtl.stage.TransformManager TransformManager]] of - * */ class Compiler( - targets: Seq[TransformManager.TransformDependency], + targets: Seq[TransformManager.TransformDependency], currentState: Seq[TransformManager.TransformDependency] = Seq.empty, - knownObjects: Set[Transform] = Set.empty) extends TransformManager(targets, currentState, knownObjects) { + knownObjects: Set[Transform] = Set.empty) + extends TransformManager(targets, currentState, knownObjects) { override val wrappers = Seq( (a: Transform) => ExpandPrepares(a), @@ -21,9 +21,10 @@ class Compiler( ) override def customPrintHandling( - tab: String, + tab: String, charSet: CharSet, - size: Int): Option[PartialFunction[(Transform, Int), Seq[String]]] = { + size: Int + ): Option[PartialFunction[(Transform, Int), Seq[String]]] = { val (l, n, c) = (charSet.lastNode, charSet.notLastNode, charSet.continuation) val last = size - 1 diff --git a/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala b/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala index 7a0621e4..d0514f15 100644 --- a/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala +++ b/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala @@ -8,8 +8,10 @@ class ExpandPrepares(val underlying: Transform) extends Transform with WrappedTr /* Assert that this is not wrapping other transforms. */ underlying match { - case _: WrappedTransform => throw new Exception( - s"'ExpandPrepares' must not wrap other 'WrappedTransforms', but wraps '${underlying.getClass.getName}'") + case _: WrappedTransform => + throw new Exception( + s"'ExpandPrepares' must not wrap other 'WrappedTransforms', but wraps '${underlying.getClass.getName}'" + ) case _ => } diff --git a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala index 913ab5d2..c268332a 100644 --- a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala +++ b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala @@ -8,8 +8,10 @@ import firrtl.options.{Dependency, DependencyManagerException} case class TransformHistoryAnnotation(history: Seq[Transform], state: Set[Transform]) extends NoTargetAnnotation { - def add(transform: Transform, - invalidates: (Transform) => Boolean = (a: Transform) => false): TransformHistoryAnnotation = + def add( + transform: Transform, + invalidates: (Transform) => Boolean = (a: Transform) => false + ): TransformHistoryAnnotation = this.copy( history = transform +: this.history, state = (this.state + transform).filterNot(invalidates) @@ -44,8 +46,7 @@ class TrackTransforms(val underlying: Transform) extends Transform with WrappedT } override def execute(c: CircuitState): CircuitState = { - val state = c.annotations - .collectFirst{ case TransformHistoryAnnotation(_, state) => state } + val state = c.annotations.collectFirst { case TransformHistoryAnnotation(_, state) => state } .getOrElse(Set.empty[Transform]) .map(Dependency.fromTransform(_)) @@ -53,7 +54,8 @@ class TrackTransforms(val underlying: Transform) extends Transform with WrappedT throw new DependencyManagerException( s"""|Tried to execute Transform '$trueUnderlying' for which run-time prerequisites were not satisfied: | state: ${state.mkString("\n -", "\n -", "")} - | prerequisites: ${trueUnderlying.prerequisites.mkString("\n -", "\n -", "")}""".stripMargin) + | prerequisites: ${trueUnderlying.prerequisites.mkString("\n -", "\n -", "")}""".stripMargin + ) } val out = underlying.transform(c) diff --git a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala index cc0fbc6f..e36eef9b 100644 --- a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala +++ b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala @@ -5,7 +5,9 @@ package firrtl.stage.transforms import firrtl.{CircuitState, Transform} import firrtl.options.Translator -class UpdateAnnotations(val underlying: Transform) extends Transform with WrappedTransform +class UpdateAnnotations(val underlying: Transform) + extends Transform + with WrappedTransform with Translator[CircuitState, (CircuitState, CircuitState)] { override def execute(c: CircuitState): CircuitState = underlying.transform(c) diff --git a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala index a57973d5..5000e07a 100644 --- a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala +++ b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala @@ -2,7 +2,7 @@ package firrtl.transforms -import java.io.{File, FileNotFoundException, FileInputStream, FileOutputStream, PrintWriter} +import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream, PrintWriter} import firrtl._ import firrtl.annotations._ @@ -11,31 +11,32 @@ import scala.collection.immutable.ListSet sealed trait BlackBoxHelperAnno extends Annotation -case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno - with NoTargetAnnotation { +case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno with NoTargetAnnotation { override def serialize: String = s"targetDir\n$targetDir" } -case class BlackBoxResourceAnno(target: ModuleName, resourceId: String) extends BlackBoxHelperAnno +case class BlackBoxResourceAnno(target: ModuleName, resourceId: String) + extends BlackBoxHelperAnno with SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) override def serialize: String = s"resource\n$resourceId" } -case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String) extends BlackBoxHelperAnno +case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String) + extends BlackBoxHelperAnno with SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) override def serialize: String = s"inline\n$name\n$text" } -case class BlackBoxPathAnno(target: ModuleName, path: String) extends BlackBoxHelperAnno +case class BlackBoxPathAnno(target: ModuleName, path: String) + extends BlackBoxHelperAnno with SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) override def serialize: String = s"path\n$path" } -case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno - with NoTargetAnnotation { +case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno with NoTargetAnnotation { override def serialize: String = s"resourceFileName\n$resourceFileName" } @@ -43,8 +44,10 @@ case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackB * @param fileName the name of the BlackBox file (only used for error message generation) * @param e an underlying exception that generated this */ -class BlackBoxNotFoundException(fileName: String, message: String) extends FirrtlUserException( - s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message") +class BlackBoxNotFoundException(fileName: String, message: String) + extends FirrtlUserException( + s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message" + ) /** Handle source for Verilog ExtModules (BlackBoxes) * @@ -72,15 +75,16 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration { */ def collectAnnos(annos: Seq[Annotation]): (ListSet[BlackBoxHelperAnno], File, File) = annos.foldLeft((ListSet.empty[BlackBoxHelperAnno], DefaultTargetDir, new File(defaultFileListName))) { - case ((acc, tdir, flistName), anno) => anno match { - case BlackBoxTargetDirAnno(dir) => - val targetDir = new File(dir) - if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) } - (acc, targetDir, flistName) - case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName)) - case a: BlackBoxHelperAnno => (acc + a, tdir, flistName) - case _ => (acc, tdir, flistName) - } + case ((acc, tdir, flistName), anno) => + anno match { + case BlackBoxTargetDirAnno(dir) => + val targetDir = new File(dir) + if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) } + (acc, targetDir, flistName) + case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName)) + case a: BlackBoxHelperAnno => (acc + a, tdir, flistName) + case _ => (acc, tdir, flistName) + } } /** @@ -112,14 +116,15 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration { case BlackBoxInlineAnno(_, name, text) => val outFile = new File(targetDir, name) (text, outFile) - }.map { case (text, file) => - writeTextToFile(text, file) - file + }.map { + case (text, file) => + writeTextToFile(text, file) + file } // Issue #917 - We don't want to list Verilog header files ("*.vh") in our file list - they will automatically be included by reference. def isHeader(name: String) = name.endsWith(".h") || name.endsWith(".vh") || name.endsWith(".svh") - val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot{ f => isHeader(f.getName()) } + val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot { f => isHeader(f.getName()) } val filelistFile = if (flistName.isAbsolute()) flistName else new File(targetDir, flistName.getName()) // We need the canonical path here, so verilator will create a path to the file that works from the targetDir, @@ -137,12 +142,14 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration { } object BlackBoxSourceHelper { + /** Safely access a file converting [[FileNotFoundException]]s and [[NullPointerException]]s into * [[BlackBoxNotFoundException]]s * @param fileName the name of the file to be accessed (only used for error message generation) * @param code some code to run */ - private def safeFile[A](fileName: String)(code: => A) = try { code } catch { + private def safeFile[A](fileName: String)(code: => A) = try { code } + catch { case e @ (_: FileNotFoundException | _: NullPointerException) => throw new BlackBoxNotFoundException(fileName, e.getMessage) } diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index 6403be23..ee4c1d0b 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -24,6 +24,7 @@ import firrtl.options.{Dependency, RegisteredTransform, ShellOption} case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None) object LogicNode { + /** * Construct a LogicNode from a *Low FIRRTL* reference or subfield that refers to a component. * Since aggregate types appear in Low FIRRTL only as the full types of instances or memories, @@ -39,11 +40,11 @@ object LogicNode { case s: WSubField => s.expr match { case modref: WRef => - LogicNode(s.name,Some(modref.name)) + LogicNode(s.name, Some(modref.name)) case memport: WSubField => memport.expr match { case memref: WRef => - LogicNode(s.name,Some(memref.name),Some(memport.name)) + LogicNode(s.name, Some(memref.name), Some(memport.name)) case _ => throwInternalError(s"LogicNode: unrecognized subsubfield expression - $memport") } case _ => throwInternalError(s"LogicNode: unrecognized subfield expression - $s") @@ -56,9 +57,8 @@ object CheckCombLoops { type ConnMap = DiGraph[LogicNode] with EdgeData[LogicNode, Info] type MutableConnMap = MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info] - - class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException( - s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) + class CombLoopException(info: Info, mname: String, cycle: Seq[String]) + extends PassException(s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) } case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation @@ -73,7 +73,7 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge override def update(renames: RenameMap): Seq[Annotation] = { val sources = renames.get(source).getOrElse(Seq(source)) val sinks = renames.get(sink).getOrElse(Seq(sink)) - val paths = sources flatMap { s => sinks.map((s, _)) } + val paths = sources.flatMap { s => sinks.map((s, _)) } paths.collect { case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink) } @@ -82,8 +82,8 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget]) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { - val newSources = sources.flatMap { s => renames(s) }.collect {case x: ReferenceTarget if x.isLocal => x} - val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x} + val newSources = sources.flatMap { s => renames(s) }.collect { case x: ReferenceTarget if x.isLocal => x } + val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x } newSinks.map(snk => CombinationalPath(snk, newSources)) } } @@ -98,14 +98,10 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules * @note The pass will throw exceptions on "false paths" */ -class CheckCombLoops extends Transform - with RegisteredTransform - with DependencyAPIMigration { +class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), - Dependency(firrtl.transforms.RemoveReset) ) + Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset)) override def optionalPrerequisites = Seq.empty @@ -119,17 +115,21 @@ class CheckCombLoops extends Transform new ShellOption[Unit]( longOption = "no-check-comb-loops", toAnnotationSeq = (_: Unit) => Seq(DontCheckCombLoopsAnnotation), - helpText = "Disable combinational loop checking" ) ) + helpText = "Disable combinational loop checking" + ) + ) private def getExprDeps(deps: MutableConnMap, v: LogicNode, info: Info)(e: Expression): Unit = e match { - case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info) + case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info) case s: WSubField => deps.addEdgeIfValid(v, LogicNode(s), info) case _ => e.foreach(getExprDeps(deps, v, info)) } private def getStmtDeps( simplifiedModules: mutable.Map[String, AbstractConnMap], - deps: MutableConnMap)(s: Statement): Unit = s match { + deps: MutableConnMap + )(s: Statement + ): Unit = s match { case Connect(info, loc, expr) => val lhs = LogicNode(loc) if (deps.contains(lhs)) { @@ -152,9 +152,9 @@ class CheckCombLoops extends Transform case i: WDefInstance => val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) iGraph.getVertices.foreach(deps.addVertex(_)) - iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v, _) } }) case _ => - s.foreach(getStmtDeps(simplifiedModules,deps)) + s.foreach(getStmtDeps(simplifiedModules, deps)) } // Pretty-print a LogicNode with a prepended hierarchical path @@ -169,24 +169,26 @@ class CheckCombLoops extends Transform * recovered. */ private def expandInstancePaths( - m: String, + m: String, moduleGraphs: mutable.Map[String, ConnMap], - moduleDeps: Map[String, Map[String, String]], - hierPrefix: Seq[String], - path: Seq[LogicNode]): Seq[String] = { + moduleDeps: Map[String, Map[String, String]], + hierPrefix: Seq[String], + path: Seq[LogicNode] + ): Seq[String] = { // Recover info from edge data, add to error string def info(u: LogicNode, v: LogicNode): String = moduleGraphs(m).getEdgeData(u, v).map(_.toString).mkString("\t", "", "") // lhs comes after rhs - val pathNodes = (path zip path.tail) map { case (rhs, lhs) => - if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) { - val child = moduleDeps(m)(lhs.inst.get) - val newHierPrefix = hierPrefix :+ lhs.inst.get - val subpath = moduleGraphs(child).path(lhs.copy(inst=None),rhs.copy(inst=None)).reverse - expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath) - } else { - Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs)) - } + val pathNodes = (path.zip(path.tail)).map { + case (rhs, lhs) => + if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) { + val child = moduleDeps(m)(lhs.inst.get) + val newHierPrefix = hierPrefix :+ lhs.inst.get + val subpath = moduleGraphs(child).path(lhs.copy(inst = None), rhs.copy(inst = None)).reverse + expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath) + } else { + Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs)) + } } pathNodes.flatten } @@ -238,12 +240,13 @@ class CheckCombLoops extends Transform val errors = new Errors() val extModulePaths = state.annotations.groupBy { case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module) - case ann: Annotation => CircuitTarget(c.main) + case ann: Annotation => CircuitTarget(c.main) } - val moduleMap = c.modules.map({m => (m.name,m) }).toMap + val moduleMap = c.modules.map({ m => (m.name, m) }).toMap val iGraph = InstanceKeyGraph(c).graph - val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap - val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } + val moduleDeps = + iGraph.getEdgeMap.map({ case (k, v) => (k.module, (v.map { i => (i.name, i.module) }).toMap) }).toMap + val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse.map { moduleMap(_) } val moduleGraphs = new mutable.HashMap[String, ConnMap] val simplifiedModuleGraphs = new mutable.HashMap[String, AbstractConnMap] topoSortedModules.foreach { @@ -252,7 +255,8 @@ class CheckCombLoops extends Transform val extModuleDeps = new MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info] portSet.foreach(extModuleDeps.addVertex(_)) extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect { - case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) + case a: ExtModulePathAnnotation => + extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) } moduleGraphs(em.name) = extModuleDeps simplifiedModuleGraphs(em.name) = extModuleDeps.simplify(portSet) @@ -270,7 +274,7 @@ class CheckCombLoops extends Transform for (scc <- internalDeps.findSCCs.filter(_.length > 1)) { val sccSubgraph = internalDeps.subgraph(scc.toSet) val cycle = findCycleInSCC(sccSubgraph) - (cycle zip cycle.tail).foreach({ case (a,b) => require(internalDeps.getEdges(a).contains(b)) }) + (cycle.zip(cycle.tail)).foreach({ case (a, b) => require(internalDeps.getEdges(a).contains(b)) }) // Reverse to make sure LHS comes after RHS, print repeated vertex at start for legibility val intuitiveCycle = cycle.reverse val repeatedInitial = prettyPrintAbsoluteRef(Seq(m.name), intuitiveCycle.head) @@ -280,10 +284,11 @@ class CheckCombLoops extends Transform case m => throwInternalError(s"Module ${m.name} has unrecognized type") } val mt = ModuleTarget(c.main, c.main) - val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty => - val sink = mt.ref(from.name) - val sources = tos.map(to => mt.ref(to.name)) - CombinationalPath(sink, sources.toSeq) + val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { + case (from, tos) if tos.nonEmpty => + val sink = mt.ref(from.name) + val sources = tos.map(to => mt.ref(to.name)) + CombinationalPath(sink, sources.toSeq) } (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs, moduleGraphs) } @@ -291,7 +296,7 @@ class CheckCombLoops extends Transform /** * Returns a Map from Module name to port connectivity */ - def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = { + def analyze(state: CircuitState): collection.Map[String, DiGraph[String]] = { val (result, errors, connectivity, _) = run(state) connectivity.map { case (k, v) => (k, v.transformNodes(ln => ln.name)) @@ -301,7 +306,7 @@ class CheckCombLoops extends Transform /** * Returns a Map from Module name to complete netlist connectivity */ - def analyzeFull(state: CircuitState): collection.Map[String,DiGraph[LogicNode]] = { + def analyzeFull(state: CircuitState): collection.Map[String, DiGraph[LogicNode]] = { run(state)._4 } diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala index 7fa01e46..3014d0e3 100644 --- a/src/main/scala/firrtl/transforms/CombineCats.scala +++ b/src/main/scala/firrtl/transforms/CombineCats.scala @@ -1,4 +1,3 @@ - package firrtl package transforms @@ -14,26 +13,30 @@ import scala.collection.mutable case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation object CombineCats { + /** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them paired with their Cat length */ type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)] def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match { - case cat@DoPrim(Cat, args, _, _) => + case cat @ DoPrim(Cat, args, _, _) => val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head) val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1)) (a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression]) case other => - netlist.get(we(expr)).collect { - case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat) - }.getOrElse((1, other)) + netlist + .get(we(expr)) + .collect { + case (len, cat @ DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat) + } + .getOrElse((1, other)) } def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = { stmt.map(onStmt(maxCatLen, netlist)) match { - case node@DefNode(_, name, value) => + case node @ DefNode(_, name, value) => val catLenAndVal = value match { - case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat) - case other => (1, other) + case cat @ DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat) + case other => (1, other) } netlist(we(WRef(name))) = catLenAndVal node.copy(value = catLenAndVal._2) @@ -55,16 +58,16 @@ object CombineCats { class CombineCats extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions) ) + Seq( + Dependency(passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions) + ) override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Seq( - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter]) override def invalidates(a: Transform) = false diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index ce36dd72..dc9b2bbe 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -28,7 +28,7 @@ object ConstantPropagation { /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) case (we, wt) if we == wt => e } @@ -44,38 +44,40 @@ object ConstantPropagation { case lit: Literal => require(hi >= lo) UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe)) - case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match { - case t: UIntType => x - case _ => asUInt(x, e.tpe) - } + case x if bitWidth(e.tpe) == bitWidth(x.tpe) => + x.tpe match { + case t: UIntType => x + case _ => asUInt(x, e.tpe) + } case _ => e } } def foldShiftRight(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head - case x => e.args.head match { - // TODO when amount >= x.width, return a zero-width wire - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) - // take sign bit if shift amount is larger than arg width - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) - case _ => e - } + case x => + e.args.head match { + // TODO when amount >= x.width, return a zero-width wire + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1))) + // take sign bit if shift amount is larger than arg width + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1))) + case _ => e + } } - - /********************************************** - * REGISTER CONSTANT PROPAGATION HELPER TYPES * - **********************************************/ + /** ******************************************** + * REGISTER CONSTANT PROPAGATION HELPER TYPES * + * ******************************************** + */ // A utility class that is somewhat like an Option but with two variants containing Nothing. // for register constant propagation (register or literal). private abstract class ConstPropBinding[+T] { def resolve[V >: T](that: ConstPropBinding[V]): ConstPropBinding[V] = (this, that) match { - case (x, y) if (x == y) => x + case (x, y) if (x == y) => x case (x, UnboundConstant) => x case (UnboundConstant, y) => y - case _ => NonConstant + case _ => NonConstant } } @@ -103,21 +105,23 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res override def prerequisites = ((new mutable.LinkedHashSet()) - ++ firrtl.stage.Forms.LowForm - - Dependency(firrtl.passes.Legalize) - + Dependency(firrtl.passes.RemoveValidIf)).toSeq + ++ firrtl.stage.Forms.LowForm + - Dependency(firrtl.passes.Legalize) + + Dependency(firrtl.passes.RemoveValidIf)).toSeq override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = - Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[SystemVerilogEmitter], - Dependency[VerilogEmitter] ) + Seq( + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[SystemVerilogEmitter], + Dependency[VerilogEmitter] + ) override def invalidates(a: Transform): Boolean = a match { case firrtl.passes.Legalize => true - case _ => false + case _ => false } override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation]) @@ -130,7 +134,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } sealed trait FoldCommutativeOp extends SimplifyBinaryOp { - def fold(c1: Literal, c2: Literal): Expression + def fold(c1: Literal, c2: Literal): Expression def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match { @@ -138,7 +142,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe) case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe) case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs) - case _ => e + case _ => e } } @@ -177,20 +181,20 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res */ def apply(prim: DoPrim): Expression = prim.args.head match { case a: Literal => simplifyLiteral(a) - case _ => prim + case _ => prim } } object FoldADD extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match { - case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) - case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1)) + case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1)) + case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1)) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => rhs case SIntLiteral(v, w) if v == BigInt(0) => rhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = e } @@ -209,77 +213,81 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res object FoldAND extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked(c1.value & c2.value, width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) - case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) + case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked((c1.value | c2.value), width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { - case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case UIntLiteral(v, _) if v == BigInt(0) => rhs + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe) } object FoldXOR extends FoldCommutativeOp { def fold(c1: Literal, c2: Literal) = { - val width = (c1.width max c2.width).asInstanceOf[IntWidth] + val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth] UIntLiteral.masked((c1.value ^ c2.value), width) } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) - case _ => e + case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe)) } object FoldEqual extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => + DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1) } object FoldNotEqual extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs - case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe) + case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => + DoPrim(Not, Seq(rhs), Nil, e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0) } private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match { - case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => + UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match { case 0 => e.args.head - case x => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) - case _ => e - } + case x => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) + case _ => e + } } private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match { @@ -296,53 +304,55 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case _ => e } - private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { def isUInt(e: Expression): Boolean = e.tpe match { case UIntType(_) => true - case _ => false + case _ => false } def isZero(e: Expression) = e match { - case UIntLiteral(value, _) => value == BigInt(0) - case SIntLiteral(value, _) => value == BigInt(0) - case _ => false - } + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) + case _ => false + } x match { - case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero - case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one - case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero - case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one - case ex => ex + case DoPrim(Lt, Seq(a, b), _, _) if isUInt(a) && isZero(b) => zero + case DoPrim(Leq, Seq(a, b), _, _) if isZero(a) && isUInt(b) => one + case DoPrim(Gt, Seq(a, b), _, _) if isZero(a) && isUInt(b) => zero + case DoPrim(Geq, Seq(a, b), _, _) if isUInt(a) && isZero(b) => one + case ex => ex } } def foldIfOutsideRange(x: Expression): Expression = { //Note, only abides by a partial ordering case class Range(min: BigInt, max: BigInt) { - def === (that: Range) = + def ===(that: Range) = Seq(this.min, this.max, that.min, that.max) - .sliding(2,1) + .sliding(2, 1) .map(x => x.head == x(1)) .reduce(_ && _) - def > (that: Range) = this.min > that.max - def >= (that: Range) = this.min >= that.max - def < (that: Range) = this.max < that.min - def <= (that: Range) = this.max <= that.min + def >(that: Range) = this.min > that.max + def >=(that: Range) = this.min >= that.max + def <(that: Range) = this.max < that.min + def <=(that: Range) = this.max <= that.min } def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) - case _ => e.tpe match { - case SIntType(IntWidth(width)) => Range( - min = BigInt(0) - BigInt(2).pow(width.toInt - 1), - max = BigInt(2).pow(width.toInt - 1) - BigInt(1) - ) - case UIntType(IntWidth(width)) => Range( - min = BigInt(0), - max = BigInt(2).pow(width.toInt) - BigInt(1) - ) - } + case _ => + e.tpe match { + case SIntType(IntWidth(width)) => + Range( + min = BigInt(0) - BigInt(2).pow(width.toInt - 1), + max = BigInt(2).pow(width.toInt - 1) - BigInt(1) + ) + case UIntType(IntWidth(width)) => + Range( + min = BigInt(0), + max = BigInt(2).pow(width.toInt) - BigInt(1) + ) + } } // Calculates an expression's range of values x match { @@ -351,27 +361,28 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res def r1 = range(ex.args(1)) ex.op match { // Always true - case Lt if r0 < r1 => one + case Lt if r0 < r1 => one case Leq if r0 <= r1 => one - case Gt if r0 > r1 => one + case Gt if r0 > r1 => one case Geq if r0 >= r1 => one // Always false - case Lt if r0 >= r1 => zero + case Lt if r0 >= r1 => zero case Leq if r0 > r1 => zero - case Gt if r0 <= r1 => zero + case Gt if r0 <= r1 => zero case Geq if r0 < r1 => zero - case _ => ex + case _ => ex } case ex => ex } } def foldIfMatchingArgs(x: Expression) = x match { - case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match { - case (Lt | Gt) => zero - case (Leq | Geq) => one - case _ => x - } + case DoPrim(op, Seq(a, b), _, _) if (a == b) => + op match { + case (Lt | Gt) => zero + case (Leq | Geq) => one + case _ => x + } case _ => x } foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e))) @@ -393,43 +404,47 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } private def constPropPrim(e: DoPrim): Expression = e.op match { - case Shl => foldShiftLeft(e) - case Dshl => foldDynamicShiftLeft(e) - case Shr => foldShiftRight(e) - case Dshr => foldDynamicShiftRight(e) - case Cat => foldConcat(e) - case Add => FoldADD(e) - case Sub => SimplifySUB(e) - case Div => SimplifyDIV(e) - case Rem => SimplifyREM(e) - case And => FoldAND(e) - case Or => FoldOR(e) - case Xor => FoldXOR(e) - case Eq => FoldEqual(e) - case Neq => FoldNotEqual(e) - case Andr => FoldANDR(e) - case Orr => FoldORR(e) - case Xorr => FoldXORR(e) + case Shl => foldShiftLeft(e) + case Dshl => foldDynamicShiftLeft(e) + case Shr => foldShiftRight(e) + case Dshr => foldDynamicShiftRight(e) + case Cat => foldConcat(e) + case Add => FoldADD(e) + case Sub => SimplifySUB(e) + case Div => SimplifyDIV(e) + case Rem => SimplifyREM(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case Andr => FoldANDR(e) + case Orr => FoldORR(e) + case Xorr => FoldXORR(e) case (Lt | Leq | Gt | Geq) => foldComparison(e) - case Not => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) - case _ => e - } + case Not => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case _ => e + } case AsUInt => e.args.head match { case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w)) - case arg => arg.tpe match { - case _: UIntType => arg - case _ => e - } + case arg => + arg.tpe match { + case _: UIntType => arg + case _ => e + } } - case AsSInt => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w)) - case arg => arg.tpe match { - case _: SIntType => arg - case _ => e + case AsSInt => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w)) + case arg => + arg.tpe match { + case _: SIntType => arg + case _ => e + } } - } case AsClock => val arg = e.args.head arg.tpe match { @@ -442,25 +457,27 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res case AsyncResetType => arg case _ => e } - case Pad => e.args.head match { - case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w)) - case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w)) - case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head - case _ => e - } + case Pad => + e.args.head match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w))) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w))) + case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head + case _ => e + } case (Bits | Head | Tail) => constPropBitExtract(e) - case _ => e + case _ => e } private def constPropMuxCond(m: Mux) = m.cond match { case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe) - case _ => m + case _ => m } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval case (t: UIntLiteral, f: UIntLiteral) - if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => m.cond + if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => + m.cond case (t: UIntLiteral, _) if t.value == BigInt(1) && bitWidth(m.tpe) == BigInt(1) => DoPrim(Or, Seq(m.cond, m.fval), Nil, m.tpe) case (_, f: UIntLiteral) if f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => @@ -479,15 +496,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') - def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) - - private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = { - val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) + def optimize(e: Expression): Expression = + constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = + constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e) + + private def constPropExpression( + nodeMap: NodeMap, + instMap: collection.Map[Instance, OfModule], + constSubOutputs: Map[OfModule, Map[String, Literal]] + )(e: Expression + ): Expression = { + val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs)) val propagated = old match { case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) => val module = instMap(inst.Instance) @@ -506,17 +530,17 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res * @todo generalize source locator propagation across Expressions and delete this method * @todo is the `orElse` the way we want to do propagation here? */ - private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String]) - (stmt: Statement): Statement = stmt match { - // We check rname because inlining it would cause the original declaration to go away - case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => - val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) - node.copy(info = InfoExpr.orElse(info1, info0)) - case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => - val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) - con.copy(info = InfoExpr.orElse(info1, info0)) - case other => other - } + private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])(stmt: Statement): Statement = + stmt match { + // We check rname because inlining it would cause the original declaration to go away + case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + node.copy(info = InfoExpr.orElse(info1, info0)) + case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) => + val (info1, _) = InfoExpr.unwrap(nodeMap(rname)) + con.copy(info = InfoExpr.orElse(info1, info0)) + case other => other + } /* Constant propagate a Module * @@ -538,12 +562,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res */ @tailrec private def constPropModule( - m: Module, - dontTouches: Set[String], - instMap: collection.Map[Instance, OfModule], - constInputs: Map[String, Literal], - constSubOutputs: Map[OfModule, Map[String, Literal]] - ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { + m: Module, + dontTouches: Set[String], + instMap: collection.Map[Instance, OfModule], + constInputs: Map[String, Literal], + constSubOutputs: Map[OfModule, Map[String, Literal]] + ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = { var nPropagated = 0L val nodeMap = new NodeMap() @@ -571,13 +595,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // to constant wires, we don't need to worry about propagating primops or muxes since we'll do // that on the next iteration if necessary def backPropExpr(expr: Expression): Expression = { - val old = expr map backPropExpr + val old = expr.map(backPropExpr) val propagated = old match { // When swapping, we swap both rhs and lhs - case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) => + case ref @ WRef(rname, _, _, _) if swapMap.contains(rname) => ref.copy(name = swapMap(rname)) // Only const prop on the rhs - case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) => + case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) => constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2) case x => x } @@ -590,27 +614,29 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res def backPropStmt(stmt: Statement): Statement = stmt match { case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) => // Self-init reset is an idiom for "no reset," and must be handled separately - swapMap.get(reg.name) - .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) - .getOrElse(reg) - case s => s map backPropExpr match { - case decl: IsDeclaration if swapMap.contains(decl.name) => - val newName = swapMap(decl.name) - nPropagated += 1 - decl match { - case node: DefNode => node.copy(name = newName) - case wire: DefWire => wire.copy(name = newName) - case reg: DefRegister => reg.copy(name = newName) - case other => throwInternalError() - } - case other => other map backPropStmt - } + swapMap + .get(reg.name) + .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName))) + .getOrElse(reg) + case s => + s.map(backPropExpr) match { + case decl: IsDeclaration if swapMap.contains(decl.name) => + val newName = swapMap(decl.name) + nPropagated += 1 + decl match { + case node: DefNode => node.copy(name = newName) + case wire: DefWire => wire.copy(name = newName) + case reg: DefRegister => reg.copy(name = newName) + case other => throwInternalError() + } + case other => other.map(backPropStmt) + } } // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression, info: Info): Unit = { value match { - case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => + case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind => assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a // node declaration or the single connection to a wire or register swapMap += (lname -> rname, rname -> lname) @@ -639,25 +665,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns // This requires that reset has been made explicit case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) => - - /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment. - * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of - * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can - * be resolved: each side is allowed one candidate register and one candidate literal to - * appear in their source trees, referring to the potential constant propagation case that - * they could allow. If the two are compatible (no different bound sources of either of - * the two types), they can be resolved by combining sources. Otherwise, they propagate - * NonConstant values. When encountering a node reference, it expands the node by to its - * RHS assignment and recurses. - * - * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This - * can stifle register constant propagations, which looks at drivers through value-preserving - * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into - * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole - * de-optimization that does not actually appear in the output code. - * - * @return a RegCPEntry describing the constant prop-compatible sources driving this expression - */ + /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment. + * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of + * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can + * be resolved: each side is allowed one candidate register and one candidate literal to + * appear in their source trees, referring to the potential constant propagation case that + * they could allow. If the two are compatible (no different bound sources of either of + * the two types), they can be resolved by combining sources. Otherwise, they propagate + * NonConstant values. When encountering a node reference, it expands the node by to its + * RHS assignment and recurses. + * + * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This + * can stifle register constant propagations, which looks at drivers through value-preserving + * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into + * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole + * de-optimization that does not actually appear in the output code. + * + * @return a RegCPEntry describing the constant prop-compatible sources driving this expression + */ val unbound = RegCPEntry(UnboundConstant, UnboundConstant) val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant) @@ -684,11 +709,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Updates nodeMap after analyzing the returned value from regConstant def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match { - case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero) + case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero) case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero) - case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) + case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit) - case _ => + case _ => } def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe)) @@ -733,11 +758,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Unify two maps using f to combine values of duplicate keys private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] = - b.foldLeft(a) { case (acc, (k, v)) => - acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) + b.foldLeft(a) { + case (acc, (k, v)) => + acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v)) } - private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = { val iGraph = InstanceKeyGraph(c) val moduleDeps = iGraph.getChildInstanceMap @@ -754,9 +779,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // are driven with the same constant value. Then, if we find a Module input where each instance // is driven with the same constant (and not seen in a previous iteration), we iterate again @tailrec - def iterate(toVisit: Set[OfModule], - modules: Map[OfModule, Module], - constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = { + def iterate( + toVisit: Set[OfModule], + modules: Map[OfModule, Module], + constInputs: Map[OfModule, Map[String, Literal]] + ): Map[OfModule, DefModule] = { if (toVisit.isEmpty) modules else { // Order from leaf modules to root so that any module driving an output @@ -767,31 +794,36 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res // Aggreagte Module outputs that are driven constant for use by instaniating Modules // Aggregate submodule inputs driven constant for checking later val (modulesx, _, constInputsx) = - order.foldLeft((modules, - Map[OfModule, Map[String, Literal]](), - Map[OfModule, Map[String, Seq[Literal]]]())) { + order.foldLeft((modules, Map[OfModule, Map[String, Literal]](), Map[OfModule, Map[String, Seq[Literal]]]())) { case ((mmap, constOutputs, constInputsAcc), mname) => val dontTouches = dontTouchMap.getOrElse(mname, Set.empty) - val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname), - constInputs.getOrElse(mname, Map.empty), constOutputs) + val (mx, mco, mci) = constPropModule( + modules(mname), + dontTouches, + moduleDeps(mname), + constInputs.getOrElse(mname, Map.empty), + constOutputs + ) // Accumulate all Literals used to drive a particular Module port val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d)) (mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx) } // Determine which module inputs have all of the same, new constants driving them - val newProppedInputs = constInputsx.flatMap { case (mname, ports) => - val portsx = ports.flatMap { case (pname, lits) => - val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) - val isModule = modules.contains(mname) // ExtModules are not contained in modules - val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 - if (isModule && newPort && allSameConst) Some(pname -> lits.head) - else None - } - if (portsx.nonEmpty) Some(mname -> portsx) else None + val newProppedInputs = constInputsx.flatMap { + case (mname, ports) => + val portsx = ports.flatMap { + case (pname, lits) => + val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false) + val isModule = modules.contains(mname) // ExtModules are not contained in modules + val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1 + if (isModule && newPort && allSameConst) Some(pname -> lits.head) + else None + } + if (portsx.nonEmpty) Some(mname -> portsx) else None } val modsWithConstInputs = newProppedInputs.keySet val newToVisit = modsWithConstInputs ++ - modsWithConstInputs.flatMap(parentGraph.reachableFrom) + modsWithConstInputs.flatMap(parentGraph.reachableFrom) // Combine const inputs (there can't be duplicate values in the inner maps) val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b) iterate(newToVisit.toSet, modulesx, nextConstInputs) @@ -805,7 +837,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res c.modules.map(m => mmap.getOrElse(m.OfModule, m)) } - Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index c883bdfb..fb1bd1f6 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -1,4 +1,3 @@ - package firrtl.transforms import firrtl._ @@ -8,7 +7,7 @@ import firrtl.annotations._ import firrtl.graph._ import firrtl.analyses.InstanceKeyGraph import firrtl.Mappers._ -import firrtl.Utils.{throwInternalError, kind} +import firrtl.Utils.{kind, throwInternalError} import firrtl.MemoizedHash._ import firrtl.options.{Dependency, RegisteredTransform, ShellOption} @@ -29,29 +28,34 @@ import collection.mutable * circumstances of their instantiation in their parent module, they will still not be removed. To * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. */ -class DeadCodeElimination extends Transform +class DeadCodeElimination + extends Transform with ResolvedAnnotationPaths with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination) ) + Seq( + Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats], + Dependency(passes.CommonSubexpressionElimination) + ) override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(passes.VerilogPrep), - Dependency[firrtl.AddDescriptionNodes] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] + ) override def invalidates(a: Transform) = false @@ -59,7 +63,9 @@ class DeadCodeElimination extends Transform new ShellOption[Unit]( longOption = "no-dce", toAnnotationSeq = (_: Unit) => Seq(NoDCEAnnotation), - helpText = "Disable dead code elimination" ) ) + helpText = "Disable dead code elimination" + ) + ) /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */ private type LogicNode = MemoizedHash[WrappedExpression] @@ -72,6 +78,7 @@ class DeadCodeElimination extends Transform val loweredName = LowerTypes.loweredName(component.name.split('.')) apply(component.module.name, WRef(loweredName)) } + /** External Modules are representated as a single node driven by all inputs and driving all * outputs */ @@ -87,7 +94,7 @@ class DeadCodeElimination extends Transform def rec(e: Expression): Expression = { e match { case ref @ (_: WRef | _: WSubField) => refs += ref - case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.map(rec) case ignore @ (_: Literal) => // Do nothing case unexpected => throwInternalError() } @@ -98,9 +105,7 @@ class DeadCodeElimination extends Transform } // Gets all dependencies and constructs LogicNodes from them - private def getDepsImpl(mname: String, - instMap: collection.Map[String, String]) - (expr: Expression): Seq[LogicNode] = + private def getDepsImpl(mname: String, instMap: collection.Map[String, String])(expr: Expression): Seq[LogicNode] = extractRefs(expr).map { e => if (kind(e) == InstanceKind) { val (inst, tail) = Utils.splitRef(e) @@ -110,11 +115,12 @@ class DeadCodeElimination extends Transform } } - /** Construct the dependency graph within this module */ - private def setupDepGraph(depGraph: MutableDiGraph[LogicNode], - instMap: collection.Map[String, String]) - (mod: Module): Unit = { + private def setupDepGraph( + depGraph: MutableDiGraph[LogicNode], + instMap: collection.Map[String, String] + )(mod: Module + ): Unit = { def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) def onStmt(stmt: Statement): Unit = stmt match { @@ -150,7 +156,7 @@ class DeadCodeElimination extends Transform val node = getDeps(loc) match { case Seq(elt) => elt } getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref)) // Simulation constructs are treated as top-level outputs - case Stop(_,_, clk, en) => + case Stop(_, _, clk, en) => Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Print(_, _, args, clk, en) => (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) @@ -172,12 +178,14 @@ class DeadCodeElimination extends Transform } // TODO Make immutable? - private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]], - doTouchExtMods: Set[String], - c: Circuit): MutableDiGraph[LogicNode] = { + private def createDependencyGraph( + instMaps: collection.Map[String, collection.Map[String, String]], + doTouchExtMods: Set[String], + c: Circuit + ): MutableDiGraph[LogicNode] = { val depGraph = new MutableDiGraph[LogicNode] c.modules.foreach { - case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) + case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) case ext: ExtModule => // Connect all inputs to all outputs val node = LogicNode(ext) @@ -205,23 +213,25 @@ class DeadCodeElimination extends Transform depGraph } - private def deleteDeadCode(instMap: collection.Map[String, String], - deadNodes: collection.Set[LogicNode], - moduleMap: collection.Map[String, DefModule], - renames: RenameMap, - topName: String, - doTouchExtMods: Set[String]) - (mod: DefModule): Option[DefModule] = { + private def deleteDeadCode( + instMap: collection.Map[String, String], + deadNodes: collection.Set[LogicNode], + moduleMap: collection.Map[String, DefModule], + renames: RenameMap, + topName: String, + doTouchExtMods: Set[String] + )(mod: DefModule + ): Option[DefModule] = { // For log-level debug def deleteMsg(decl: IsDeclaration): String = { val tpe = decl match { - case _: DefNode => "node" + case _: DefNode => "node" case _: DefRegister => "reg" - case _: DefWire => "wire" - case _: Port => "port" - case _: DefMemory => "mem" + case _: DefWire => "wire" + case _: Port => "port" + case _: DefMemory => "mem" case (_: DefInstance | _: WDefInstance) => "inst" - case _: Module => "module" + case _: Module => "module" case _: ExtModule => "extmodule" } val ref = decl match { @@ -237,7 +247,7 @@ class DeadCodeElimination extends Transform def deleteIfNotEnabled(stmt: Statement, en: Expression): Statement = en match { case UIntLiteral(v, _) if v == BigInt(0) => EmptyStmt - case _ => stmt + case _ => stmt } def onStmt(stmt: Statement): Statement = { @@ -256,12 +266,11 @@ class DeadCodeElimination extends Transform logger.debug(deleteMsg(decl)) renames.delete(decl.name) EmptyStmt - } - else decl - case print: Print => deleteIfNotEnabled(print, print.en) - case stop: Stop => deleteIfNotEnabled(stop, stop.en) + } else decl + case print: Print => deleteIfNotEnabled(print, print.en) + case stop: Stop => deleteIfNotEnabled(stop, stop.en) case formal: Verification => deleteIfNotEnabled(formal, formal.en) - case con: Connect => + case con: Connect => val node = getDeps(con.loc) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else con case Attach(info, exprs) => // If any exprs are dead then all are @@ -270,7 +279,7 @@ class DeadCodeElimination extends Transform case IsInvalid(info, expr) => val node = getDeps(expr) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else IsInvalid(info, expr) - case block: Block => block map onStmt + case block: Block => block.map(onStmt) case other => other } stmtx match { // Check if module empty @@ -300,8 +309,7 @@ class DeadCodeElimination extends Transform if (portsx.isEmpty && doTouchExtMods.contains(ext.name)) { logger.debug(deleteMsg(mod)) None - } - else { + } else { if (ext.ports != portsx) throwInternalError() // Sanity check Some(ext.copy(ports = portsx)) } @@ -309,14 +317,13 @@ class DeadCodeElimination extends Transform } - def run(state: CircuitState, - dontTouches: Seq[LogicNode], - doTouchExtMods: Set[String]): CircuitState = { + def run(state: CircuitState, dontTouches: Seq[LogicNode], doTouchExtMods: Set[String]): CircuitState = { val c = state.circuit val moduleMap = c.modules.map(m => m.name -> m).toMap val iGraph = InstanceKeyGraph(c) - val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) => - k.module -> v.map(i => i.name -> i.module).toMap + val moduleDeps = iGraph.graph.getEdgeMap.map({ + case (k, v) => + k.module -> v.map(i => i.name -> i.module).toMap }) val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) @@ -347,11 +354,12 @@ class DeadCodeElimination extends Transform // themselves. We iterate over the modules in a topological order from leaves to the top. The // current status of the modulesxMap is used to either delete instances or update their types val modulesxMap = mutable.HashMap.empty[String, DefModule] - topoSortedModules.foreach { case mod => - deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match { - case Some(m) => modulesxMap += m.name -> m - case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) - } + topoSortedModules.foreach { + case mod => + deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match { + case Some(m) => modulesxMap += m.name -> m + case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) + } } // Preserve original module order diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala index 627af11f..18e32cbc 100644 --- a/src/main/scala/firrtl/transforms/Dedup.scala +++ b/src/main/scala/firrtl/transforms/Dedup.scala @@ -20,7 +20,6 @@ import scala.annotation.tailrec // Datastructures import scala.collection.mutable - /** A component, e.g. register etc. Must be declared only once under the TopAnnotation */ case class NoDedupAnnotation(target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] { def duplicate(n: ModuleTarget): NoDedupAnnotation = NoDedupAnnotation(n) @@ -36,7 +35,9 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt new ShellOption[Unit]( longOption = "no-dedup", toAnnotationSeq = _ => Seq(NoCircuitDedupAnnotation), - helpText = "Do NOT dedup modules" ) ) + helpText = "Do NOT dedup modules" + ) + ) } @@ -46,12 +47,13 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt * @param original Original module * @param index the normalized position of the original module in the original module list, fraction between 0 and 1 */ -case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) extends MultiTargetAnnotation { +case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) + extends MultiTargetAnnotation { override val targets: Seq[Seq[Target]] = Seq(Seq(original), duplicate.toList) override def duplicate(n: Seq[Seq[Target]]): Annotation = { n.toList match { case Seq(_, List(dup: IsModule)) => DedupedResult(original, Some(dup), index) - case _ => DedupedResult(original, None, -1) + case _ => DedupedResult(original, None, -1) } } } @@ -96,7 +98,7 @@ class DedupModules extends Transform with DependencyAPIMigration { val noDedups = state.circuit.main +: state.annotations.collect { case NoDedupAnnotation(ModuleTarget(_, m)) => m } val (remainingAnnotations, dupResults) = state.annotations.partition { case _: DupedResult => false - case _ => true + case _ => true } val previouslyDupedMap = dupResults.flatMap { case DupedResult(newModules, original) => @@ -114,9 +116,11 @@ class DedupModules extends Transform with DependencyAPIMigration { * @param noDedups Modules not to dedup * @return Deduped Circuit and corresponding RenameMap */ - def run(c: Circuit, - noDedups: Seq[String], - previouslyDupedMap: Map[String, String]): (Circuit, RenameMap, AnnotationSeq) = { + def run( + c: Circuit, + noDedups: Seq[String], + previouslyDupedMap: Map[String, String] + ): (Circuit, RenameMap, AnnotationSeq) = { // RenameMap val componentRenameMap = RenameMap() @@ -124,13 +128,16 @@ class DedupModules extends Transform with DependencyAPIMigration { // Maps module name to corresponding dedup module val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, previouslyDupedMap, componentRenameMap) - val dedupCliques = dedupMap.foldLeft(Map.empty[String, Set[String]]) { - case (dedupCliqueMap, (orig: String, dupMod: DefModule)) => - val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig - dedupCliqueMap + (dupMod.name -> set) - }.flatMap { case (dedupName, set) => - set.map { _ -> set } - } + val dedupCliques = dedupMap + .foldLeft(Map.empty[String, Set[String]]) { + case (dedupCliqueMap, (orig: String, dupMod: DefModule)) => + val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig + dedupCliqueMap + (dupMod.name -> set) + } + .flatMap { + case (dedupName, set) => + set.map { _ -> set } + } // Use old module list to preserve ordering // Lookup what a module deduped to, if its a duplicate, remove it @@ -149,9 +156,10 @@ class DedupModules extends Transform with DependencyAPIMigration { val ct = CircuitTarget(c.main) - val map = dedupMap.map { case (from, to) => - logger.debug(s"[Dedup] $from -> ${to.name}") - ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name)) + val map = dedupMap.map { + case (from, to) => + logger.debug(s"[Dedup] $from -> ${to.name}") + ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name)) } val moduleRenameMap = RenameMap() moduleRenameMap.recordAll(map) @@ -159,15 +167,19 @@ class DedupModules extends Transform with DependencyAPIMigration { // Build instanceify renaming map val instanceGraph = InstanceKeyGraph(c) val instanceify = RenameMap() - val moduleName2Index = c.modules.map(_.name).zipWithIndex.map { case (n, i) => - { - c.modules.size match { - case 0 => (n, 0.0) - case 1 => (n, 1.0) - case d => (n, i.toDouble / (d - 1)) + val moduleName2Index = c.modules + .map(_.name) + .zipWithIndex + .map { + case (n, i) => { + c.modules.size match { + case 0 => (n, 0.0) + case 1 => (n, 1.0) + case d => (n, i.toDouble / (d - 1)) + } } } - }.toMap + .toMap // get the ordered set of instances a module, includes new Deduped modules val getChildrenInstances = { @@ -182,56 +194,62 @@ class DedupModules extends Transform with DependencyAPIMigration { } val instanceNameMap: Map[OfModule, Map[Instance, Instance]] = { - dedupMap.map { case (oldName, dedupedMod) => - val key = OfModule(oldName) - val value = getChildrenInstances(oldName).zip(getChildrenInstances(dedupedMod.name)).map { - case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name) - }.toMap - key -> value + dedupMap.map { + case (oldName, dedupedMod) => + val key = OfModule(oldName) + val value = getChildrenInstances(oldName) + .zip(getChildrenInstances(dedupedMod.name)) + .map { + case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name) + } + .toMap + key -> value }.toMap } - val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { case mt@ModuleTarget(c, m) if dedupCliques(m).size > 1 => - dedupMap.get(m) match { - case None => Nil - case Some(module: DefModule) => - val paths = instanceGraph.findInstancesInHierarchy(m) - // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option - val newTargets = paths.map { path => - val root: IsModule = ct.module(c) - path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => - if(mod == c) { - val mod = CircuitTarget(c).module(c) - mod -> mod - } else { - val enclosingMod = oldRelPath match { - case i: InstanceTarget => i.ofModule - case m: ModuleTarget => m.module - } - val instMap = instanceNameMap(OfModule(enclosingMod)) - val newInstName = instMap(Instance(name)).value - val old = oldRelPath.instOf(name, mod) - old -> newRelPath.instOf(newInstName, mod) + val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { + case mt @ ModuleTarget(c, m) if dedupCliques(m).size > 1 => + dedupMap.get(m) match { + case None => Nil + case Some(module: DefModule) => + val paths = instanceGraph.findInstancesInHierarchy(m) + // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option + val newTargets = paths.map { path => + val root: IsModule = ct.module(c) + path.foldLeft(root -> root) { + case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) => + if (mod == c) { + val mod = CircuitTarget(c).module(c) + mod -> mod + } else { + val enclosingMod = oldRelPath match { + case i: InstanceTarget => i.ofModule + case m: ModuleTarget => m.module + } + val instMap = instanceNameMap(OfModule(enclosingMod)) + val newInstName = instMap(Instance(name)).value + val old = oldRelPath.instOf(name, mod) + old -> newRelPath.instOf(newInstName, mod) + } } } - } - // Add all relative paths to referredModule to map to new instances - def addRecord(from: IsMember, to: IsMember): Unit = from match { - case x: ModuleTarget => - instanceify.record(x, to) - case x: IsComponent => - instanceify.record(x, to) - addRecord(x.stripHierarchy(1), to) - } - // Instanceify deduped Modules! - if (dedupCliques(module.name).size > 1) { - newTargets.foreach { case (from, to) => addRecord(from, to) } - } - // Return Deduped Results - if (newTargets.size == 1) { - Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m))) - } else Nil - } + // Add all relative paths to referredModule to map to new instances + def addRecord(from: IsMember, to: IsMember): Unit = from match { + case x: ModuleTarget => + instanceify.record(x, to) + case x: IsComponent => + instanceify.record(x, to) + addRecord(x.stripHierarchy(1), to) + } + // Instanceify deduped Modules! + if (dedupCliques(module.name).size > 1) { + newTargets.foreach { case (from, to) => addRecord(from, to) } + } + // Return Deduped Results + if (newTargets.size == 1) { + Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m))) + } else Nil + } case noDedups => Nil } @@ -242,6 +260,7 @@ class DedupModules extends Transform with DependencyAPIMigration { /** Utility functions for [[DedupModules]] */ object DedupModules extends LazyLogging { + /** Change's a module's internal signal names, types, infos, and modules. * @param rename Function to rename a signal. Called on declaration and references. * @param retype Function to retype a signal. Called on declaration, references, and subfields @@ -250,14 +269,16 @@ object DedupModules extends LazyLogging { * @param module Module to change internals * @return Changed Module */ - def changeInternals(rename: String=>String, - retype: String=>Type=>Type, - reinfo: Info=>Info, - renameOfModule: (String, String)=>String, - renameExps: Boolean = true - )(module: DefModule): DefModule = { + def changeInternals( + rename: String => String, + retype: String => Type => Type, + reinfo: Info => Info, + renameOfModule: (String, String) => String, + renameExps: Boolean = true + )(module: DefModule + ): DefModule = { def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe)) - def onExp(e: Expression): Expression = e match { + def onExp(e: Expression): Expression = e match { case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g) case WSubField(expr, n, tpe, kind) => val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n) @@ -266,12 +287,12 @@ object DedupModules extends LazyLogging { val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind) //TODO: renameMap.rename(e.serialize, finalExpr.serialize) finalExpr - case other => other map onExp + case other => other.map(onExp) } def onStmt(s: Statement): Statement = s match { case DefNode(info, name, value) => retype(name)(value.tpe) - if(renameExps) DefNode(reinfo(info), rename(name), onExp(value)) + if (renameExps) DefNode(reinfo(info), rename(name), onExp(value)) else DefNode(reinfo(info), rename(name), value) case WDefInstance(i, n, m, t) => val newmod = renameOfModule(n, m) @@ -283,12 +304,18 @@ object DedupModules extends LazyLogging { val oldType = MemPortUtils.memType(d) val newType = retype(d.name)(oldType) val index = oldType - .asInstanceOf[BundleType].fields.headOption - .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere( - { - case Field("data" | "wdata" | "rdata", _, _) => true - case _ => false - })) + .asInstanceOf[BundleType] + .fields + .headOption + .map( + _.tpe + .asInstanceOf[BundleType] + .fields + .indexWhere({ + case Field("data" | "wdata" | "rdata", _, _) => true + case _ => false + }) + ) val newDataType = index match { case Some(i) => //If index nonempty, then there exists a port @@ -299,15 +326,15 @@ object DedupModules extends LazyLogging { // associate it with the type of the memory (as the memory type is different than the datatype) retype(d.name + ";&*^$")(d.dataType) } - d.copy(dataType = newDataType) map rename map reinfo + d.copy(dataType = newDataType).map(rename).map(reinfo) case h: IsDeclaration => - val temp = h map rename map retype(h.name) map reinfo - if(renameExps) temp map onExp else temp + val temp = h.map(rename).map(retype(h.name)).map(reinfo) + if (renameExps) temp.map(onExp) else temp case other => - val temp = other map reinfo map onStmt - if(renameExps) temp map onExp else temp + val temp = other.map(reinfo).map(onStmt) + if (renameExps) temp.map(onExp) else temp } - module map onPort map onStmt + module.map(onPort).map(onStmt) } /** Dedup a module's instances based on dedup map @@ -321,11 +348,13 @@ object DedupModules extends LazyLogging { * @param renameMap Will be modified to keep track of renames in this function * @return fixed up module deduped instances */ - def dedupInstances(top: CircuitTarget, - originalModule: String, - moduleMap: Map[String, DefModule], - name2name: Map[String, String], - renameMap: RenameMap): DefModule = { + def dedupInstances( + top: CircuitTarget, + originalModule: String, + moduleMap: Map[String, DefModule], + name2name: Map[String, String], + renameMap: RenameMap + ): DefModule = { val module = moduleMap(originalModule) // If black box, return it (it has no instances) @@ -340,7 +369,8 @@ object DedupModules extends LazyLogging { } val typeMap = mutable.HashMap[String, Type]() def retype(name: String)(tpe: Type): Type = { - if (typeMap.contains(name)) typeMap(name) else { + if (typeMap.contains(name)) typeMap(name) + else { if (instanceModuleMap.contains(name)) { val newType = Utils.module_type(getNewModule(instanceModuleMap(name))) typeMap(name) = newType @@ -360,7 +390,7 @@ object DedupModules extends LazyLogging { def renameOfModule(instance: String, ofModule: String): String = { name2name(ofModule) } - changeInternals({n => n}, retype, {i => i}, renameOfModule)(module) + changeInternals({ n => n }, retype, { i => i }, renameOfModule)(module) } @tailrec @@ -415,10 +445,11 @@ object DedupModules extends LazyLogging { * @return A map from tag to names of modules with the same structure and * a RenameMap which maps Module names to their Tag. */ - def buildRTLTags(top: CircuitTarget, - moduleLinearization: Seq[DefModule], - noDedups: Set[String] - ): (collection.Map[String, collection.Set[String]], RenameMap) = { + def buildRTLTags( + top: CircuitTarget, + moduleLinearization: Seq[DefModule], + noDedups: Set[String] + ): (collection.Map[String, collection.Set[String]], RenameMap) = { // maps hash code to human readable tag val hashToTag = mutable.HashMap[ir.HashCode, String]() @@ -449,9 +480,9 @@ object DedupModules extends LazyLogging { moduleNameToTag(originalModule.name) = hashToTag(hash) } - val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet } + val tag2all = hashToNames.map { case (hash, names) => hashToTag(hash) -> names.toSet } val tagMap = RenameMap() - moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } + moduleNameToTag.foreach { case (name, tag) => tagMap.record(top.module(name), top.module(tag)) } (tag2all, tagMap) } @@ -461,10 +492,12 @@ object DedupModules extends LazyLogging { * @param renameMap rename map to populate when deduping * @return Map of original Module name -> Deduped Module */ - def deduplicate(circuit: Circuit, - noDedups: Set[String], - previousDupResults: Map[String, String], - renameMap: RenameMap): Map[String, DefModule] = { + def deduplicate( + circuit: Circuit, + noDedups: Set[String], + previousDupResults: Map[String, String], + renameMap: RenameMap + ): Map[String, DefModule] = { val (moduleMap, moduleLinearization) = { val iGraph = InstanceKeyGraph(circuit) @@ -479,13 +512,14 @@ object DedupModules extends LazyLogging { val (tag2all, tagMap) = buildRTLTags(top, moduleLinearization, noDedups) // Set tag2name to be the best dedup module name - val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap + val moduleIndex = circuit.modules.zipWithIndex.map { case (m, i) => m.name -> i }.toMap // returns the module matching the circuit name or the module with lower index otherwise def order(l: String, r: String): String = { if (l == main) l else if (r == main) r - else if (moduleIndex(l) < moduleIndex(r)) l else r + else if (moduleIndex(l) < moduleIndex(r)) l + else r } // Maps a module's tag to its deduplicated module @@ -499,7 +533,7 @@ object DedupModules extends LazyLogging { tag2name(tag) = dedupName val dedupModule = moduleMap(dedupWithoutOldName) match { case e: ExtModule => e.copy(name = dedupName) - case e: Module => e.copy(name = dedupName) + case e: Module => e.copy(name = dedupName) } dedupName -> dedupModule }.toMap @@ -508,32 +542,32 @@ object DedupModules extends LazyLogging { val name2name = moduleMap.keysIterator.map { originalModule => tagMap.get(top.module(originalModule)) match { case Some(Seq(Target(_, Some(tag), Nil))) => originalModule -> tag2name(tag) - case None => originalModule -> originalModule - case other => throwInternalError(other.toString) + case None => originalModule -> originalModule + case other => throwInternalError(other.toString) } }.toMap // Build Remap for modules with deduped module references val dedupedName2module = tag2name.map { - case (tag, name) => name -> DedupModules.dedupInstances( - top, name, moduleMapWithOldNames, name2name, renameMap) + case (tag, name) => name -> DedupModules.dedupInstances(top, name, moduleMapWithOldNames, name2name, renameMap) } // Build map from original name to corresponding deduped module // It is important to flatMap before looking up the DefModules so that they aren't hashed val name2module: Map[String, DefModule] = tag2all.flatMap { case (tag, names) => names.map(_ -> tag) } - .mapValues(tag => dedupedName2module(tag2name(tag))) - .toMap + .mapValues(tag => dedupedName2module(tag2name(tag))) + .toMap // Build renameMap val indexedTargets = mutable.HashMap[String, IndexedSeq[ReferenceTarget]]() - name2module.foreach { case (originalName, depModule) => - if(originalName != depModule.name) { - val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule)) - val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName)) - computeRenameMap(fromSeq, toSeq, renameMap) - } + name2module.foreach { + case (originalName, depModule) => + if (originalName != depModule.name) { + val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule)) + val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName)) + computeRenameMap(fromSeq, toSeq, renameMap) + } } name2module @@ -549,18 +583,21 @@ object DedupModules extends LazyLogging { tpe } - changeInternals(rename, retype, {i => i}, {(x, y) => x}, renameExps = false)(m) + changeInternals(rename, retype, { i => i }, { (x, y) => x }, renameExps = false)(m) refs.toIndexedSeq } - def computeRenameMap(originalNames: IndexedSeq[ReferenceTarget], - dedupedNames: IndexedSeq[ReferenceTarget], - renameMap: RenameMap): Unit = { + def computeRenameMap( + originalNames: IndexedSeq[ReferenceTarget], + dedupedNames: IndexedSeq[ReferenceTarget], + renameMap: RenameMap + ): Unit = { originalNames.zip(dedupedNames).foreach { - case (o, d) => if (o.component != d.component || o.ref != d.ref) { - renameMap.record(o, d.copy(module = o.module)) - } + case (o, d) => + if (o.component != d.component || o.ref != d.ref) { + renameMap.record(o, d.copy(module = o.module)) + } } } diff --git a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala index a1e49d62..bfab31bf 100644 --- a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala +++ b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala @@ -33,7 +33,7 @@ object FixAddingNegativeLiterals { */ def fixupModule(m: DefModule): DefModule = { val namespace = Namespace(m) - m map fixupStatement(namespace) + m.map(fixupStatement(namespace)) } /** Returns a statement with fixed additions of negative literals @@ -43,8 +43,8 @@ object FixAddingNegativeLiterals { */ def fixupStatement(namespace: Namespace)(s: Statement): Statement = { val stmtBuffer = mutable.ListBuffer[Statement]() - val ret = s map fixupStatement(namespace) map fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer) - if(stmtBuffer.isEmpty) { + val ret = s.map(fixupStatement(namespace)).map(fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer)) + if (stmtBuffer.isEmpty) { ret } else { stmtBuffer += ret @@ -58,8 +58,7 @@ object FixAddingNegativeLiterals { * @param e expression to fixup * @return generated statements and the fixed expression */ - def fixupExpression(info: Info, namespace: Namespace) - (e: Expression): (Seq[Statement], Expression) = { + def fixupExpression(info: Info, namespace: Namespace)(e: Expression): (Seq[Statement], Expression) = { val stmtBuffer = mutable.ListBuffer[Statement]() val retExpr = fixupOnExpr(info, namespace, stmtBuffer)(e) (stmtBuffer.toList, retExpr) @@ -72,12 +71,16 @@ object FixAddingNegativeLiterals { * @param e expression to fixup * @return fixed expression */ - private def fixupOnExpr(info: Info, namespace: Namespace, stmtBuffer: mutable.ListBuffer[Statement]) - (e: Expression): Expression = { + private def fixupOnExpr( + info: Info, + namespace: Namespace, + stmtBuffer: mutable.ListBuffer[Statement] + )(e: Expression + ): Expression = { // Helper function to create the subtraction expression def fixupAdd(expr: Expression, litValue: BigInt, litWidth: BigInt): DoPrim = { - if(litValue == minNegValue(litWidth)) { + if (litValue == minNegValue(litWidth)) { val posLiteral = SIntLiteral(-litValue) assert(posLiteral.width.asInstanceOf[IntWidth].width - 1 == litWidth) val sub = DefNode(info, namespace.newTemp, setType(DoPrim(Sub, Seq(expr, posLiteral), Nil, UnknownType))) @@ -91,10 +94,10 @@ object FixAddingNegativeLiterals { } } - e map fixupOnExpr(info, namespace, stmtBuffer) match { - case DoPrim(Add, Seq(arg, lit@SIntLiteral(value, w@IntWidth(width))), Nil, t: SIntType) if value < 0 => + e.map(fixupOnExpr(info, namespace, stmtBuffer)) match { + case DoPrim(Add, Seq(arg, lit @ SIntLiteral(value, w @ IntWidth(width))), Nil, t: SIntType) if value < 0 => fixupAdd(arg, value, width) - case DoPrim(Add, Seq(lit@SIntLiteral(value, w@IntWidth(width)), arg), Nil, t: SIntType) if value < 0 => + case DoPrim(Add, Seq(lit @ SIntLiteral(value, w @ IntWidth(width)), arg), Nil, t: SIntType) if value < 0 => fixupAdd(arg, value, width) case other => other } diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala index cc5b3504..36e71470 100644 --- a/src/main/scala/firrtl/transforms/Flatten.scala +++ b/src/main/scala/firrtl/transforms/Flatten.scala @@ -7,7 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.annotations._ import scala.collection.mutable -import firrtl.passes.{InlineInstances,PassException} +import firrtl.passes.{InlineInstances, PassException} import firrtl.stage.Forms /** Tags an annotation to be consumed by this transform */ @@ -25,101 +25,114 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named */ class Flatten extends Transform with DependencyAPIMigration { - override def prerequisites = Forms.LowForm - override def optionalPrerequisites = Seq.empty - override def optionalPrerequisiteOf = Forms.LowEmitters + override def prerequisites = Forms.LowForm + override def optionalPrerequisites = Seq.empty + override def optionalPrerequisiteOf = Forms.LowEmitters override def invalidates(a: Transform) = false - val inlineTransform = new InlineInstances - - private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = - anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) { - case ((modNames, instNames), ann) => ann match { - case FlattenAnnotation(CircuitName(c)) => - (circuit.modules.collect { - case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) - }.toSet, instNames) - case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) - case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) - case _ => throw new PassException("Annotation must be a FlattenAnnotation") - } - } - - /** + val inlineTransform = new InlineInstances + + private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) = + anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) { + case ((modNames, instNames), ann) => + ann match { + case FlattenAnnotation(CircuitName(c)) => + ( + circuit.modules.collect { + case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c)) + }.toSet, + instNames + ) + case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames) + case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod)) + case _ => throw new PassException("Annotation must be a FlattenAnnotation") + } + } + + /** * Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and * by rewriting the original circuit to refer to the new modules that will be inlined later. * @return modified circuit and ModuleNames to inline */ - def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = { - val modMap = c.modules.map(m => m.name->m).toMap - val seedMods = mutable.Map.empty[String, String] - val newModDefs = mutable.Set.empty[DefModule] - val nsp = Namespace(c) - - /** + def duplicateSubCircuitsFromAnno( + c: Circuit, + mods: Set[ModuleName], + insts: Set[ComponentName] + ): (Circuit, Set[ModuleName]) = { + val modMap = c.modules.map(m => m.name -> m).toMap + val seedMods = mutable.Map.empty[String, String] + val newModDefs = mutable.Set.empty[DefModule] + val nsp = Namespace(c) + + /** * We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later. * It populates seedMods where we capture the mapping between the original module name of the instances came from annotation * to a new module name that we will create as a replica of the original one. * Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged. */ - def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { - case _: Block => x map rewriteMod(parent) - case WDefInstance(info, instName, moduleName, instTpe) => - if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) - || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) { - val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - seedMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - } else x - case _ => x - } - - val modifMods = c.modules map { m => m map rewriteMod(m) } - - /** + def rewriteMod(parent: DefModule)(x: Statement): Statement = x match { + case _: Block => x.map(rewriteMod(parent)) + case WDefInstance(info, instName, moduleName, instTpe) => + if ( + insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main)))) + || mods.contains(ModuleName(parent.name, CircuitName(c.main))) + ) { + val newModName = + if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN") + seedMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + } else x + case _ => x + } + + val modifMods = c.modules.map { m => m.map(rewriteMod(m)) } + + /** * Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations). * Populates newModDefs, which are replicated modules used in the subcircuit that we create * by recursively traversing modules captured inside seedMods and replicating them */ - def recDupMods(mods: Map[String, String]): Unit = { - val replMods = mutable.Map.empty[String, String] - - def dupMod(x: Statement): Statement = x match { - case _: Block => x map dupMod - case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match { - case m: Module => - val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN") - replMods += moduleName -> newModName - WDefInstance(info, instName, newModName, instTpe) - case _ => x // Ignore extmodules - } - case _ => x - } - - def dupName(name: String): String = mods(name) - val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName } - - newModDefs ++= newMods - - if(replMods.size > 0) recDupMods(replMods.toMap) - - } - recDupMods(seedMods.toMap) - - //convert newly created modules to ModuleName for inlining next (outside this function) - val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) } - (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) - } - - override def execute(state: CircuitState): CircuitState = { - val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } - annos match { - case Nil => state - case myAnnotations => - val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) - // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline - val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) - inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) - } - } + def recDupMods(mods: Map[String, String]): Unit = { + val replMods = mutable.Map.empty[String, String] + + def dupMod(x: Statement): Statement = x match { + case _: Block => x.map(dupMod) + case WDefInstance(info, instName, moduleName, instTpe) => + modMap(moduleName) match { + case m: Module => + val newModName = + if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN") + replMods += moduleName -> newModName + WDefInstance(info, instName, newModName, instTpe) + case _ => x // Ignore extmodules + } + case _ => x + } + + def dupName(name: String): String = mods(name) + val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) } + + newModDefs ++= newMods + + if (replMods.size > 0) recDupMods(replMods.toMap) + + } + recDupMods(seedMods.toMap) + + //convert newly created modules to ModuleName for inlining next (outside this function) + val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) } + (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet) + } + + override def execute(state: CircuitState): CircuitState = { + val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a } + annos match { + case Nil => state + case myAnnotations => + val (modNames, instNames) = collectAnns(state.circuit, myAnnotations) + // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline + val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames) + inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations) + } + } } diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index a2399b5a..b582fe2a 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -119,7 +119,7 @@ object FlattenRegUpdate { def rec(e: Expression): (Info, Expression) = { val (info, expr) = kind(e) match { case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e)) - case _ => unwrap(e) + case _ => unwrap(e) } expr match { case Mux(cond, tval, fval, tpe) => @@ -128,16 +128,18 @@ object FlattenRegUpdate { val infox = combineInfos(info, tinfo, finfo) (infox, Mux(cond, tvalx, fvalx, tpe)) // Return the original expression to end flattening - case _ => unwrap(e) + case _ => unwrap(e) } } rec(start) } def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { - case reg @ DefRegister(_, rname, _,_, resetCond, _) => - assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero, - "Synchronous reset should have already been made explicit!") + case reg @ DefRegister(_, rname, _, _, resetCond, _) => + assert( + resetCond.tpe == AsyncResetType || resetCond == Utils.zero, + "Synchronous reset should have already been made explicit!" + ) val ref = WRef(reg) val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref)) val update = Connect(info, ref, rhs) @@ -145,7 +147,7 @@ object FlattenRegUpdate { reg // Remove connections to Registers so we preserve LowFirrtl single-connection semantics case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt - case other => other + case other => other } val bodyx = onStmt(mod.body) @@ -163,12 +165,14 @@ object FlattenRegUpdate { class FlattenRegUpdate extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform], - Dependency[LegalizeClocksTransform] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform], + Dependency[LegalizeClocksTransform] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -181,7 +185,7 @@ class FlattenRegUpdate extends Transform with DependencyAPIMigration { def execute(state: CircuitState): CircuitState = { val modulesx = state.circuit.modules.map { - case mod: Module => FlattenRegUpdate.flattenReg(mod) + case mod: Module => FlattenRegUpdate.flattenReg(mod) case ext: ExtModule => ext } state.copy(circuit = state.circuit.copy(modules = modulesx)) diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala index 166feba0..0db67f1e 100644 --- a/src/main/scala/firrtl/transforms/GroupComponents.scala +++ b/src/main/scala/firrtl/transforms/GroupComponents.scala @@ -10,7 +10,6 @@ import firrtl.stage.Forms import scala.collection.mutable - /** * Specifies a group of components, within a module, to pull out into their own module * Components that are only connected to a group's components will also be included @@ -21,8 +20,14 @@ import scala.collection.mutable * @param outputSuffix suggested suffix of any output ports of the new module * @param inputSuffix suggested suffix of any input ports of the new module */ -case class GroupAnnotation(components: Seq[ComponentName], newModule: String, newInstance: String, outputSuffix: Option[String] = None, inputSuffix: Option[String] = None) extends Annotation { - if(components.nonEmpty) { +case class GroupAnnotation( + components: Seq[ComponentName], + newModule: String, + newInstance: String, + outputSuffix: Option[String] = None, + inputSuffix: Option[String] = None) + extends Annotation { + if (components.nonEmpty) { require(components.forall(_.module == components.head.module), "All components must be in the same module.") require(components.forall(!_.name.contains('.')), "No components can be a subcomponent.") } @@ -35,7 +40,7 @@ case class GroupAnnotation(components: Seq[ComponentName], newModule: String, ne /* Only keeps components renamed to components */ def update(renames: RenameMap): Seq[Annotation] = { - val newComponents = components.flatMap{c => renames.get(c).getOrElse(Seq(c))}.collect { + val newComponents = components.flatMap { c => renames.get(c).getOrElse(Seq(c)) }.collect { case c: ComponentName => c } Seq(GroupAnnotation(newComponents, newModule, newInstance, outputSuffix, inputSuffix)) @@ -58,7 +63,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { } override def execute(state: CircuitState): CircuitState = { - val groups = state.annotations.collect {case g: GroupAnnotation => g} + val groups = state.annotations.collect { case g: GroupAnnotation => g } val module2group = groups.groupBy(_.currentModule) val mnamespace = Namespace(state.circuit) val newModules = state.circuit.modules.flatMap { @@ -74,13 +79,12 @@ class GroupComponents extends Transform with DependencyAPIMigration { val namespace = Namespace(m) val groupRoots = groups.map(_.components.map(_.name)) val totalSum = groupRoots.map(_.size).sum - val union = groupRoots.foldLeft(Set.empty[String]){(all, set) => all.union(set.toSet)} + val union = groupRoots.foldLeft(Set.empty[String]) { (all, set) => all.union(set.toSet) } - require(groupRoots.forall{_.forall{namespace.contains}}, "All names should be in this module") + require(groupRoots.forall { _.forall { namespace.contains } }, "All names should be in this module") require(totalSum == union.size, "No name can be in more than one group") require(groupRoots.forall(_.nonEmpty), "All groupRoots must by non-empty") - // Order of groups, according to their label. The label is the first root in the group val labelOrder = groups.collect({ case g: GroupAnnotation => g.components.head.name }) @@ -90,8 +94,8 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Group roots, by label // The label "" indicates the original module, and components belonging to that group will remain // in the original module (not get moved into a new module) - val label2group: Map[String, MSet[String]] = groups.collect{ - case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name):_*) + val label2group: Map[String, MSet[String]] = groups.collect { + case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name): _*) }.toMap + ("" -> mutable.Set("")) // Name of new module containing each group, by label @@ -105,7 +109,6 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Build set of components not in set val notSet = label2group.map { case (key, value) => key -> union.diff(value) } - // Get all dependencies between components val deps = getComponentConnectivity(m) @@ -114,13 +117,14 @@ class GroupComponents extends Transform with DependencyAPIMigration { // For each group (by label), add connectivity between nodes in set // Populate reachableNodes with reachability, where blacklist is their notSet - label2group.foreach { case (label, set) => - set.foreach { x => - deps.addPairWithEdge(label, x) - } - deps.reachableFrom(label, notSet(label)) foreach { node => - reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label - } + label2group.foreach { + case (label, set) => + set.foreach { x => + deps.addPairWithEdge(label, x) + } + deps.reachableFrom(label, notSet(label)).foreach { node => + reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label + } } // Unused nodes are not reachable from any group nor the root--add them to root group @@ -129,12 +133,13 @@ class GroupComponents extends Transform with DependencyAPIMigration { } // Add nodes who are reached by a single group, to that group - reachableNodes.foreach { case (node, membership) => - if(membership.size == 1) { - label2group(membership.head) += node - } else { - label2group("") += node - } + reachableNodes.foreach { + case (node, membership) => + if (membership.size == 1) { + label2group(membership.head) += node + } else { + label2group("") += node + } } applyGrouping(m, labelOrder, label2group, label2module, label2instance, label2annotation) @@ -150,19 +155,21 @@ class GroupComponents extends Transform with DependencyAPIMigration { * @param label2annotation annotation specifying the group, by label * @return new modules, including each group's module and the new split module */ - def applyGrouping( m: Module, - labelOrder: Seq[String], - label2group: Map[String, MSet[String]], - label2module: Map[String, String], - label2instance: Map[String, String], - label2annotation: Map[String, GroupAnnotation] - ): Seq[Module] = { + def applyGrouping( + m: Module, + labelOrder: Seq[String], + label2group: Map[String, MSet[String]], + label2module: Map[String, String], + label2instance: Map[String, String], + label2annotation: Map[String, GroupAnnotation] + ): Seq[Module] = { // Maps node to group val byNode = mutable.HashMap[String, String]() - label2group.foreach { case (group, nodes) => - nodes.foreach { node => - byNode(node) = group - } + label2group.foreach { + case (group, nodes) => + nodes.foreach { node => + byNode(node) = group + } } val groupNamespace = label2group.map { case (head, set) => head -> Namespace(set.toSeq) } @@ -180,7 +187,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { val portNames = groupPortNames(group) val suffix = d match { case Output => label2annotation(group).outputSuffix.getOrElse("") - case Input => label2annotation(group).inputSuffix.getOrElse("") + case Input => label2annotation(group).inputSuffix.getOrElse("") } val newName = groupNamespace(group).newName(source + suffix) val portName = portNames.getOrElseUpdate(source, newName) @@ -192,7 +199,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { val portName = addPort(group, exp, Output) val connectStatement = exp.tpe match { case AnalogType(_) => Attach(NoInfo, Seq(WRef(portName), exp)) - case _ => Connect(NoInfo, WRef(portName), exp) + case _ => Connect(NoInfo, WRef(portName), exp) } groupStatements(group) += connectStatement portName @@ -201,7 +208,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Given the sink is in a group, tidy up source references def inGroupFixExps(group: String, added: mutable.ArrayBuffer[Statement])(e: Expression): Expression = e match { case _: Literal => e - case _: DoPrim | _: Mux | _: ValidIf => e map inGroupFixExps(group, added) + case _: DoPrim | _: Mux | _: ValidIf => e.map(inGroupFixExps(group, added)) case otherExp: Expression => val wref = getWRef(otherExp) val source = wref.name @@ -238,10 +245,10 @@ class GroupComponents extends Transform with DependencyAPIMigration { // Given the sink is in the parent module, tidy up source references belonging to groups def inTopFixExps(e: Expression): Expression = e match { - case _: DoPrim | _: Mux | _: ValidIf => e map inTopFixExps + case _: DoPrim | _: Mux | _: ValidIf => e.map(inTopFixExps) case otherExp: Expression => val wref = getWRef(otherExp) - if(byNode(wref.name) != "") { + if (byNode(wref.name) != "") { // Get the name of source's group val otherGroup = byNode(wref.name) @@ -260,7 +267,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { case r: IsDeclaration if byNode(r.name) != "" => val topStmts = mutable.ArrayBuffer[Statement]() val group = byNode(r.name) - groupStatements(group) += r mapExpr inGroupFixExps(group, topStmts) + groupStatements(group) += r.mapExpr(inGroupFixExps(group, topStmts)) Block(topStmts.toSeq) case c: Connect if byNode(getWRef(c.loc).name) != "" => // Sink is in a group @@ -276,20 +283,26 @@ class GroupComponents extends Transform with DependencyAPIMigration { // TODO Attach if all are in a group? case _: IsDeclaration | _: Connect | _: Attach => // Sink is in Top - val ret = s mapExpr inTopFixExps + val ret = s.mapExpr(inTopFixExps) ret - case other => other map onStmt + case other => other.map(onStmt) } } - // Build datastructures - val newTopBody = Block(labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body))) + val newTopBody = Block( + labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body)) + ) val finalTopBody = Block(Utils.squashEmpty(newTopBody).asInstanceOf[Block].stmts.distinct) // For all group labels (not including the original module label), return a new Module. - val newModules = labelOrder.filter(_ != "") map { group => - Module(NoInfo, label2module(group), groupPorts(group).distinct.toSeq, Block(groupStatements(group).distinct.toSeq)) + val newModules = labelOrder.filter(_ != "").map { group => + Module( + NoInfo, + label2module(group), + groupPorts(group).distinct.toSeq, + Block(groupStatements(group).distinct.toSeq) + ) } Seq(m.copy(body = finalTopBody)) ++ newModules } @@ -298,7 +311,7 @@ class GroupComponents extends Transform with DependencyAPIMigration { case w: WRef => w case other => var w = WRef("") - other mapExpr { e => w = getWRef(e); e} + other.mapExpr { e => w = getWRef(e); e } w } @@ -317,25 +330,25 @@ class GroupComponents extends Transform with DependencyAPIMigration { bidirGraph.addPairWithEdge(sink.name, name) bidirGraph.addPairWithEdge(name, sink.name) w - case other => other map onExpr(sink) + case other => other.map(onExpr(sink)) } def onStmt(stmt: Statement): Unit = stmt match { case w: WDefInstance => case h: IsDeclaration => bidirGraph.addVertex(h.name) - h map onExpr(WRef(h.name)) + h.map(onExpr(WRef(h.name))) case Attach(_, exprs) => // Add edge between each expression - exprs.tail map onExpr(getWRef(exprs.head)) + exprs.tail.map(onExpr(getWRef(exprs.head))) case Connect(_, loc, expr) => onExpr(getWRef(loc))(expr) - case q @ Stop(_,_, clk, en) => + case q @ Stop(_, _, clk, en) => val simName = simNamespace.newTemp simulations(simName) = q - Seq(clk, en) map onExpr(WRef(simName)) + Seq(clk, en).map(onExpr(WRef(simName))) case q @ Print(_, _, args, clk, en) => val simName = simNamespace.newTemp simulations(simName) = q - (args :+ clk :+ en) map onExpr(WRef(simName)) + (args :+ clk :+ en).map(onExpr(WRef(simName))) case Block(stmts) => stmts.foreach(onStmt) case ignore @ (_: IsInvalid | EmptyStmt) => // do nothing case other => throw new Exception(s"Unexpected Statement $other") @@ -358,7 +371,7 @@ class GroupAndDedup extends GroupComponents { override def invalidates(a: Transform): Boolean = a match { case _: DedupModules => true - case _ => super.invalidates(a) + case _ => super.invalidates(a) } } diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala index dd073001..376382cc 100644 --- a/src/main/scala/firrtl/transforms/InferResets.scala +++ b/src/main/scala/firrtl/transforms/InferResets.scala @@ -7,9 +7,9 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.annotations.{ReferenceTarget, TargetToken} -import firrtl.Utils.{toTarget, throwInternalError} +import firrtl.Utils.{throwInternalError, toTarget} import firrtl.options.Dependency -import firrtl.passes.{Pass, PassException, InferTypes} +import firrtl.passes.{InferTypes, Pass, PassException} import firrtl.graph.MutableDiGraph import scala.collection.mutable @@ -83,14 +83,13 @@ object InferResets { // Vectors must all have the same type, so we only process Index 0 // If the subtype is an aggregate, there can be multiple of each index val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) } - VectorTree(fromTokens(ts:_*)) + VectorTree(fromTokens(ts: _*)) // BundleTree case (TargetToken.Field(_) +: _, _) +: _ => val fields = - tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n } - .mapValues { ts => - fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }:_*) - }.toMap + tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }.mapValues { ts => + fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }: _*) + }.toMap BundleTree(fields) } } @@ -113,14 +112,16 @@ object InferResets { class InferResets extends Transform with DependencyAPIMigration { override def prerequisites = - Seq( Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows), - Dependency[passes.InferWidths] ) ++ stage.Forms.WorkingIR + Seq( + Dependency(passes.ResolveKinds), + Dependency(passes.InferTypes), + Dependency(passes.ResolveFlows), + Dependency[passes.InferWidths] + ) ++ stage.Forms.WorkingIR override def invalidates(a: Transform): Boolean = a match { case _: checks.CheckResets | passes.CheckTypes => true - case _ => false + case _ => false } import InferResets._ @@ -138,7 +139,7 @@ class InferResets extends Transform with DependencyAPIMigration { val mod = instMap(target.ref) val port = target.component.head match { case TargetToken.Field(name) => name - case bad => Utils.throwInternalError(s"Unexpected token $bad") + case bad => Utils.throwInternalError(s"Unexpected token $bad") } target.copy(module = mod, ref = port, component = target.component.tail) case _ => target @@ -148,17 +149,18 @@ class InferResets extends Transform with DependencyAPIMigration { // Mark driver of a ResetType leaf def markResetDriver(lhs: Expression, rhs: Expression): Unit = { val con = Utils.flow(lhs) match { - case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs)) + case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs)) case SourceFlow if rhs.tpe == ResetType => Some((rhs, lhs)) // If sink is not ResetType, do nothing - case _ => None + case _ => None } - con.foreach { case (loc, exp) => - val driver = exp.tpe match { - case ResetType => TargetDriver(makeTarget(exp)) - case tpe => TypeDriver(tpe, () => makeTarget(exp)) - } - map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver + con.foreach { + case (loc, exp) => + val driver = exp.tpe match { + case ResetType => TargetDriver(makeTarget(exp)) + case tpe => TypeDriver(tpe, () => makeTarget(exp)) + } + map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver } } stmt match { @@ -227,7 +229,7 @@ class InferResets extends Transform with DependencyAPIMigration { private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = { val graph = new MutableDiGraph[Node] val asyncNode = Typ(AsyncResetType) - val syncNode = Typ(Utils.BoolType) + val syncNode = Typ(Utils.BoolType) for ((target, drivers) <- map) { val v = Var(target) drivers.foreach { @@ -247,7 +249,7 @@ class InferResets extends Transform with DependencyAPIMigration { // do the actual inference, the check is simply if syncNode is reachable from asyncNode graph.addPairWithEdge(v, u) case InvalidDriver => - graph.addVertex(v) // Must be in the graph or won't be inferred + graph.addVertex(v) // Must be in the graph or won't be inferred } } val async = graph.reachableFrom(asyncNode) @@ -257,7 +259,7 @@ class InferResets extends Transform with DependencyAPIMigration { case (a, _) if a.contains(syncNode) => throw InferResetsException(graph.path(asyncNode, syncNode)) case (a, s) => (a.view.collect { case Var(t) => t -> asyncNode.tpe } ++ - s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap + s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap } } } @@ -265,34 +267,40 @@ class InferResets extends Transform with DependencyAPIMigration { private def fixupType(tpe: Type, tree: TypeTree): Type = (tpe, tree) match { case (BundleType(fields), BundleTree(map)) => val fieldsx = - fields.map(f => map.get(f.name) match { - case Some(t) => f.copy(tpe = fixupType(f.tpe, t)) - case None => f - }) + fields.map(f => + map.get(f.name) match { + case Some(t) => f.copy(tpe = fixupType(f.tpe, t)) + case None => f + } + ) BundleType(fieldsx) case (VectorType(vtpe, size), VectorTree(t)) => VectorType(fixupType(vtpe, t), size) case (_, GroundTree(t)) => t - case x => throw new Exception(s"Error! Unexpected pair $x") + case x => throw new Exception(s"Error! Unexpected pair $x") } // Assumes all ReferenceTargets are in the same module private def makeDeclMap(map: Map[ReferenceTarget, Type]): Map[String, TypeTree] = - map.groupBy(_._1.ref).mapValues { ts => - TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }:_*) - }.toMap + map + .groupBy(_._1.ref) + .mapValues { ts => + TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }: _*) + } + .toMap private def implPort(map: Map[String, TypeTree])(port: Port): Port = - map.get(port.name) - .map(tree => port.copy(tpe = fixupType(port.tpe, tree))) - .getOrElse(port) + map + .get(port.name) + .map(tree => port.copy(tpe = fixupType(port.tpe, tree))) + .getOrElse(port) private def implStmt(map: Map[String, TypeTree])(stmt: Statement): Statement = stmt.map(implStmt(map)) match { case decl: IsDeclaration if map.contains(decl.name) => val tree = map(decl.name) decl match { - case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree)) - case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree)) + case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree)) + case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree)) // TODO Can this really happen? case mem: DefMemory => mem.copy(dataType = fixupType(mem.dataType, tree)) case other => other @@ -303,10 +311,13 @@ class InferResets extends Transform with DependencyAPIMigration { private def implement(c: Circuit, map: Map[ReferenceTarget, Type]): Circuit = { val modMaps = map.groupBy(_._1.module) def onMod(mod: DefModule): DefModule = { - modMaps.get(mod.name).map { tmap => - val declMap = makeDeclMap(tmap) - mod.map(implPort(declMap)).map(implStmt(declMap)) - }.getOrElse(mod) + modMaps + .get(mod.name) + .map { tmap => + val declMap = makeDeclMap(tmap) + mod.map(implPort(declMap)).map(implStmt(declMap)) + } + .getOrElse(mod) } c.map(onMod) } diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala index 515bf407..100b598f 100644 --- a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala +++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala @@ -6,7 +6,7 @@ package transforms import firrtl.ir._ import firrtl.Mappers._ import firrtl.options.Dependency -import firrtl.PrimOps.{Bits, Head, Tail, Shr} +import firrtl.PrimOps.{Bits, Head, Shr, Tail} import firrtl.Utils.{isBitExtract, isTemp} import firrtl.WrappedExpression._ @@ -19,8 +19,8 @@ object InlineBitExtractionsTransform { // Note that this can have false negatives but MUST NOT have false positives. private def isSimpleExpr(expr: Expression): Boolean = expr match { case _: WRef | _: Literal | _: WSubField => true - case DoPrim(op, args, _,_) if isBitExtract(op) => args.forall(isSimpleExpr) - case _ => false + case DoPrim(op, args, _, _) if isBitExtract(op) => args.forall(isSimpleExpr) + case _ => false } // replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions @@ -28,12 +28,12 @@ object InlineBitExtractionsTransform { case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) => val msb = bitWidth(rhs.head.tpe) - 1 val lsb = bitWidth(rhs.head.tpe) - c.head - DoPrim(Bits, rhs, Seq(msb,lsb), tpe) + DoPrim(Bits, rhs, Seq(msb, lsb), tpe) case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) => val msb = bitWidth(rhs.head.tpe) - c.head - 1 - DoPrim(Bits, rhs, Seq(msb,0), tpe) + DoPrim(Bits, rhs, Seq(msb, 0), tpe) case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) => - DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe)-1, c.head), tpe) + DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe) - 1, c.head), tpe) case _ => expr // Not a candidate } @@ -49,26 +49,28 @@ object InlineBitExtractionsTransform { */ def onExpr(netlist: Netlist)(expr: Expression): Expression = { expr.map(onExpr(netlist)) match { - case e @ WRef(name, _,_,_) => - netlist.get(we(e)) - .filter(isBitExtract) - .getOrElse(e) + case e @ WRef(name, _, _, _) => + netlist + .get(we(e)) + .filter(isBitExtract) + .getOrElse(e) // replace back-to-back Bits Extractions case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) => ival.head match { case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) => (lop, rop) match { - case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head min rc.head), ltpe) + case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head.min(rc.head)), ltpe) case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe) - case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe) - case (_,_) => (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match { - case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) => - DoPrim(Bits, rhs, Seq(lmsb+rlsb,llsb+rlsb), ltpe) - case (_,_) => lhs // Not a candidate - } + case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe) + case (_, _) => + (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match { + case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) => + DoPrim(Bits, rhs, Seq(lmsb + rlsb, llsb + rlsb), ltpe) + case (_, _) => lhs // Not a candidate + } } - case _ => lhs // Not a candidate - } + case _ => lhs // Not a candidate + } case other => other // Not a candidate } } @@ -97,9 +99,11 @@ object InlineBitExtractionsTransform { class InlineBitExtractionsTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala index 3dac938e..0efc0727 100644 --- a/src/main/scala/firrtl/transforms/InlineCasts.scala +++ b/src/main/scala/firrtl/transforms/InlineCasts.scala @@ -8,7 +8,7 @@ import firrtl.Mappers._ import firrtl.PrimOps.Pad import firrtl.options.Dependency -import firrtl.Utils.{isCast, isBitExtract, NodeMap} +import firrtl.Utils.{isBitExtract, isCast, NodeMap} object InlineCastsTransform { @@ -17,8 +17,8 @@ object InlineCastsTransform { // Note that this can have false negatives but MUST NOT have false positives private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match { case _: WRef | _: Literal | _: WSubField => castSeen - case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true)) - case _ => false + case DoPrim(op, args, _, _) if isCast(op) => args.forall(isSimpleCast(true)) + case _ => false } /** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s @@ -31,17 +31,20 @@ object InlineCastsTransform { def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match { // Anything that may generate a part-select should not be inlined! case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr - case e => e.map(onExpr(replace)) match { - case e @ WRef(name, _,_,_) => - replace.get(name) - .filter(isSimpleCast(castSeen=false)) - .getOrElse(e) - case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) => - replace.get(name) - .map(value => e.copy(args = Seq(value))) - .getOrElse(e) - case other => other // Not a candidate - } + case e => + e.map(onExpr(replace)) match { + case e @ WRef(name, _, _, _) => + replace + .get(name) + .filter(isSimpleCast(castSeen = false)) + .getOrElse(e) + case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) => + replace + .get(name) + .map(value => e.copy(args = Seq(value))) + .getOrElse(e) + case other => other // Not a candidate + } } /** Inline casts in a Statement @@ -69,11 +72,13 @@ object InlineCastsTransform { class InlineCastsTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[PropagatePresetAnnotations] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[PropagatePresetAnnotations] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala index f439fdc9..248775d9 100644 --- a/src/main/scala/firrtl/transforms/LegalizeClocks.scala +++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala @@ -18,8 +18,8 @@ object LegalizeClocksTransform { // Currently only looks for literals nested within casts private def illegalClockExpr(expr: Expression): Boolean = expr match { case _: Literal => true - case DoPrim(op, args, _,_) if isCast(op) => args.exists(illegalClockExpr) - case _ => false + case DoPrim(op, args, _, _) if isCast(op) => args.exists(illegalClockExpr) + case _ => false } /** Legalize Clocks in a Statement @@ -66,11 +66,13 @@ object LegalizeClocksTransform { class LegalizeClocksTransform extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform] ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/LegalizeReductions.scala b/src/main/scala/firrtl/transforms/LegalizeReductions.scala index 2e60aae7..33a10349 100644 --- a/src/main/scala/firrtl/transforms/LegalizeReductions.scala +++ b/src/main/scala/firrtl/transforms/LegalizeReductions.scala @@ -6,17 +6,16 @@ import firrtl.Mappers._ import firrtl.options.Dependency import firrtl.Utils.BoolType - object LegalizeAndReductionsTransform { private def allOnesOfType(tpe: Type): Literal = tpe match { case UIntType(width @ IntWidth(x)) => UIntLiteral((BigInt(1) << x.toInt) - 1, width) - case SIntType(width) => SIntLiteral(-1, width) + case SIntType(width) => SIntLiteral(-1, width) } def onExpr(expr: Expression): Expression = expr.map(onExpr) match { - case DoPrim(PrimOps.Andr, Seq(arg), _,_) if bitWidth(arg.tpe) > 64 => + case DoPrim(PrimOps.Andr, Seq(arg), _, _) if bitWidth(arg.tpe) > 64 => DoPrim(PrimOps.Eq, Seq(arg, allOnesOfType(arg.tpe)), Seq(), BoolType) case other => other } @@ -35,8 +34,7 @@ class LegalizeAndReductionsTransform extends Transform with DependencyAPIMigrati override def prerequisites = firrtl.stage.Forms.WorkingIR ++ - Seq( Dependency(passes.CheckTypes), - Dependency(passes.CheckWidths)) + Seq(Dependency(passes.CheckTypes), Dependency(passes.CheckWidths)) override def optionalPrerequisites = Nil diff --git a/src/main/scala/firrtl/transforms/ManipulateNames.scala b/src/main/scala/firrtl/transforms/ManipulateNames.scala index f15b546f..d0b12e66 100644 --- a/src/main/scala/firrtl/transforms/ManipulateNames.scala +++ b/src/main/scala/firrtl/transforms/ManipulateNames.scala @@ -57,8 +57,9 @@ sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[_]] extends Mult * @note $noteLocalTargets */ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]]( - targets: Seq[Seq[Target]], - transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] { + targets: Seq[Seq[Target]], + transform: Dependency[A]) + extends ManipulateNamesListAnnotation[A] { override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a) @@ -77,8 +78,9 @@ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]]( * @note $noteLocalTargets */ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]]( - targets: Seq[Seq[Target]], - transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] { + targets: Seq[Seq[Target]], + transform: Dependency[A]) + extends ManipulateNamesListAnnotation[A] { override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a) @@ -94,19 +96,21 @@ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]]( * @param oldTargets the old targets */ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( - targets: Seq[Seq[Target]], - transform: Dependency[A], - oldTargets: Seq[Seq[Target]]) extends MultiTargetAnnotation { + targets: Seq[Seq[Target]], + transform: Dependency[A], + oldTargets: Seq[Seq[Target]]) + extends MultiTargetAnnotation { override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a) override def update(renames: RenameMap) = { val (targetsx, oldTargetsx) = targets.zip(oldTargets).foldLeft((Seq.empty[Seq[Target]], Seq.empty[Seq[Target]])) { - case ((accT, accO), (t, o)) => t.flatMap(renames(_)) match { - /* If the target was deleted, delete the old target */ - case tx if tx.isEmpty => (accT, accO) - case tx => (Seq(tx) ++ accT, Seq(o) ++ accO) - } + case ((accT, accO), (t, o)) => + t.flatMap(renames(_)) match { + /* If the target was deleted, delete the old target */ + case tx if tx.isEmpty => (accT, accO) + case tx => (Seq(tx) ++ accT, Seq(o) ++ accO) + } } targetsx match { /* If all targets were deleted, delete the annotation */ @@ -117,9 +121,13 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( /** Return [[firrtl.RenameMap RenameMap]] from old targets to new targets */ def toRenameMap: RenameMap = { - val m = oldTargets.zip(targets).flatMap { - case (a, b) => a.map(_ -> b) - }.toMap.asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]] + val m = oldTargets + .zip(targets) + .flatMap { + case (a, b) => a.map(_ -> b) + } + .toMap + .asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]] RenameMap.create(m) } @@ -132,25 +140,28 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]]( * @param allow a function that returns true if a [[firrtl.annotations.Target Target]] should be renamed */ private class RenameDataStructure( - circuit: ir.Circuit, + circuit: ir.Circuit, val renames: RenameMap, - val block: Target => Boolean, - val allow: Target => Boolean) { + val block: Target => Boolean, + val allow: Target => Boolean) { /** A mapping of targets to associated namespaces */ val namespaces: mutable.HashMap[CompleteTarget, Namespace] = mutable.HashMap(CircuitTarget(circuit.main) -> Namespace(circuit)) - /** Wraps a HashMap to provide better error messages when accessing a non-existing element */ + /** Wraps a HashMap to provide better error messages when accessing a non-existing element */ class InstanceHashMap { type Key = ReferenceTarget type Value = Either[ReferenceTarget, InstanceTarget] private val m = mutable.HashMap[Key, Value]() - def apply(key: ReferenceTarget): Value = m.getOrElse(key, { - throw new FirrtlUserException( - s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems. - | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin) - }) + def apply(key: ReferenceTarget): Value = m.getOrElse( + key, { + throw new FirrtlUserException( + s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems. + | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin + ) + } + ) def update(key: Key, value: Value): Unit = m.update(key, value) } @@ -165,17 +176,17 @@ private class RenameDataStructure( /** Transform for manipulate all the names in a FIRRTL circuit. * @tparam A the type of the child transform */ -abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Transform with DependencyAPIMigration { +abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transform with DependencyAPIMigration { /** A function used to manipulate a name in a FIRRTL circuit */ def manipulate: (String, Namespace) => Option[String] - override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes)) - override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty + override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes)) + override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisiteOf: Seq[TransformDependency] = Forms.LowEmitters override def invalidates(a: Transform) = a match { case _: analyses.GetNamespace => true - case _ => false + case _ => false } /** Compute a new name for some target and record the rename if the new name differs. If the top module or the circuit @@ -192,27 +203,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans case a if r.skip(a) => (name, None) /* Circuit renaming */ - case a@ CircuitTarget(b) => manipulate(b, r.namespaces(a)) match { - case Some(str) => (str, Some(a.copy(circuit = str))) - case None => (b, None) - } + case a @ CircuitTarget(b) => + manipulate(b, r.namespaces(a)) match { + case Some(str) => (str, Some(a.copy(circuit = str))) + case None => (b, None) + } /* Module renaming for non-top modules */ - case a@ ModuleTarget(_, b) => manipulate(b, r.namespaces(a.circuitTarget)) match { - case Some(str) => (str, Some(a.copy(module = str))) - case None => (b, None) - } + case a @ ModuleTarget(_, b) => + manipulate(b, r.namespaces(a.circuitTarget)) match { + case Some(str) => (str, Some(a.copy(module = str))) + case None => (b, None) + } /* Instance renaming */ - case a@ InstanceTarget(_, _, Nil, b, c) => manipulate(b, r.namespaces(a.moduleTarget)) match { - case Some(str) => (str, Some(a.copy(instance = str))) - case None => (b, None) - } + case a @ InstanceTarget(_, _, Nil, b, c) => + manipulate(b, r.namespaces(a.moduleTarget)) match { + case Some(str) => (str, Some(a.copy(instance = str))) + case None => (b, None) + } /* Rename either a module component or a memory */ - case a@ ReferenceTarget(_, _, _, b, Nil) => manipulate(b, r.namespaces(a.moduleTarget)) match { - case Some(str) => (str, Some(a.copy(ref = str))) - case None => (b, None) - } + case a @ ReferenceTarget(_, _, _, b, Nil) => + manipulate(b, r.namespaces(a.moduleTarget)) match { + case Some(str) => (str, Some(a.copy(ref = str))) + case None => (b, None) + } /* Rename an instance port or a memory reader/writer/readwriter */ - case a@ ReferenceTarget(_, _, _, b, (token@ TargetToken.Field(c)) :: Nil) => + case a @ ReferenceTarget(_, _, _, b, (token @ TargetToken.Field(c)) :: Nil) => val ref = r.instanceMap(a.moduleTarget.ref(b)) match { case Right(inst) => inst.ofModuleTarget case Left(mem) => mem @@ -224,8 +239,8 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans } /* Record the optional rename. If the circuit was renamed, also rename the top module. If the top module was * renamed, also rename the circuit. */ - ax.foreach( - axx => target match { + ax.foreach(axx => + target match { case c: CircuitTarget => r.renames.rename(target, r.renames(axx)) r.renames.rename(c.module(c.circuit), CircuitTarget(namex).module(namex)) @@ -252,21 +267,26 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans r.renames.underlying.get(t) match { case Some(ax) if ax.size == 1 => ax match { - case Seq(foo: CircuitTarget) => foo.name - case Seq(foo: ModuleTarget) => foo.module - case Seq(foo: InstanceTarget) => foo.instance - case Seq(foo: ReferenceTarget) => foo.tokens.last match { - case TargetToken.Ref(value) => value - case TargetToken.Field(value) => value - case _ => Utils.throwInternalError( - s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field' + case Seq(foo: CircuitTarget) => foo.name + case Seq(foo: ModuleTarget) => foo.module + case Seq(foo: InstanceTarget) => foo.instance + case Seq(foo: ReferenceTarget) => + foo.tokens.last match { + case TargetToken.Ref(value) => value + case TargetToken.Field(value) => value + case _ => + Utils.throwInternalError( + s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field' | This is indicative of a circuit that has not been run through LowerTypes.""", - Some(new MatchError(foo.serialize))) - } + Some(new MatchError(foo.serialize)) + ) + } } - case s@ Some(ax) => Utils.throwInternalError( - s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""", - Some(new MatchError(s))) + case s @ Some(ax) => + Utils.throwInternalError( + s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""", + Some(new MatchError(s)) + ) case None => name } @@ -280,27 +300,34 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans /* A reference to something inside this module */ case w: WRef => w.copy(name = maybeRename(w.name, r, Target.asTarget(t)(w))) /* This is either the subfield of an instance or a subfield of a memory reader/writer/readwriter */ - case w@ WSubField(expr, ref, _, _) => expr match { - /* This is an instance */ - case we@ WRef(inst, _, _, _) => - val tx = Target.asTarget(t)(we) - val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match { - case Right(a) => (a.ofModuleTarget.ref(ref), a) - case a@ Left(ref) => throw new FirrtlUserException( - s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'. - | This is indicative of a circuit that has not been run through LowerTypes.""", new MatchError(a)) - } - w.copy(we.copy(name=maybeRename(inst, r, iTarget)), name=maybeRename(ref, r, rTarget)) - /* This is a reader/writer/readwriter */ - case ws@ WSubField(expr, port, _, _) => expr match { - /* This is the memory. */ - case wr@ WRef(mem, _, _, _) => - w.copy( - expr=ws.copy( - expr=wr.copy(name=maybeRename(mem, r, t.ref(mem))), - name=maybeRename(port, r, t.ref(mem).field(port)))) + case w @ WSubField(expr, ref, _, _) => + expr match { + /* This is an instance */ + case we @ WRef(inst, _, _, _) => + val tx = Target.asTarget(t)(we) + val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match { + case Right(a) => (a.ofModuleTarget.ref(ref), a) + case a @ Left(ref) => + throw new FirrtlUserException( + s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'. + | This is indicative of a circuit that has not been run through LowerTypes.""", + new MatchError(a) + ) + } + w.copy(we.copy(name = maybeRename(inst, r, iTarget)), name = maybeRename(ref, r, rTarget)) + /* This is a reader/writer/readwriter */ + case ws @ WSubField(expr, port, _, _) => + expr match { + /* This is the memory. */ + case wr @ WRef(mem, _, _, _) => + w.copy( + expr = ws.copy( + expr = wr.copy(name = maybeRename(mem, r, t.ref(mem))), + name = maybeRename(port, r, t.ref(mem).field(port)) + ) + ) + } } - } case e => e.map(onExpression(_: ir.Expression, r, t)) } @@ -310,30 +337,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans * and readwriters. */ private def onStatement(s: ir.Statement, r: RenameDataStructure, t: ModuleTarget): ir.Statement = s match { - case decl: ir.IsDeclaration => decl match { - case decl@ WDefInstance(_, inst, mod, _) => - val modx = maybeRename(mod, r, t.circuitTarget.module(mod)) - val instx = doRename(inst, r, t.instOf(inst, mod)) - r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod)) - decl.copy(name = instx, module = modx) - case decl: ir.DefMemory => - val namex = doRename(decl.name, r, t.ref(decl.name)) - val tx = t.ref(decl.name) - r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters) - r.instanceMap(tx) = Left(tx) - decl - .copy( - name = namex, - readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))), - writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))), - readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw))) - ) - .map(onExpression(_: ir.Expression, r, t)) - case decl => - decl - .map(doRename(_: String, r, t.ref(decl.name))) - .map(onExpression(_: ir.Expression, r, t)) - } + case decl: ir.IsDeclaration => + decl match { + case decl @ WDefInstance(_, inst, mod, _) => + val modx = maybeRename(mod, r, t.circuitTarget.module(mod)) + val instx = doRename(inst, r, t.instOf(inst, mod)) + r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod)) + decl.copy(name = instx, module = modx) + case decl: ir.DefMemory => + val namex = doRename(decl.name, r, t.ref(decl.name)) + val tx = t.ref(decl.name) + r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters) + r.instanceMap(tx) = Left(tx) + decl + .copy( + name = namex, + readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))), + writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))), + readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw))) + ) + .map(onExpression(_: ir.Expression, r, t)) + case decl => + decl + .map(doRename(_: String, r, t.ref(decl.name))) + .map(onExpression(_: ir.Expression, r, t)) + } case s => s .map(onStatement(_: ir.Statement, r, t)) @@ -362,7 +390,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans */ val onName: String => String = t.circuit match { case `main` => maybeRename(_, r, moduleTarget) - case _ => doRename(_, r, moduleTarget) + case _ => doRename(_, r, moduleTarget) } m @@ -380,11 +408,11 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans * @return the circuit with manipulated names */ def run( - c: ir.Circuit, + c: ir.Circuit, renames: RenameMap, - block: Target => Boolean, - allow: Target => Boolean) - : ir.Circuit = { + block: Target => Boolean, + allow: Target => Boolean + ): ir.Circuit = { val t = CircuitTarget(c.main) /* If the circuit is a skip, return the original circuit. Otherwise, walk all the modules and rename them. Rename the @@ -427,8 +455,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans .toMap /* Replace the old modules making sure that they are still in the same order */ - c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))), - main = mainx) + c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))), main = mainx) } } @@ -436,18 +463,20 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans def execute(state: CircuitState): CircuitState = { val block = state.annotations.collect { - case ManipulateNamesBlocklistAnnotation(targetSeq, t) => t.getObject match { - case _: A => targetSeq - case _ => Nil - } + case ManipulateNamesBlocklistAnnotation(targetSeq, t) => + t.getObject match { + case _: A => targetSeq + case _ => Nil + } }.flatten.flatten.toSet val allow = { val allowx = state.annotations.collect { - case ManipulateNamesAllowlistAnnotation(targetSeq, t) => t.getObject match { - case _: A => targetSeq - case _ => Nil - } + case ManipulateNamesAllowlistAnnotation(targetSeq, t) => + t.getObject match { + case _: A => targetSeq + case _ => Nil + } }.flatten.flatten allowx match { @@ -461,17 +490,19 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans val annotationsx = state.annotations.flatMap { /* Consume blocklist annotations */ - case foo@ ManipulateNamesBlocklistAnnotation(_, t) => t.getObject match { - case _: A => None - case _ => Some(foo) - } + case foo @ ManipulateNamesBlocklistAnnotation(_, t) => + t.getObject match { + case _: A => None + case _ => Some(foo) + } /* Convert allowlist annotations to result annotations */ - case foo@ ManipulateNamesAllowlistAnnotation(a, t) => + case foo @ ManipulateNamesAllowlistAnnotation(a, t) => t.getObject match { - case _: A => (a, a.map(_.map(renames(_)).flatten)) match { - case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a)) - } - case _ => Some(foo) + case _: A => + (a, a.map(_.map(renames(_)).flatten)) match { + case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a)) + } + case _ => Some(foo) } case a => Some(a) } diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala index ff44afec..5532d0f0 100644 --- a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala +++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala @@ -1,4 +1,3 @@ - package firrtl package transforms @@ -34,17 +33,19 @@ trait DontTouchAllTargets extends HasDontTouches { self: Annotation => * DCE treats the component as a top-level sink of the circuit */ case class DontTouchAnnotation(target: ReferenceTarget) - extends SingleTargetAnnotation[ReferenceTarget] with DontTouchAllTargets { + extends SingleTargetAnnotation[ReferenceTarget] + with DontTouchAllTargets { def targets = Seq(target) def duplicate(n: ReferenceTarget) = this.copy(n) } object DontTouchAnnotation { - class DontTouchNotFoundException(module: String, component: String) extends PassException( - s"""|Target marked dontTouch ($module.$component) not found! - |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then - |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin - ) + class DontTouchNotFoundException(module: String, component: String) + extends PassException( + s"""|Target marked dontTouch ($module.$component) not found! + |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then + |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin + ) def errorNotFound(module: String, component: String) = throw new DontTouchNotFoundException(module, component) @@ -58,7 +59,6 @@ object DontTouchAnnotation { * * @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted */ -case class OptimizableExtModuleAnnotation(target: ModuleName) extends - SingleTargetAnnotation[ModuleName] { +case class OptimizableExtModuleAnnotation(target: ModuleName) extends SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(n) } diff --git a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala index da803837..97db0219 100644 --- a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala +++ b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala @@ -11,8 +11,10 @@ import firrtl.options.Dependency import scala.collection.mutable object PropagatePresetAnnotations { - val advice = "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset." - case class TreeCleanUpOrphanException(message: String) extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice") + val advice = + "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset." + case class TreeCleanUpOrphanException(message: String) + extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice") } /** Propagate PresetAnnotations to all children of targeted AsyncResets @@ -39,9 +41,11 @@ object PropagatePresetAnnotations { class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic]) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic] + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized @@ -52,7 +56,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { import PropagatePresetAnnotations._ private type TargetSet = mutable.HashSet[ReferenceTarget] - private type TargetMap = mutable.HashMap[ReferenceTarget,String] + private type TargetMap = mutable.HashMap[ReferenceTarget, String] private type TargetSetMap = mutable.HashMap[ReferenceTarget, TargetSet] private val toCleanUp = new TargetSet() @@ -71,7 +75,11 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { * @param presetAnnos all the annotations * @return updated annotations */ - private def propagate(cs: CircuitState, presetAnnos: Seq[PresetAnnotation], otherAnnos: Seq[Annotation]): AnnotationSeq = { + private def propagate( + cs: CircuitState, + presetAnnos: Seq[PresetAnnotation], + otherAnnos: Seq[Annotation] + ): AnnotationSeq = { val presets = presetAnnos.groupBy(_.target) // store all annotated asyncreset references val asyncToAnnotate = new TargetSet() @@ -85,34 +93,34 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { val circuitTarget = CircuitTarget(cs.circuit.main) /* - * WALK I PHASE 1 FUNCTIONS - */ + * WALK I PHASE 1 FUNCTIONS + */ /* Walk current module - * - process ports - * - store connections & entry points for PHASE 2 - * - process statements - * - Instances => record local instances for cross module AsyncReset Tree Buidling - * - Registers => store AsyncReset bound registers for PHASE 2 - * - Wire => store AsyncReset Connections & entry points for PHASE 2 - * - Connect => store AsyncReset Connections & entry points for PHASE 2 - * - * @param m module - */ + * - process ports + * - store connections & entry points for PHASE 2 + * - process statements + * - Instances => record local instances for cross module AsyncReset Tree Buidling + * - Registers => store AsyncReset bound registers for PHASE 2 + * - Wire => store AsyncReset Connections & entry points for PHASE 2 + * - Connect => store AsyncReset Connections & entry points for PHASE 2 + * + * @param m module + */ def processModule(m: DefModule): Unit = { val moduleTarget = circuitTarget.module(m.name) val localInstances = new TargetMap() /* Recursively process a given type - * Recursive on Bundle and Vector Type only - * Store Register and Connections for AsyncResetType - * @param tpe [[Type]] to be processed - * @param target [[ReferenceTarget]] associated to the tpe - * @param all Boolean indicating whether all subelements of the current - * tpe should also be stored as Annotated AsyncReset entry points - */ + * Recursive on Bundle and Vector Type only + * Store Register and Connections for AsyncResetType + * @param tpe [[Type]] to be processed + * @param target [[ReferenceTarget]] associated to the tpe + * @param all Boolean indicating whether all subelements of the current + * tpe should also be stored as Annotated AsyncReset entry points + */ def processType(tpe: Type, target: ReferenceTarget, all: Boolean): Unit = { - if(tpe == AsyncResetType){ + if (tpe == AsyncResetType) { asyncRegMap(target) = new TargetSet() asyncCoMap(target) = new TargetSet() if (presets.contains(target) || all) { @@ -121,14 +129,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } else { tpe match { case b: BundleType => - b.fields.foreach{ - (x: Field) => - val tar = target.field(x.name) - processType(x.tpe, tar, (presets.contains(tar) || all)) + b.fields.foreach { (x: Field) => + val tar = target.field(x.name) + processType(x.tpe, tar, (presets.contains(tar) || all)) } case v: VectorType => - for(i <- 0 until v.size) { + for (i <- 0 until v.size) { val tar = target.index(i) processType(v.tpe, tar, (presets.contains(tar) || all)) } @@ -143,19 +150,19 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } /* Recursively search for the ReferenceTarget of a given Expression - * @param e Targeted Expression - * @param ta Local ReferenceTarget of the Targeted Expression - * @return a ReferenceTarget in case of success, a GenericTarget otherwise - * @throws [[InternalError]] on unexpected recursive path return results - */ - def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false) : Target = { + * @param e Targeted Expression + * @param ta Local ReferenceTarget of the Targeted Expression + * @return a ReferenceTarget in case of success, a GenericTarget otherwise + * @throws [[InternalError]] on unexpected recursive path return results + */ + def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false): Target = { e match { case w: WRef => moduleTarget.ref(w.name) case w: WSubField => getRef(w.expr, ta, annoCo) match { case rt: ReferenceTarget => - if(localInstances.contains(rt)){ - val remote_ref = circuitTarget.module(localInstances(rt)) + if (localInstances.contains(rt)) { + val remote_ref = circuitTarget.module(localInstances(rt)) if (annoCo) asyncCoMap(ta) += rt.field(w.name) remote_ref.ref(w.name) @@ -163,7 +170,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { rt.field(w.name) } case remote_target => remote_target - } + } case w: WSubIndex => getRef(w.expr, ta, annoCo) match { case remote_target: ReferenceTarget => @@ -179,7 +186,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def processRegister(r: DefRegister): Unit = { getRef(r.reset, moduleTarget.ref(r.name), false) match { - case rt : ReferenceTarget => + case rt: ReferenceTarget => if (asyncRegMap.contains(rt)) { asyncRegMap(rt) += moduleTarget.ref(r.name) } @@ -189,12 +196,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } def processConnect(c: Connect): Unit = { - getRef(c.expr, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match { + getRef(c.expr, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match { case rhs: ReferenceTarget => if (presets.contains(rhs) || asyncRegMap.contains(rhs)) { getRef(c.loc, rhs, true) match { - case lhs : ReferenceTarget => - if(asyncRegMap.contains(rhs)){ + case lhs: ReferenceTarget => + if (asyncRegMap.contains(rhs)) { asyncRegMap(rhs) += lhs } else { asyncToAnnotate += lhs @@ -211,10 +218,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { val target = moduleTarget.ref(n.name) processType(n.value.tpe, target, presets.contains(target)) - getRef(n.value, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match { + getRef(n.value, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match { case rhs: ReferenceTarget => if (presets.contains(rhs) || asyncRegMap.contains(rhs)) { - if(asyncRegMap.contains(rhs)){ + if (asyncRegMap.contains(rhs)) { asyncRegMap(rhs) += target } else { asyncToAnnotate += target @@ -227,18 +234,18 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def processStatements(statement: Statement): Unit = { statement match { - case i : WDefInstance => + case i: WDefInstance => localInstances(moduleTarget.ref(i.name)) = i.module - case r : DefRegister => processRegister(r) - case w : DefWire => processWire(w) - case n : DefNode => processNode(n) - case c : Connect => processConnect(c) - case s => s.foreachStmt(processStatements) + case r: DefRegister => processRegister(r) + case w: DefWire => processWire(w) + case n: DefNode => processNode(n) + case c: Connect => processConnect(c) + case s => s.foreachStmt(processStatements) } } def processPorts(port: Port): Unit = { - if(port.tpe == AsyncResetType){ + if (port.tpe == AsyncResetType) { val target = moduleTarget.ref(port.name) asyncRegMap(target) = new TargetSet() asyncCoMap(target) = new TargetSet() @@ -263,17 +270,17 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { /** Annotate a given target and all its children according to the asyncCoMap */ def annotateCo(ta: ReferenceTarget): Unit = { - if (asyncCoMap.contains(ta)){ + if (asyncCoMap.contains(ta)) { toCleanUp += ta - asyncCoMap(ta) foreach( (t: ReferenceTarget) => { + asyncCoMap(ta).foreach((t: ReferenceTarget) => { toCleanUp += t }) } } /** Annotate all registers somehow connected to the orignal annotated async reset */ - def annotateRegSet(set: TargetSet) : Unit = { - set foreach ( (ta: ReferenceTarget) => { + def annotateRegSet(set: TargetSet): Unit = { + set.foreach((ta: ReferenceTarget) => { annotateCo(ta) if (asyncRegMap.contains(ta)) { annotateRegSet(asyncRegMap(ta)) @@ -287,8 +294,8 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { * Walk AsyncReset Trees with all Annotated AsyncReset as entry points * Annotate all leaf registers and intermediate wires, nodes, connectors along the way */ - def annotateAsyncSet(set: TargetSet) : Unit = { - set foreach ((t: ReferenceTarget) => { + def annotateAsyncSet(set: TargetSet): Unit = { + set.foreach((t: ReferenceTarget) => { annotateCo(t) if (asyncRegMap.contains(t)) annotateRegSet(asyncRegMap(t)) @@ -300,7 +307,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { */ cs.circuit.foreachModule(processModule) // PHASE 1 : Initialize - annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate + annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate otherAnnos ++ newAnnos } @@ -312,21 +319,21 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { * Clean-up useless reset tree (not relying on DCE) * Disconnect preset registers from their reset tree */ - private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq) : Circuit = { - val presetRegs = annos.collect {case a : PresetRegAnnotation => a}.groupBy(_.target) + private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq): Circuit = { + val presetRegs = annos.collect { case a: PresetRegAnnotation => a }.groupBy(_.target) val circuitTarget = CircuitTarget(circuit.main) def processModule(m: DefModule): DefModule = { val moduleTarget = circuitTarget.module(m.name) val localInstances = new TargetMap() - def getRef(e: Expression) : Target = { + def getRef(e: Expression): Target = { e match { case w: WRef => moduleTarget.ref(w.name) case w: WSubField => getRef(w.expr) match { case rt: ReferenceTarget => - if(localInstances.contains(rt)){ + if (localInstances.contains(rt)) { circuitTarget.module(localInstances(rt)).ref(w.name) } else { rt.field(w.name) @@ -341,14 +348,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { case DoPrim(op, args, _, _) => op match { case AsInterval | AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => getRef(args.head) - case _ => Target(None, None, Seq.empty) + case _ => Target(None, None, Seq.empty) } case _ => Target(None, None, Seq.empty) } } - - def processRegister(r: DefRegister) : DefRegister = { + def processRegister(r: DefRegister): DefRegister = { if (presetRegs.contains(moduleTarget.ref(r.name))) { r.copy(reset = UIntLiteral(0)) } else { @@ -356,7 +362,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } } - def processWire(w: DefWire) : Statement = { + def processWire(w: DefWire): Statement = { if (toCleanUp.contains(moduleTarget.ref(w.name))) { EmptyStmt } else { @@ -364,12 +370,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } } - def processNode(n: DefNode) : Statement = { + def processNode(n: DefNode): Statement = { if (toCleanUp.contains(moduleTarget.ref(n.name))) { EmptyStmt } else { getRef(n.value) match { - case rt : ReferenceTarget if(toCleanUp.contains(rt)) => + case rt: ReferenceTarget if (toCleanUp.contains(rt)) => throw TreeCleanUpOrphanException(s"Orphan (${moduleTarget.ref(n.name)}) the way.") case _ => n } @@ -380,7 +386,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { getRef(c.expr) match { case rhs: ReferenceTarget if (toCleanUp.contains(rhs)) => getRef(c.loc) match { - case lhs : ReferenceTarget if(!toCleanUp.contains(lhs)) => + case lhs: ReferenceTarget if (!toCleanUp.contains(lhs)) => throw TreeCleanUpOrphanException(s"Orphan ${lhs} connected deleted node $rhs.") case _ => EmptyStmt } @@ -388,7 +394,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { } } - def processInstance(i: WDefInstance) : WDefInstance = { + def processInstance(i: WDefInstance): WDefInstance = { localInstances(moduleTarget.ref(i.name)) = i.module val tpe = i.tpe match { case b: BundleType => @@ -401,12 +407,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def processStatements(statement: Statement): Statement = { statement match { - case i : WDefInstance => processInstance(i) - case r : DefRegister => processRegister(r) - case w : DefWire => processWire(w) - case n : DefNode => processNode(n) - case c : Connect => processConnect(c) - case s => s.mapStmt(processStatements) + case i: WDefInstance => processInstance(i) + case r: DefRegister => processRegister(r) + case w: DefWire => processWire(w) + case n: DefNode => processNode(n) + case c: Connect => processConnect(c) + case s => s.mapStmt(processStatements) } } @@ -422,10 +428,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration { def execute(state: CircuitState): CircuitState = { // Collect all user-defined PresetAnnotation - val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true ; case _ => false } + val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true; case _ => false } // No PresetAnnotation => no need to walk the IR - if (presets.isEmpty){ + if (presets.isEmpty) { state } else { // PHASE I - Propagate diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala index 840a3d99..ae3bc693 100644 --- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala +++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala @@ -21,10 +21,11 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames { * @return Some name if a rename occurred, None otherwise * @note prefix uniqueness is not respected */ - override def manipulate = (n: String, ns: Namespace) => keywords.contains(n) match { - case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords)) - case false => None - } + override def manipulate = (n: String, ns: Namespace) => + keywords.contains(n) match { + case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords)) + case false => None + } } @@ -32,14 +33,16 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames { class VerilogRename extends RemoveKeywordCollisions(v_keywords) { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals], - Dependency[ReplaceTruncatingArithmetic], - Dependency[InlineBitExtractionsTransform], - Dependency[InlineCastsTransform], - Dependency[LegalizeClocksTransform], - Dependency[FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup) ) + Seq( + Dependency[BlackBoxSourceHelper], + Dependency[FixAddingNegativeLiterals], + Dependency[ReplaceTruncatingArithmetic], + Dependency[InlineBitExtractionsTransform], + Dependency[InlineCastsTransform], + Dependency[LegalizeClocksTransform], + Dependency[FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup) + ) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala index 6b3a9d07..8736e21b 100644 --- a/src/main/scala/firrtl/transforms/RemoveReset.scala +++ b/src/main/scala/firrtl/transforms/RemoveReset.scala @@ -18,8 +18,7 @@ import scala.collection.{immutable, mutable} object RemoveReset extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize) ) + Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize)) override def optionalPrerequisites = Seq.empty @@ -58,7 +57,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { reg.copy(reset = Utils.zero, init = WRef(reg)) case reg @ DefRegister(_, rname, _, _, Utils.zero, _) => reg.copy(init = WRef(reg)) // canonicalize - case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType => + case reg @ DefRegister(info, rname, _, _, reset, init) if reset.tpe != AsyncResetType => // Add register reset to map resets(rname) = Reset(reset, init, info) reg.copy(reset = Utils.zero, init = WRef(reg)) @@ -68,7 +67,7 @@ object RemoveReset extends Transform with DependencyAPIMigration { // Use reg source locator for mux enable and true value since that's where they're defined val infox = MultiInfo(reset.info, reset.info, info) Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType)) - case other => other map onStmt + case other => other.map(onStmt) } } m.map(onStmt) diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index f692e513..31fa3b6f 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -8,11 +8,11 @@ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ -import firrtl.graph.{MutableDiGraph, CyclicException} +import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import scala.collection.mutable -import scala.util.{Try, Success, Failure} +import scala.util.{Failure, Success, Try} /** Replace wires with nodes in a legal, flow-forward order * @@ -23,11 +23,13 @@ import scala.util.{Try, Success, Failure} class RemoveWires extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.MidForm ++ - Seq( Dependency(passes.LowerTypes), - Dependency(passes.Legalize), - Dependency(passes.ResolveKinds), - Dependency(transforms.RemoveReset), - Dependency[transforms.CheckCombLoops] ) + Seq( + Dependency(passes.LowerTypes), + Dependency(passes.Legalize), + Dependency(passes.ResolveKinds), + Dependency(transforms.RemoveReset), + Dependency[transforms.CheckCombLoops] + ) override def optionalPrerequisites = Seq(Dependency[checks.CheckResets]) @@ -35,7 +37,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { override def invalidates(a: Transform) = a match { case passes.ResolveKinds => true - case _ => false + case _ => false } // Extract all expressions that are references to a Node, Wire, or Reg @@ -44,7 +46,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { val refs = mutable.ArrayBuffer.empty[WRef] def rec(e: Expression): Expression = { e match { - case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref + case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec) case _ => // Do nothing } @@ -57,7 +59,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { // Transform netlist into DefNodes private def getOrderedNodes( netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], - regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = { + regInfo: mutable.Map[WrappedExpression, DefRegister] + ): Try[Seq[Statement]] = { val digraph = new MutableDiGraph[WrappedExpression] for ((sink, (exprs, _)) <- netlist) { digraph.addVertex(sink) @@ -106,21 +109,22 @@ class RemoveWires extends Transform with DependencyAPIMigration { case reg: DefRegister => val resetDep = reg.reset.tpe match { case AsyncResetType => Some(reg.reset) - case _ => None + case _ => None } val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself regInfo(we(WRef(reg))) = reg netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info) case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires decls += decl - case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match { - case WireKind => - // Be sure to pad the rhs since nodes get their type from the rhs - val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) - val dinfo = wireInfo(lhs) - netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) - case _ => otherStmts += con // Other connections just pass through - } + case con @ Connect(cinfo, lhs, rhs) => + kind(lhs) match { + case WireKind => + // Be sure to pad the rhs since nodes get their type from the rhs + val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe) + val dinfo = wireInfo(lhs) + netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo)) + case _ => otherStmts += con // Other connections just pass through + } case invalid @ IsInvalid(info, expr) => kind(expr) match { case WireKind => @@ -146,8 +150,10 @@ class RemoveWires extends Transform with DependencyAPIMigration { // If we hit a CyclicException, just abort removing wires case Failure(c: CyclicException) => val problematicNode = c.node - logger.warn(s"Cycle found in module $name, " + - s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode") + logger.warn( + s"Cycle found in module $name, " + + s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode" + ) mod case Failure(other) => throw other } @@ -155,7 +161,6 @@ class RemoveWires extends Transform with DependencyAPIMigration { } } - def execute(state: CircuitState): CircuitState = state.copy(circuit = state.circuit.map(onModule)) } diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala index d37f8c39..16fd655a 100644 --- a/src/main/scala/firrtl/transforms/RenameModules.scala +++ b/src/main/scala/firrtl/transforms/RenameModules.scala @@ -44,7 +44,7 @@ class RenameModules extends Transform with DependencyAPIMigration { moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings)) val modulesx = state.circuit.modules.map { - case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings) + case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings) case ext: ExtModule => ext } diff --git a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala index a93087b9..14c84b91 100644 --- a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala +++ b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala @@ -80,8 +80,7 @@ object ReplaceTruncatingArithmetic { class ReplaceTruncatingArithmetic extends Transform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[BlackBoxSourceHelper], - Dependency[FixAddingNegativeLiterals] ) + Seq(Dependency[BlackBoxSourceHelper], Dependency[FixAddingNegativeLiterals]) override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala index a056c7da..7790d060 100644 --- a/src/main/scala/firrtl/transforms/SimplifyMems.scala +++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala @@ -33,12 +33,13 @@ class SimplifyMems extends Transform with DependencyAPIMigration { def onExpr(e: Expression): Expression = e.map(onExpr) match { case wr @ WRef(name, _, MemKind, _) if memAdapters.contains(name) => wr.copy(kind = WireKind) - case e => e + case e => e } def simplifyMem(mem: DefMemory): Statement = { val adapterDecl = DefWire(mem.info, mem.name, memType(mem)) - val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType)) + val simpleMemDecl = + mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType)) val oldRT = mTarget.ref(mem.name) val adapterConnects = memType(simpleMemDecl).fields.flatMap { case Field(pName, Flip, pType: BundleType) => @@ -63,8 +64,10 @@ class SimplifyMems extends Transform with DependencyAPIMigration { def canSimplify(mem: DefMemory) = mem.dataType match { case at: AggregateType => - val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask"))) - val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask"))) + val wMasks = + mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask"))) + val rwMasks = + mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask"))) (wMasks ++ rwMasks).flatten.isEmpty case _ => false } diff --git a/src/main/scala/firrtl/transforms/TopWiring.scala b/src/main/scala/firrtl/transforms/TopWiring.scala index f5a5e2a3..b35fed22 100644 --- a/src/main/scala/firrtl/transforms/TopWiring.scala +++ b/src/main/scala/firrtl/transforms/TopWiring.scala @@ -4,7 +4,7 @@ package TopWiring import firrtl._ import firrtl.ir._ -import firrtl.passes.{InferTypes, LowerTypes, ResolveKinds, ResolveFlows, ExpandConnects} +import firrtl.passes.{ExpandConnects, InferTypes, LowerTypes, ResolveFlows, ResolveKinds} import firrtl.annotations._ import firrtl.Mappers._ import firrtl.analyses.InstanceKeyGraph @@ -13,22 +13,21 @@ import firrtl.options.Dependency import collection.mutable -/** Annotation for optional output files, and what directory to put those files in (absolute path) **/ -case class TopWiringOutputFilesAnnotation(dirName: String, - outputFunction: (String,Seq[((ComponentName, Type, Boolean, - Seq[String],String), Int)], - CircuitState) => CircuitState) extends NoTargetAnnotation +/** Annotation for optional output files, and what directory to put those files in (absolute path) * */ +case class TopWiringOutputFilesAnnotation( + dirName: String, + outputFunction: (String, Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + CircuitState) => CircuitState) + extends NoTargetAnnotation /** Annotation for indicating component to be wired, and what prefix to add to the ports that are generated */ -case class TopWiringAnnotation(target: ComponentName, prefix: String) extends - SingleTargetAnnotation[ComponentName] { +case class TopWiringAnnotation(target: ComponentName, prefix: String) extends SingleTargetAnnotation[ComponentName] { def duplicate(n: ComponentName) = this.copy(target = n) } - /** Punch out annotated ports out to the toplevel of the circuit. - This also has an option to pass a function as a parmeter to generate - custom output files as a result of the additional ports + * This also has an option to pass a function as a parmeter to generate + * custom output files as a result of the additional ports * @note This *does* work for deduped modules */ class TopWiringTransform extends Transform with DependencyAPIMigration { @@ -39,116 +38,133 @@ class TopWiringTransform extends Transform with DependencyAPIMigration { override def invalidates(a: Transform): Boolean = a match { case InferTypes | ResolveKinds | ResolveFlows | ExpandConnects => true - case _ => false + case _ => false } type InstPath = Seq[String] /** Get the names of the targets that need to be wired */ private def getSourceNames(state: CircuitState): Map[ComponentName, String] = { - state.annotations.collect { case TopWiringAnnotation(srcname,prefix) => - (srcname -> prefix) }.toMap.withDefaultValue("") + state.annotations.collect { + case TopWiringAnnotation(srcname, prefix) => + (srcname -> prefix) + }.toMap.withDefaultValue("") } - /** Get the names of the modules which include the targets that need to be wired */ private def getSourceModNames(state: CircuitState): Seq[String] = { - state.annotations.collect { case TopWiringAnnotation(ComponentName(_,ModuleName(srcmodname, _)),_) => srcmodname } + state.annotations.collect { case TopWiringAnnotation(ComponentName(_, ModuleName(srcmodname, _)), _) => srcmodname } } - - /** Get the Type of each wire to be connected * * Find the definition of each wire in sourceList, and get the type and whether or not it's a port * Update the results in sourceMap */ - private def getSourceTypes(sourceList: Map[ComponentName, String], - sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], - currentmodule: ModuleName, state: CircuitState)(s: Statement): Statement = s match { + private def getSourceTypes( + sourceList: Map[ComponentName, String], + sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], + currentmodule: ModuleName, + state: CircuitState + )(s: Statement + ): Statement = s match { // If target wire, add name and size to to sourceMap case w: IsDeclaration => if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) { - val (isport, tpe, prefix) = w match { - case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name,currentmodule))) - case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") - } - sourceMap.get(currentmodule.name) match { - case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) => - sourceMap.update(currentmodule.name, xs :+( - (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) )) - case None => - sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule), - tpe, isport ,Seq[String](w.name), prefix)) - } + val (isport, tpe, prefix) = w match { + case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name, currentmodule))) + case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") + } + sourceMap.get(currentmodule.name) match { + case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) => + sourceMap.update( + currentmodule.name, + xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)) + ) + case None => + sourceMap(currentmodule.name) = Seq( + (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix) + ) + } } w // Return argument unchanged (ok because DefWire has no Statement children) // If not, apply to all children Statement - case _ => s map getSourceTypes(sourceList, sourceMap, currentmodule, state) + case _ => s.map(getSourceTypes(sourceList, sourceMap, currentmodule, state)) } - - /** Get the Type of each port to be connected * * Similar to getSourceTypes, but specifically for ports since they are not found in statements. * Find the definition of each port in sourceList, and get the type and whether or not it's a port * Update the results in sourceMap */ - private def getSourceTypesPorts(sourceList: Map[ComponentName, String], sourceMap: mutable.Map[String, - Seq[(ComponentName, Type, Boolean, InstPath, String)]], - currentmodule: ModuleName, state: CircuitState)(s: Port): CircuitState = s match { + private def getSourceTypesPorts( + sourceList: Map[ComponentName, String], + sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], + currentmodule: ModuleName, + state: CircuitState + )(s: Port + ): CircuitState = s match { // If target port, add name and size to to sourceMap case w: IsDeclaration => if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) { - val (isport, tpe, prefix) = w match { - case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule))) - case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") - } - sourceMap.get(currentmodule.name) match { - case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) => - sourceMap.update(currentmodule.name, xs :+( - (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) )) - case None => - sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule), - tpe, isport ,Seq[String](w.name), prefix)) - } + val (isport, tpe, prefix) = w match { + case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule))) + case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}") + } + sourceMap.get(currentmodule.name) match { + case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) => + sourceMap.update( + currentmodule.name, + xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)) + ) + case None => + sourceMap(currentmodule.name) = Seq( + (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix) + ) + } } state // Return argument unchanged (ok because DefWire has no Statement children) // If not, apply to all children Statement case _ => state } - /** Create a map of Module name to target wires under this module * * These paths are relative but cross module (they refer down through instance hierarchy) */ - private def getSourcesMap(state: CircuitState): Map[String,Seq[(ComponentName, Type, Boolean, InstPath, String)]] = { + private def getSourcesMap(state: CircuitState): Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = { val sSourcesModNames = getSourceModNames(state) val sSourcesNames = getSourceNames(state) val instGraph = firrtl.analyses.InstanceKeyGraph(state.circuit) - val cMap = instGraph.getChildInstances.map{ case (m, wdis) => - (m -> wdis.map{ case wdi => (wdi.name, wdi.module) }.toSeq) }.toMap + val cMap = instGraph.getChildInstances.map { + case (m, wdis) => + (m -> wdis.map { case wdi => (wdi.name, wdi.module) }.toSeq) + }.toMap val topSort = instGraph.moduleOrder.reverse // Map of component name to relative instance paths that result in a debug wire val sourcemods: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = mutable.Map(sSourcesModNames.map(_ -> Seq()): _*) - state.circuit.modules.foreach { m => m map - getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) } - state.circuit.modules.foreach { m => m.ports.foreach { - p => Seq(p) map - getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) }} + state.circuit.modules.foreach { m => + m.map(getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state)) + } + state.circuit.modules.foreach { m => + m.ports.foreach { p => + Seq(p).map( + getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state) + ) + } + } for (mod <- topSort) { - val seqChildren: Seq[(ComponentName,Type,Boolean,InstPath,String)] = cMap(mod.name).flatMap { + val seqChildren: Seq[(ComponentName, Type, Boolean, InstPath, String)] = cMap(mod.name).flatMap { case (inst, module) => - sourcemods.get(module).map( _.map { case (a,b,c,path,p) => (a,b,c, inst +: path, p)}) + sourcemods.get(module).map(_.map { case (a, b, c, path, p) => (a, b, c, inst +: path, p) }) }.flatten if (seqChildren.nonEmpty) { sourcemods(mod.name) = sourcemods.getOrElse(mod.name, Seq()) ++ seqChildren @@ -158,108 +174,113 @@ class TopWiringTransform extends Transform with DependencyAPIMigration { sourcemods.toMap } - - /** Process a given DefModule * * For Modules that contain or are in the parent hierarchy to modules containing target wires * 1. Add ports for each target wire this module is parent to * 2. Connect these ports to ports of instances that are parents to some number of target wires */ - private def onModule(sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], - portnamesmap : mutable.Map[String,String], - instgraph : firrtl.analyses.InstanceKeyGraph, - namespacemap : Map[String, Namespace]) - (module: DefModule): DefModule = { + private def onModule( + sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]], + portnamesmap: mutable.Map[String, String], + instgraph: firrtl.analyses.InstanceKeyGraph, + namespacemap: Map[String, Namespace] + )(module: DefModule + ): DefModule = { val namespace = namespacemap(module.name) sources.get(module.name) match { case Some(p) => - val newPorts = p.map{ case (ComponentName(cname,_), tpe, _ , path, prefix) => { - val newportname = portnamesmap.get(prefix + path.mkString("_")) match { - case Some(pn) => pn - case None => { - val npn = namespace.newName(prefix + path.mkString("_")) - portnamesmap(prefix + path.mkString("_")) = npn - npn - } + val newPorts = p.map { + case (ComponentName(cname, _), tpe, _, path, prefix) => { + val newportname = portnamesmap.get(prefix + path.mkString("_")) match { + case Some(pn) => pn + case None => { + val npn = namespace.newName(prefix + path.mkString("_")) + portnamesmap(prefix + path.mkString("_")) = npn + npn } - Port(NoInfo, newportname, Output, tpe) - } } + } + Port(NoInfo, newportname, Output, tpe) + } + } // Add connections to Module val childInstances = instgraph.getChildInstances.toMap module match { case m: Module => - val connections: Seq[Connect] = p.map { case (ComponentName(cname,_), _, _ , path, prefix) => + val connections: Seq[Connect] = p.map { + case (ComponentName(cname, _), _, _, path, prefix) => val modRef = portnamesmap.get(prefix + path.mkString("_")) match { - case Some(pn) => WRef(pn) - case None => { - portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_")) - WRef(portnamesmap(prefix + path.mkString("_"))) - } + case Some(pn) => WRef(pn) + case None => { + portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_")) + WRef(portnamesmap(prefix + path.mkString("_"))) + } } path.size match { - case 1 => { - val leafRef = WRef(path.head.mkString("")) - Connect(NoInfo, modRef, leafRef) - } - case _ => { - val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match { - case Some(ipn) => ipn - case None => { - val instmod = childInstances(module.name).collectFirst { - case wdi if wdi.name == path.head => wdi.module}.get - val instnamespace = namespacemap(instmod) - portnamesmap(prefix + path.tail.mkString("_")) = - instnamespace.newName(prefix + path.tail.mkString("_")) - portnamesmap(prefix + path.tail.mkString("_")) - } - } - val instRef = WSubField(WRef(path.head), instportname) - Connect(NoInfo, modRef, instRef) + case 1 => { + val leafRef = WRef(path.head.mkString("")) + Connect(NoInfo, modRef, leafRef) + } + case _ => { + val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match { + case Some(ipn) => ipn + case None => { + val instmod = childInstances(module.name).collectFirst { + case wdi if wdi.name == path.head => wdi.module + }.get + val instnamespace = namespacemap(instmod) + portnamesmap(prefix + path.tail.mkString("_")) = + instnamespace.newName(prefix + path.tail.mkString("_")) + portnamesmap(prefix + path.tail.mkString("_")) + } + } + val instRef = WSubField(WRef(path.head), instportname) + Connect(NoInfo, modRef, instRef) } } } - m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections )) + m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections)) case e: ExtModule => e.copy(ports = e.ports ++ newPorts) - } + } case None => module // unchanged if no paths } } - /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */ - def topWiringDummyOutputFilesFunction(dir: String, - mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], - state: CircuitState): CircuitState = { - state + /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */ + def topWiringDummyOutputFilesFunction( + dir: String, + mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], + state: CircuitState + ): CircuitState = { + state } - def execute(state: CircuitState): CircuitState = { - val outputTuples: Seq[(String, - (String,Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], - CircuitState) => CircuitState)] = state.annotations.collect { - case TopWiringOutputFilesAnnotation(td,of) => (td, of) } + val outputTuples: Seq[ + (String, (String, Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], CircuitState) => CircuitState) + ] = state.annotations.collect { + case TopWiringOutputFilesAnnotation(td, of) => (td, of) + } // Do actual work of this transform val sources = getSourcesMap(state) val (nstate, nmappings) = if (sources.nonEmpty) { - val portnamesmap: mutable.Map[String,String] = mutable.Map() + val portnamesmap: mutable.Map[String, String] = mutable.Map() val instgraph = InstanceKeyGraph(state.circuit) - val namespacemap = state.circuit.modules.map{ case m => (m.name -> Namespace(m)) }.toMap - val modulesx = state.circuit.modules map onModule(sources, portnamesmap, instgraph, namespacemap) + val namespacemap = state.circuit.modules.map { case m => (m.name -> Namespace(m)) }.toMap + val modulesx = state.circuit.modules.map(onModule(sources, portnamesmap, instgraph, namespacemap)) val newCircuit = state.circuit.copy(modules = modulesx) val mappings = sources(state.circuit.main).zipWithIndex val annosx = state.annotations.filter { case _: TopWiringAnnotation => false - case _ => true + case _ => true } (state.copy(circuit = newCircuit, annotations = annosx), mappings) - } - else { (state, List.empty) } + } else { (state, List.empty) } //Generate output files based on the mapping. outputTuples.map { case (dir, outputfunction) => outputfunction(dir, nmappings, nstate) } nstate diff --git a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala index 7370fcfb..cdbee495 100644 --- a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala +++ b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala @@ -1,4 +1,3 @@ - package firrtl.transforms.formal import firrtl.ir.{Circuit, Formal, Statement, Verification} @@ -7,7 +6,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, Transform} import firrtl.annotations.NoTargetAnnotation import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption} - /** * Assert Submodule Assumptions * @@ -16,12 +14,13 @@ import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption} * overly restrictive assume in a child module can prevent the model checker * from searching valid inputs and states in the parent module. */ -class AssertSubmoduleAssumptions extends Transform - with RegisteredTransform - with DependencyAPIMigration - with PreservesAll[Transform] { +class AssertSubmoduleAssumptions + extends Transform + with RegisteredTransform + with DependencyAPIMigration + with PreservesAll[Transform] { - override def prerequisites: Seq[TransformDependency] = Seq.empty + override def prerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisiteOf: Seq[TransformDependency] = firrtl.stage.Forms.MidEmitters @@ -29,9 +28,10 @@ class AssertSubmoduleAssumptions extends Transform val options = Seq( new ShellOption[Unit]( longOption = "no-asa", - toAnnotationSeq = (_: Unit) => Seq( - DontAssertSubmoduleAssumptionsAnnotation), - helpText = "Disable assert submodule assumptions" ) ) + toAnnotationSeq = (_: Unit) => Seq(DontAssertSubmoduleAssumptionsAnnotation), + helpText = "Disable assert submodule assumptions" + ) + ) def assertAssumption(s: Statement): Statement = s match { case Verification(Formal.Assume, info, clk, cond, en, msg) => @@ -50,8 +50,7 @@ class AssertSubmoduleAssumptions extends Transform } def execute(state: CircuitState): CircuitState = { - val noASA = state.annotations.contains( - DontAssertSubmoduleAssumptionsAnnotation) + val noASA = state.annotations.contains(DontAssertSubmoduleAssumptionsAnnotation) if (noASA) { logger.info("Skipping assert submodule assumptions") state diff --git a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala index ddead331..5928c79c 100644 --- a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala +++ b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala @@ -14,10 +14,8 @@ import firrtl.options.Dependency object ConvertAsserts extends Transform with DependencyAPIMigration { override def prerequisites = Nil override def optionalPrerequisites = Nil - override def optionalPrerequisiteOf = Seq( - Dependency[VerilogEmitter], - Dependency[MinimumVerilogEmitter], - Dependency[RemoveVerificationStatements]) + override def optionalPrerequisiteOf = + Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[RemoveVerificationStatements]) override def invalidates(a: Transform): Boolean = false @@ -28,7 +26,7 @@ object ConvertAsserts extends Transform with DependencyAPIMigration { val stop = Stop(i, 1, clk, gatedNPred) msg match { case StringLit("") => stop - case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop) + case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop) } case s => s.mapStmt(convertAsserts) } diff --git a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala index 72890c07..1e6d2c72 100644 --- a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala +++ b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala @@ -1,4 +1,3 @@ - package firrtl.transforms.formal import firrtl.ir.{Circuit, EmptyStmt, Statement, Verification} @@ -6,7 +5,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, MinimumVerilogEmitter, Tran import firrtl.options.{Dependency, PreservesAll, StageUtils} import firrtl.stage.TransformManager.TransformDependency - /** * Remove Verification Statements * @@ -14,15 +12,12 @@ import firrtl.stage.TransformManager.TransformDependency * This is intended to be required by the Verilog emitter to ensure compatibility * with the Verilog 2001 standard. */ -class RemoveVerificationStatements extends Transform - with DependencyAPIMigration - with PreservesAll[Transform] { +class RemoveVerificationStatements extends Transform with DependencyAPIMigration with PreservesAll[Transform] { - override def prerequisites: Seq[TransformDependency] = Seq.empty + override def prerequisites: Seq[TransformDependency] = Seq.empty override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency(ConvertAsserts)) override def optionalPrerequisiteOf: Seq[TransformDependency] = - Seq( Dependency[VerilogEmitter], - Dependency[MinimumVerilogEmitter]) + Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter]) private var removedCounter = 0 @@ -43,11 +38,13 @@ class RemoveVerificationStatements extends Transform def execute(state: CircuitState): CircuitState = { val newState = state.copy(circuit = run(state.circuit)) if (removedCounter > 0) { - StageUtils.dramaticWarning(s"$removedCounter verification statements " + - "(assert, assume or cover) " + - "were removed when compiling to Verilog because the basic Verilog " + - "standard does not support them. If this was not intended, compile " + - "to System Verilog instead using the `-X sverilog` compiler flag.") + StageUtils.dramaticWarning( + s"$removedCounter verification statements " + + "(assert, assume or cover) " + + "were removed when compiling to Verilog because the basic Verilog " + + "standard does not support them. If this was not intended, compile " + + "to System Verilog instead using the `-X sverilog` compiler flag." + ) } newState } diff --git a/src/main/scala/firrtl/traversals/Foreachers.scala b/src/main/scala/firrtl/traversals/Foreachers.scala index fdb02399..dee74d63 100644 --- a/src/main/scala/firrtl/traversals/Foreachers.scala +++ b/src/main/scala/firrtl/traversals/Foreachers.scala @@ -15,19 +15,19 @@ object Foreachers { } private object StmtForMagnet { implicit def forStmt(f: Statement => Unit): StmtForMagnet = new StmtForMagnet { - def foreach(stmt: Statement): Unit = stmt foreachStmt f + def foreach(stmt: Statement): Unit = stmt.foreachStmt(f) } implicit def forExp(f: Expression => Unit): StmtForMagnet = new StmtForMagnet { - def foreach(stmt: Statement): Unit = stmt foreachExpr f + def foreach(stmt: Statement): Unit = stmt.foreachExpr(f) } implicit def forType(f: Type => Unit): StmtForMagnet = new StmtForMagnet { - def foreach(stmt: Statement) : Unit = stmt foreachType f + def foreach(stmt: Statement): Unit = stmt.foreachType(f) } implicit def forString(f: String => Unit): StmtForMagnet = new StmtForMagnet { - def foreach(stmt: Statement): Unit = stmt foreachString f + def foreach(stmt: Statement): Unit = stmt.foreachString(f) } implicit def forInfo(f: Info => Unit): StmtForMagnet = new StmtForMagnet { - def foreach(stmt: Statement): Unit = stmt foreachInfo f + def foreach(stmt: Statement): Unit = stmt.foreachInfo(f) } } implicit class StmtForeach(val _stmt: Statement) extends AnyVal { @@ -41,13 +41,13 @@ object Foreachers { } private object ExprForMagnet { implicit def forExpr(f: Expression => Unit): ExprForMagnet = new ExprForMagnet { - def foreach(expr: Expression): Unit = expr foreachExpr f + def foreach(expr: Expression): Unit = expr.foreachExpr(f) } implicit def forType(f: Type => Unit): ExprForMagnet = new ExprForMagnet { - def foreach(expr: Expression): Unit = expr foreachType f + def foreach(expr: Expression): Unit = expr.foreachType(f) } implicit def forWidth(f: Width => Unit): ExprForMagnet = new ExprForMagnet { - def foreach(expr: Expression): Unit = expr foreachWidth f + def foreach(expr: Expression): Unit = expr.foreachWidth(f) } } implicit class ExprForeach(val _expr: Expression) extends AnyVal { @@ -60,10 +60,10 @@ object Foreachers { } private object TypeForMagnet { implicit def forType(f: Type => Unit): TypeForMagnet = new TypeForMagnet { - def foreach(tpe: Type): Unit = tpe foreachType f + def foreach(tpe: Type): Unit = tpe.foreachType(f) } implicit def forWidth(f: Width => Unit): TypeForMagnet = new TypeForMagnet { - def foreach(tpe: Type): Unit = tpe foreachWidth f + def foreach(tpe: Type): Unit = tpe.foreachWidth(f) } } implicit class TypeForeach(val _tpe: Type) extends AnyVal { @@ -76,16 +76,16 @@ object Foreachers { } private object ModuleForMagnet { implicit def forStmt(f: Statement => Unit): ModuleForMagnet = new ModuleForMagnet { - def foreach(module: DefModule): Unit = module foreachStmt f + def foreach(module: DefModule): Unit = module.foreachStmt(f) } implicit def forPorts(f: Port => Unit): ModuleForMagnet = new ModuleForMagnet { - def foreach(module: DefModule): Unit = module foreachPort f + def foreach(module: DefModule): Unit = module.foreachPort(f) } implicit def forString(f: String => Unit): ModuleForMagnet = new ModuleForMagnet { - def foreach(module: DefModule): Unit = module foreachString f + def foreach(module: DefModule): Unit = module.foreachString(f) } implicit def forInfo(f: Info => Unit): ModuleForMagnet = new ModuleForMagnet { - def foreach(module: DefModule): Unit = module foreachInfo f + def foreach(module: DefModule): Unit = module.foreachInfo(f) } } implicit class ModuleForeach(val _module: DefModule) extends AnyVal { @@ -98,13 +98,13 @@ object Foreachers { } private object CircuitForMagnet { implicit def forModules(f: DefModule => Unit): CircuitForMagnet = new CircuitForMagnet { - def foreach(circuit: Circuit): Unit = circuit foreachModule f + def foreach(circuit: Circuit): Unit = circuit.foreachModule(f) } implicit def forString(f: String => Unit): CircuitForMagnet = new CircuitForMagnet { - def foreach(circuit: Circuit): Unit = circuit foreachString f + def foreach(circuit: Circuit): Unit = circuit.foreachString(f) } implicit def forInfo(f: Info => Unit): CircuitForMagnet = new CircuitForMagnet { - def foreach(circuit: Circuit): Unit = circuit foreachInfo f + def foreach(circuit: Circuit): Unit = circuit.foreachInfo(f) } } implicit class CircuitForeach(val _circuit: Circuit) extends AnyVal { diff --git a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala index 1557bb0c..2ac5b035 100644 --- a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala +++ b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala @@ -14,6 +14,7 @@ import firrtl.FileUtils import scala.sys.process.{ProcessBuilder, ProcessLogger, _} object BackendCompilationUtilities extends LazyLogging { + /** Parent directory for tests */ lazy val TestDirectory = new File("test_run_dir") @@ -69,12 +70,7 @@ object BackendCompilationUtilities extends LazyLogging { * @return true if compiler completed successfully */ def firrtlToVerilog(prefix: String, dir: File): ProcessBuilder = { - Process( - Seq("firrtl", - "-i", s"$prefix.fir", - "-o", s"$prefix.v", - "-X", "verilog"), - dir) + Process(Seq("firrtl", "-i", s"$prefix.fir", "-o", s"$prefix.v", "-X", "verilog"), dir) } /** Generates a Verilator invocation to convert Verilog sources to C++ @@ -103,11 +99,11 @@ object BackendCompilationUtilities extends LazyLogging { * @param extraCmdLineArgs list of additional command line arguments */ def verilogToCpp( - dutFile: String, - dir: File, - vSources: Seq[File], - cppHarness: File, - suppressVcd: Boolean = false, + dutFile: String, + dir: File, + vSources: Seq[File], + cppHarness: File, + suppressVcd: Boolean = false, resourceFileName: String = firrtl.transforms.BlackBoxSourceHelper.defaultFileListName, extraCmdLineArgs: Seq[String] = Seq.empty ): ProcessBuilder = { @@ -116,10 +112,9 @@ object BackendCompilationUtilities extends LazyLogging { val list_file = new File(dir, resourceFileName) val blackBoxVerilogList = { - if(list_file.exists()) { + if (list_file.exists()) { Seq("-f", list_file.getAbsolutePath) - } - else { + } else { Seq.empty[String] } } @@ -128,37 +123,39 @@ object BackendCompilationUtilities extends LazyLogging { // If it's in the main .f resource file, don't explicitly include it on the command line. // Build a set of canonical file paths to use as a filter to exclude already included additional Verilog sources. val blackBoxHelperFiles: Set[String] = { - if(list_file.exists()) { + if (list_file.exists()) { FileUtils.getLines(list_file).toSet - } - else { + } else { Set.empty } } val vSourcesFiltered = vSources.filterNot(f => blackBoxHelperFiles.contains(f.getCanonicalPath)) val command = Seq( "verilator", - "--cc", s"${dir.getAbsolutePath}/$dutFile.v" + "--cc", + s"${dir.getAbsolutePath}/$dutFile.v" ) ++ extraCmdLineArgs ++ blackBoxVerilogList ++ vSourcesFiltered.flatMap(file => Seq("-v", file.getCanonicalPath)) ++ - Seq("--assert", - "-Wno-fatal", - "-Wno-WIDTH", - "-Wno-STMTDLY" - ) ++ - { if(suppressVcd) { Seq.empty } else { Seq("--trace")} } ++ + Seq("--assert", "-Wno-fatal", "-Wno-WIDTH", "-Wno-STMTDLY") ++ { + if (suppressVcd) { Seq.empty } + else { Seq("--trace") } + } ++ Seq( "-O1", - "--top-module", topModule, + "--top-module", + topModule, "+define+TOP_TYPE=V" + dutFile, s"+define+PRINTF_COND=!$topModule.reset", s"+define+STOP_COND=!$topModule.reset", "-CFLAGS", s"""-Wno-undefined-bool-conversion -O1 -DTOP_TYPE=V$dutFile -DVL_USER_FINISH -include V$dutFile.h""", - "-Mdir", dir.getAbsolutePath, - "--exe", cppHarness.getAbsolutePath) + "-Mdir", + dir.getAbsolutePath, + "--exe", + cppHarness.getAbsolutePath + ) logger.info(s"${command.mkString(" ")}") command } @@ -167,17 +164,20 @@ object BackendCompilationUtilities extends LazyLogging { Seq("make", "-C", dir.toString, "-j", "-f", s"V$prefix.mk", s"V$prefix") def executeExpectingFailure( - prefix: String, - dir: File, - assertionMsg: String = ""): Boolean = { + prefix: String, + dir: File, + assertionMsg: String = "" + ): Boolean = { var triggered = false val assertionMessageSupplied = assertionMsg != "" val e = Process(s"./V$prefix", dir) ! - ProcessLogger(line => { - triggered = triggered || (assertionMessageSupplied && line.contains(assertionMsg)) - logger.info(line) - }, - logger.warn(_)) + ProcessLogger( + line => { + triggered = triggered || (assertionMessageSupplied && line.contains(assertionMsg)) + logger.info(line) + }, + logger.warn(_) + ) // Fail if a line contained an assertion or if we get a non-zero exit code // or, we get a SIGABRT (assertion failure) and we didn't provide a specific assertion message triggered || (e != 0 && (e != 134 || !assertionMessageSupplied)) @@ -201,10 +201,7 @@ object BackendCompilationUtilities extends LazyLogging { * @param timesteps the maximum number of timesteps for Yosys equivalence * checking to consider */ - def yosysExpectSuccess(customTop: String, - referenceTop: String, - testDir: File, - timesteps: Int = 1): Boolean = { + def yosysExpectSuccess(customTop: String, referenceTop: String, testDir: File, timesteps: Int = 1): Boolean = { !yosysExpectFailure(customTop, referenceTop, testDir, timesteps) } @@ -222,31 +219,26 @@ object BackendCompilationUtilities extends LazyLogging { * @param timesteps the maximum number of timesteps for Yosys equivalence * checking to consider */ - def yosysExpectFailure(customTop: String, - referenceTop: String, - testDir: File, - timesteps: Int = 1): Boolean = { + def yosysExpectFailure(customTop: String, referenceTop: String, testDir: File, timesteps: Int = 1): Boolean = { val scriptFileName = s"${testDir.getAbsolutePath}/yosys_script" val yosysScriptWriter = new PrintWriter(scriptFileName) - yosysScriptWriter.write( - s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v - |prep -flatten -top $customTop; proc; opt; memory - |design -stash custom - |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v - |prep -flatten -top $referenceTop; proc; opt; memory - |design -stash reference - |design -copy-from custom -as custom $customTop - |design -copy-from reference -as reference $referenceTop - |equiv_make custom reference equiv - |hierarchy -top equiv - |prep -flatten -top equiv - |clean -purge - |equiv_simple -seq $timesteps - |equiv_induct -seq $timesteps - |equiv_status -assert - """ - .stripMargin) + yosysScriptWriter.write(s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v + |prep -flatten -top $customTop; proc; opt; memory + |design -stash custom + |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v + |prep -flatten -top $referenceTop; proc; opt; memory + |design -stash reference + |design -copy-from custom -as custom $customTop + |design -copy-from reference -as reference $referenceTop + |equiv_make custom reference equiv + |hierarchy -top equiv + |prep -flatten -top equiv + |clean -purge + |equiv_simple -seq $timesteps + |equiv_induct -seq $timesteps + |equiv_status -assert + """.stripMargin) yosysScriptWriter.close() val resultFileName = testDir.getAbsolutePath + "/yosys_results" @@ -258,28 +250,32 @@ object BackendCompilationUtilities extends LazyLogging { @deprecated("use object BackendCompilationUtilities", "1.3") trait BackendCompilationUtilities extends LazyLogging { lazy val TestDirectory = BackendCompilationUtilities.TestDirectory - def timeStamp: String = BackendCompilationUtilities.timeStamp + def timeStamp: String = BackendCompilationUtilities.timeStamp def loggingProcessLogger: ProcessLogger = BackendCompilationUtilities.loggingProcessLogger - def copyResourceToFile(name: String, file: File): Unit = BackendCompilationUtilities.copyResourceToFile(name, file) + def copyResourceToFile(name: String, file: File): Unit = BackendCompilationUtilities.copyResourceToFile(name, file) def createTestDirectory(testName: String): File = BackendCompilationUtilities.createTestDirectory(testName) - def makeHarness(template: String => String, post: String)(f: File): File = BackendCompilationUtilities.makeHarness(template, post)(f) - def firrtlToVerilog(prefix: String, dir: File): ProcessBuilder = BackendCompilationUtilities.firrtlToVerilog(prefix, dir) + def makeHarness(template: String => String, post: String)(f: File): File = + BackendCompilationUtilities.makeHarness(template, post)(f) + def firrtlToVerilog(prefix: String, dir: File): ProcessBuilder = + BackendCompilationUtilities.firrtlToVerilog(prefix, dir) def verilogToCpp( - dutFile: String, - dir: File, - vSources: Seq[File], - cppHarness: File, - suppressVcd: Boolean = false, - resourceFileName: String = firrtl.transforms.BlackBoxSourceHelper.defaultFileListName - ): ProcessBuilder = { + dutFile: String, + dir: File, + vSources: Seq[File], + cppHarness: File, + suppressVcd: Boolean = false, + resourceFileName: String = firrtl.transforms.BlackBoxSourceHelper.defaultFileListName + ): ProcessBuilder = { BackendCompilationUtilities.verilogToCpp(dutFile, dir, vSources, cppHarness, suppressVcd, resourceFileName) } def cppToExe(prefix: String, dir: File): ProcessBuilder = BackendCompilationUtilities.cppToExe(prefix, dir) def executeExpectingFailure( - prefix: String, - dir: File, - assertionMsg: String = ""): Boolean = { + prefix: String, + dir: File, + assertionMsg: String = "" + ): Boolean = { BackendCompilationUtilities.executeExpectingFailure(prefix, dir, assertionMsg) } - def executeExpectingSuccess(prefix: String, dir: File): Boolean = BackendCompilationUtilities.executeExpectingSuccess(prefix, dir) + def executeExpectingSuccess(prefix: String, dir: File): Boolean = + BackendCompilationUtilities.executeExpectingSuccess(prefix, dir) } diff --git a/src/main/scala/firrtl/util/ClassUtils.scala b/src/main/scala/firrtl/util/ClassUtils.scala index 1b388035..34ff60fc 100644 --- a/src/main/scala/firrtl/util/ClassUtils.scala +++ b/src/main/scala/firrtl/util/ClassUtils.scala @@ -1,18 +1,20 @@ package firrtl.util object ClassUtils { + /** Determine if a named class is loaded. * * @param name - name of the class: "foo.bar" or "org.foo.bar" * @return true if the class has been loaded (is accessible), false otherwise. */ def isClassLoaded(name: String): Boolean = { - val found = try { - Class.forName(name, false, getClass.getClassLoader) != null - } catch { - case e: ClassNotFoundException => false - case x: Throwable => throw x - } + val found = + try { + Class.forName(name, false, getClass.getClassLoader) != null + } catch { + case e: ClassNotFoundException => false + case x: Throwable => throw x + } // println(s"isClassLoaded: %s $name".format(if (found) "found" else "didn't find")) found } diff --git a/src/main/scala/logger/Logger.scala b/src/main/scala/logger/Logger.scala index 9cf645fa..e002db92 100644 --- a/src/main/scala/logger/Logger.scala +++ b/src/main/scala/logger/Logger.scala @@ -4,7 +4,7 @@ package logger import java.io.{ByteArrayOutputStream, File, FileOutputStream, PrintStream} -import firrtl.{ExecutionOptionsManager, AnnotationSeq} +import firrtl.{AnnotationSeq, ExecutionOptionsManager} import firrtl.options.Viewer.view import logger.phases.{AddDefaults, Checks} @@ -38,7 +38,7 @@ object LogLevel extends Enumeration { case "info" => LogLevel.Info case "debug" => LogLevel.Debug case "trace" => LogLevel.Trace - case level => throw new Exception(s"Unknown LogLevel '$level'") + case level => throw new Exception(s"Unknown LogLevel '$level'") } } @@ -58,8 +58,8 @@ private class LoggerState { val classLevels = new scala.collection.mutable.HashMap[String, LogLevel.Value] val classToLevelCache = new scala.collection.mutable.HashMap[String, LogLevel.Value] var logClassNames = false - var stream: PrintStream = System.out - var fromInvoke: Boolean = false // this is used to not have invokes re-create run-state + var stream: PrintStream = System.out + var fromInvoke: Boolean = false // this is used to not have invokes re-create run-state var stringBufferOption: Option[Logger.OutputCaptor] = None override def toString: String = { @@ -137,10 +137,9 @@ object Logger { @deprecated("Use makescope(opts: FirrtlOptions)", "1.2") def makeScope[A](args: Array[String] = Array.empty)(codeBlock: => A): A = { val executionOptionsManager = new ExecutionOptionsManager("logger") - if(executionOptionsManager.parse(args)) { + if (executionOptionsManager.parse(args)) { makeScope(executionOptionsManager)(codeBlock) - } - else { + } else { throw new Exception(s"logger invoke failed to parse args ${args.mkString(", ")}") } } @@ -154,10 +153,9 @@ object Logger { def makeScope[A](options: AnnotationSeq)(codeBlock: => A): A = { val runState: LoggerState = { val newRunState = updatableLoggerState.value.getOrElse(new LoggerState) - if(newRunState.fromInvoke) { + if (newRunState.fromInvoke) { newRunState - } - else { + } else { val forcedNewRunState = new LoggerState forcedNewRunState.fromInvoke = true forcedNewRunState @@ -179,39 +177,41 @@ object Logger { */ private def testPackageNameMatch(className: String, level: LogLevel.Value): Option[Boolean] = { val classLevels = state.classLevels - if(classLevels.isEmpty) return None + if (classLevels.isEmpty) return None // If this class name in cache just use that value - val levelForThisClassName = state.classToLevelCache.getOrElse(className, { - // otherwise break up the class name in to full package path as list and find most specific entry you can - val packageNameList = className.split("""\.""").toList - /* - * start with full class path, lopping off from the tail until nothing left - */ - def matchPathToFindLevel(packageList: List[String]): LogLevel.Value = { - if(packageList.isEmpty) { - LogLevel.None + val levelForThisClassName = state.classToLevelCache.getOrElse( + className, { + // otherwise break up the class name in to full package path as list and find most specific entry you can + val packageNameList = className.split("""\.""").toList + /* + * start with full class path, lopping off from the tail until nothing left + */ + def matchPathToFindLevel(packageList: List[String]): LogLevel.Value = { + if (packageList.isEmpty) { + LogLevel.None + } else { + val partialName = packageList.mkString(".") + val level = classLevels.getOrElse( + partialName, { + matchPathToFindLevel(packageList.reverse.tail.reverse) + } + ) + level + } } - else { - val partialName = packageList.mkString(".") - val level = classLevels.getOrElse(partialName, { - matchPathToFindLevel(packageList.reverse.tail.reverse) - }) - level - } - } - val levelSpecified = matchPathToFindLevel(packageNameList) - if(levelSpecified != LogLevel.None) { - state.classToLevelCache(className) = levelSpecified + val levelSpecified = matchPathToFindLevel(packageNameList) + if (levelSpecified != LogLevel.None) { + state.classToLevelCache(className) = levelSpecified + } + levelSpecified } - levelSpecified - }) + ) - if(levelForThisClassName != LogLevel.None) { + if (levelForThisClassName != LogLevel.None) { Some(levelForThisClassName >= level) - } - else { + } else { None } } @@ -226,19 +226,20 @@ object Logger { */ private def showMessage(level: LogLevel.Value, className: String, message: => String): Unit = { def logIt(): Unit = { - if(state.logClassNames) { + if (state.logClassNames) { state.stream.println(s"[$level:$className] $message") - } - else { + } else { state.stream.println(message) } } testPackageNameMatch(className, level) match { - case Some(true) => logIt() + case Some(true) => logIt() case Some(false) => case None => - if((state.globalLevel == LogLevel.None && level == LogLevel.Error) || - (state.globalLevel != LogLevel.None && state.globalLevel >= level)) { + if ( + (state.globalLevel == LogLevel.None && level == LogLevel.Error) || + (state.globalLevel != LogLevel.None && state.globalLevel >= level) + ) { logIt() } } @@ -247,6 +248,7 @@ object Logger { def getGlobalLevel: LogLevel.Value = { state.globalLevel } + /** * This resets everything in the current Logger environment, including the destination * use this with caution. Unexpected things can happen @@ -309,7 +311,7 @@ object Logger { def clearStringBuffer(): Unit = { state.stringBufferOption match { case Some(x) => x.byteArrayOutputStream.reset() - case None => + case None => } } @@ -360,16 +362,16 @@ object Logger { */ def setOptions(inputAnnotations: AnnotationSeq): Unit = { val annotations = - Seq( new AddDefaults, Checks ) - .foldLeft(inputAnnotations)((a, p) => p.transform(a)) + Seq(new AddDefaults, Checks) + .foldLeft(inputAnnotations)((a, p) => p.transform(a)) val lopts = view[LoggerOptions](annotations) state.globalLevel = (state.globalLevel, lopts.globalLogLevel) match { case (LogLevel.None, LogLevel.None) => LogLevel.None - case (x, LogLevel.None) => x - case (LogLevel.None, x) => x - case (_, x) => x - case _ => LogLevel.Error + case (x, LogLevel.None) => x + case (LogLevel.None, x) => x + case (_, x) => x + case _ => LogLevel.Error } setClassLogLevels(lopts.classLogLevels) @@ -386,6 +388,7 @@ object Logger { * @param containerClass passed in from the LazyLogging trait in order to provide class level logging granularity */ class Logger(containerClass: String) { + /** * Log message at Error level * @param message message generator to be invoked if level is right @@ -393,6 +396,7 @@ class Logger(containerClass: String) { def error(message: => String): Unit = { Logger.showMessage(LogLevel.Error, containerClass, message) } + /** * Log message at Warn level * @param message message generator to be invoked if level is right @@ -400,6 +404,7 @@ class Logger(containerClass: String) { def warn(message: => String): Unit = { Logger.showMessage(LogLevel.Warn, containerClass, message) } + /** * Log message at Inof level * @param message message generator to be invoked if level is right @@ -407,6 +412,7 @@ class Logger(containerClass: String) { def info(message: => String): Unit = { Logger.showMessage(LogLevel.Info, containerClass, message) } + /** * Log message at Debug level * @param message message generator to be invoked if level is right @@ -414,6 +420,7 @@ class Logger(containerClass: String) { def debug(message: => String): Unit = { Logger.showMessage(LogLevel.Debug, containerClass, message) } + /** * Log message at Trace level * @param message message generator to be invoked if level is right diff --git a/src/main/scala/logger/LoggerAnnotations.scala b/src/main/scala/logger/LoggerAnnotations.scala index f4dc6b38..b345d617 100644 --- a/src/main/scala/logger/LoggerAnnotations.scala +++ b/src/main/scala/logger/LoggerAnnotations.scala @@ -5,7 +5,6 @@ package logger import firrtl.annotations.{Annotation, NoTargetAnnotation} import firrtl.options.{HasShellOptions, ShellOption} - /** An annotation associated with a Logger command line option */ sealed trait LoggerOption { this: Annotation => } @@ -14,7 +13,9 @@ sealed trait LoggerOption { this: Annotation => } * - if unset, a [[LogLevelAnnotation]] with the default log level will be emitted * @param level the level of logging */ -case class LogLevelAnnotation(globalLogLevel: LogLevel.Value = LogLevel.Warn) extends NoTargetAnnotation with LoggerOption +case class LogLevelAnnotation(globalLogLevel: LogLevel.Value = LogLevel.Warn) + extends NoTargetAnnotation + with LoggerOption object LogLevelAnnotation extends HasShellOptions { @@ -24,7 +25,9 @@ object LogLevelAnnotation extends HasShellOptions { toAnnotationSeq = (a: String) => Seq(LogLevelAnnotation(LogLevel(a))), helpText = s"Set global logging verbosity (default: ${new LoggerOptions().globalLogLevel}", shortOption = Some("ll"), - helpValueName = Some("{error|warn|info|debug|trace}") ) ) + helpValueName = Some("{error|warn|info|debug|trace}") + ) + ) } @@ -33,20 +36,26 @@ object LogLevelAnnotation extends HasShellOptions { * @param name the class name to log * @param level the verbosity level */ -case class ClassLogLevelAnnotation(className: String, level: LogLevel.Value) extends NoTargetAnnotation with LoggerOption +case class ClassLogLevelAnnotation(className: String, level: LogLevel.Value) + extends NoTargetAnnotation + with LoggerOption object ClassLogLevelAnnotation extends HasShellOptions { val options = Seq( new ShellOption[Seq[String]]( longOption = "class-log-level", - toAnnotationSeq = (a: Seq[String]) => a.map { aa => - val className :: levelName :: _ = aa.split(":").toList - val level = LogLevel(levelName) - ClassLogLevelAnnotation(className, level) }, + toAnnotationSeq = (a: Seq[String]) => + a.map { aa => + val className :: levelName :: _ = aa.split(":").toList + val level = LogLevel(levelName) + ClassLogLevelAnnotation(className, level) + }, helpText = "Set per-class logging verbosity", shortOption = Some("cll"), - helpValueName = Some("<FullClassName:{error|warn|info|debug|trace}>...") ) ) + helpValueName = Some("<FullClassName:{error|warn|info|debug|trace}>...") + ) + ) } @@ -63,7 +72,9 @@ object LogFileAnnotation extends HasShellOptions { longOption = "log-file", toAnnotationSeq = (a: String) => Seq(LogFileAnnotation(Some(a))), helpText = "Log to a file instead of STDOUT", - helpValueName = Some("<file>") ) ) + helpValueName = Some("<file>") + ) + ) } @@ -77,6 +88,8 @@ case object LogClassNamesAnnotation extends NoTargetAnnotation with LoggerOption longOption = "log-class-names", toAnnotationSeq = (a: Unit) => Seq(LogClassNamesAnnotation), helpText = "Show class names and log level in logging output", - shortOption = Some("lcn") ) ) + shortOption = Some("lcn") + ) + ) } diff --git a/src/main/scala/logger/LoggerOptions.scala b/src/main/scala/logger/LoggerOptions.scala index 299382f0..6cc745b9 100644 --- a/src/main/scala/logger/LoggerOptions.scala +++ b/src/main/scala/logger/LoggerOptions.scala @@ -9,23 +9,25 @@ package logger * @param logToFile if true, log to a file * @param logClassNames indicates logging verbosity on a class-by-class basis */ -class LoggerOptions private [logger] ( - val globalLogLevel: LogLevel.Value = LogLevelAnnotation().globalLogLevel, +class LoggerOptions private[logger] ( + val globalLogLevel: LogLevel.Value = LogLevelAnnotation().globalLogLevel, val classLogLevels: Map[String, LogLevel.Value] = Map.empty, - val logClassNames: Boolean = false, - val logFileName: Option[String] = None) { + val logClassNames: Boolean = false, + val logFileName: Option[String] = None) { - private [logger] def copy( - globalLogLevel: LogLevel.Value = globalLogLevel, + private[logger] def copy( + globalLogLevel: LogLevel.Value = globalLogLevel, classLogLevels: Map[String, LogLevel.Value] = classLogLevels, - logClassNames: Boolean = logClassNames, - logFileName: Option[String] = logFileName): LoggerOptions = { + logClassNames: Boolean = logClassNames, + logFileName: Option[String] = logFileName + ): LoggerOptions = { new LoggerOptions( globalLogLevel = globalLogLevel, classLogLevels = classLogLevels, logClassNames = logClassNames, - logFileName = logFileName) + logFileName = logFileName + ) } diff --git a/src/main/scala/logger/phases/AddDefaults.scala b/src/main/scala/logger/phases/AddDefaults.scala index 660de579..ec673637 100644 --- a/src/main/scala/logger/phases/AddDefaults.scala +++ b/src/main/scala/logger/phases/AddDefaults.scala @@ -5,10 +5,10 @@ package logger.phases import firrtl.AnnotationSeq import firrtl.options.Phase -import logger.{LoggerOption, LogLevelAnnotation} +import logger.{LogLevelAnnotation, LoggerOption} /** Add default logger [[Annotation]]s */ -private [logger] class AddDefaults extends Phase { +private[logger] class AddDefaults extends Phase { override def prerequisites = Seq.empty override def optionalPrerequisiteOf = Seq.empty @@ -20,12 +20,12 @@ private [logger] class AddDefaults extends Phase { */ def transform(annotations: AnnotationSeq): AnnotationSeq = { var ll = true - annotations.collect{ case a: LoggerOption => a }.map{ + annotations.collect { case a: LoggerOption => a }.map { case _: LogLevelAnnotation => ll = false - case _ => + case _ => } annotations ++ - (if (ll) Seq(LogLevelAnnotation()) else Seq() ) + (if (ll) Seq(LogLevelAnnotation()) else Seq()) } } diff --git a/src/main/scala/logger/phases/Checks.scala b/src/main/scala/logger/phases/Checks.scala index e945fa98..0109c7ad 100644 --- a/src/main/scala/logger/phases/Checks.scala +++ b/src/main/scala/logger/phases/Checks.scala @@ -6,12 +6,13 @@ import firrtl.AnnotationSeq import firrtl.annotations.Annotation import firrtl.options.{Dependency, Phase} -import logger.{LogLevelAnnotation, LogFileAnnotation, LoggerException} +import logger.{LogFileAnnotation, LogLevelAnnotation, LoggerException} import scala.collection.mutable /** Check that an [[firrtl.AnnotationSeq AnnotationSeq]] has all necessary [[firrtl.annotations.Annotation Annotation]]s - * for a [[Logger]] */ + * for a [[Logger]] + */ object Checks extends Phase { override def prerequisites = Seq(Dependency[AddDefaults]) @@ -26,20 +27,22 @@ object Checks extends Phase { */ def transform(annotations: AnnotationSeq): AnnotationSeq = { val ll, lf = mutable.ListBuffer[Annotation]() - annotations.foreach( - _ match { - case a: LogLevelAnnotation => ll += a - case a: LogFileAnnotation => lf += a - case _ => }) + annotations.foreach(_ match { + case a: LogLevelAnnotation => ll += a + case a: LogFileAnnotation => lf += a + case _ => + }) if (ll.size > 1) { - val l = ll.map{ case LogLevelAnnotation(x) => x } + val l = ll.map { case LogLevelAnnotation(x) => x } throw new LoggerException( s"""|At most one log level can be specified, but found '${l.mkString(", ")}' specified via: - | - an option or annotation: -ll, --log-level, LogLevelAnnotation""".stripMargin )} + | - an option or annotation: -ll, --log-level, LogLevelAnnotation""".stripMargin + ) + } if (lf.size > 1) { - throw new LoggerException( - s"""|At most one log file can be specified, but found ${lf.size} combinations of: - | - an options or annotation: -ltf, --log-to-file, --log-file, LogFileAnnotation""".stripMargin )} + throw new LoggerException(s"""|At most one log file can be specified, but found ${lf.size} combinations of: + | - an options or annotation: -ltf, --log-to-file, --log-file, LogFileAnnotation""".stripMargin) + } annotations } diff --git a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala index 48427af8..72f50461 100644 --- a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala +++ b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala @@ -4,9 +4,9 @@ package tutorial package lesson1 // Compiler Infrastructure -import firrtl.{Transform, LowForm, CircuitState, Utils} +import firrtl.{CircuitState, LowForm, Transform, Utils} // Firrtl IR classes -import firrtl.ir.{DefModule, Statement, Expression, Mux} +import firrtl.ir.{DefModule, Expression, Mux, Statement} // Map functions import firrtl.Mappers._ // Scala's mutable collections @@ -26,11 +26,11 @@ class Ledger { private val modules = mutable.Set[String]() private val moduleMuxMap = mutable.Map[String, Int]() def foundMux(): Unit = moduleName match { - case None => sys.error("Module name not defined in Ledger!") + case None => sys.error("Module name not defined in Ledger!") case Some(name) => moduleMuxMap(name) = moduleMuxMap.getOrElse(name, 0) + 1 } def getModuleName: String = moduleName match { - case None => Utils.error("Module name not defined in Ledger!") + case None => Utils.error("Module name not defined in Ledger!") case Some(name) => name } def setModuleName(myName: String): Unit = { @@ -38,9 +38,9 @@ class Ledger { moduleName = Some(myName) } def serialize: String = { - modules map { myName => + modules.map { myName => s"$myName => ${moduleMuxMap.getOrElse(myName, 0)} muxes!" - } mkString "\n" + }.mkString("\n") } } @@ -68,8 +68,10 @@ class Ledger { * - https://github.com/ucb-bar/firrtl/wiki/Common-Pass-Idioms */ class AnalyzeCircuit extends Transform { + /** Requires the [[firrtl.ir.Circuit Circuit]] form to be "low" */ def inputForm = LowForm + /** Indicates the output [[firrtl.ir.Circuit Circuit]] form to be "low" */ def outputForm = LowForm @@ -88,7 +90,7 @@ class AnalyzeCircuit extends Transform { * - "map" - classic functional programming concept * - discard the returned new [[firrtl.ir.Circuit Circuit]] because circuit is unmodified */ - circuit map walkModule(ledger) + circuit.map(walkModule(ledger)) // Print our ledger println(ledger.serialize) @@ -106,7 +108,7 @@ class AnalyzeCircuit extends Transform { * - return the new [[firrtl.ir.DefModule DefModule]] (in this case, its identical to m) * - if m does not contain [[firrtl.ir.Statement Statement]], map returns m. */ - m map walkStatement(ledger) + m.map(walkStatement(ledger)) } /** Deeply visits every [[firrtl.ir.Statement Statement]] and [[firrtl.ir.Expression Expression]] in s. */ @@ -116,13 +118,13 @@ class AnalyzeCircuit extends Transform { * - discard the new [[firrtl.ir.Statement Statement]] (in this case, its identical to s) * - if s does not contain [[firrtl.ir.Expression Expression]], map returns s. */ - s map walkExpression(ledger) + s.map(walkExpression(ledger)) /* Execute the function walkStatement(ledger) on every [[firrtl.ir.Statement Statement]] in s. * - return the new [[firrtl.ir.Statement Statement]] (in this case, its identical to s) * - if s does not contain [[firrtl.ir.Statement Statement]], map returns s. */ - s map walkStatement(ledger) + s.map(walkStatement(ledger)) } /** Deeply visits every [[firrtl.ir.Expression Expression]] in e. @@ -135,7 +137,7 @@ class AnalyzeCircuit extends Transform { * - return the new [[firrtl.ir.Expression Expression]] (in this case, its identical to e) * - if s does not contain [[firrtl.ir.Expression Expression]], map returns e. */ - val visited = e map walkExpression(ledger) + val visited = e.map(walkExpression(ledger)) visited match { // If e is a [[firrtl.ir.Mux Mux]], increment our ledger and return e. diff --git a/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala index 523be723..11b4519c 100644 --- a/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala +++ b/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala @@ -4,9 +4,9 @@ package tutorial package lesson2 // Compiler Infrastructure -import firrtl.{Transform, LowForm, CircuitState} +import firrtl.{CircuitState, LowForm, Transform} // Firrtl IR classes -import firrtl.ir.{DefModule, Statement, Expression, Mux, DefInstance} +import firrtl.ir.{DefInstance, DefModule, Expression, Mux, Statement} // Map functions import firrtl.Mappers._ // Scala's mutable collections @@ -27,7 +27,7 @@ class Ledger { private val moduleMuxMap = mutable.Map[String, Int]() private val moduleInstanceMap = mutable.Map[String, Seq[String]]() def getModuleName: String = moduleName match { - case None => sys.error("Module name not defined in Ledger!") + case None => sys.error("Module name not defined in Ledger!") case Some(name) => name } def setModuleName(myName: String): Unit = { @@ -47,14 +47,14 @@ class Ledger { private def countMux(myName: String): Int = { val myMuxes = moduleMuxMap.getOrElse(myName, 0) val myInstanceMuxes = - moduleInstanceMap.getOrElse(myName, Nil).foldLeft(0) { - (total, name) => total + countMux(name) + moduleInstanceMap.getOrElse(myName, Nil).foldLeft(0) { (total, name) => + total + countMux(name) } myMuxes + myInstanceMuxes } // Display recursive total of muxes def serialize: String = { - modules map { myName => s"$myName => ${countMux(myName)} muxes!" } mkString "\n" + modules.map { myName => s"$myName => ${countMux(myName)} muxes!" }.mkString("\n") } } @@ -76,7 +76,6 @@ class Ledger { * - Kind -> ExpKind * - Flow -> UnknownFlow * - Type -> UnknownType - * */ class AnalyzeCircuit extends Transform { def inputForm = LowForm @@ -88,7 +87,7 @@ class AnalyzeCircuit extends Transform { val circuit = state.circuit // Execute the function walkModule(ledger) on all [[DefModule]] in circuit - circuit map walkModule(ledger) + circuit.map(walkModule(ledger)) // Print our ledger println(ledger.serialize) @@ -103,13 +102,13 @@ class AnalyzeCircuit extends Transform { ledger.setModuleName(m.name) // Execute the function walkStatement(ledger) on every [[Statement]] in m. - m map walkStatement(ledger) + m.map(walkStatement(ledger)) } // Deeply visits every [[Statement]] and [[Expression]] in s. def walkStatement(ledger: Ledger)(s: Statement): Statement = { // Map the functions walkStatement(ledger) and walkExpression(ledger) - val visited = s map walkStatement(ledger) map walkExpression(ledger) + val visited = s.map(walkStatement(ledger)).map(walkExpression(ledger)) visited match { case DefInstance(info, name, module, tpe) => ledger.foundInstance(module) @@ -122,7 +121,7 @@ class AnalyzeCircuit extends Transform { def walkExpression(ledger: Ledger)(e: Expression): Expression = { // Execute the function walkExpression(ledger) on every [[Expression]] in e, // then handle if a [[Mux]]. - e map walkExpression(ledger) match { + e.map(walkExpression(ledger)) match { case mux: Mux => ledger.foundMux() mux diff --git a/src/test/scala/firrtl/JsonProtocolSpec.scala b/src/test/scala/firrtl/JsonProtocolSpec.scala index 7d04e9fc..cc7591cb 100644 --- a/src/test/scala/firrtl/JsonProtocolSpec.scala +++ b/src/test/scala/firrtl/JsonProtocolSpec.scala @@ -4,7 +4,13 @@ package firrtlTests import org.json4s._ -import firrtl.annotations.{NoTargetAnnotation, JsonProtocol, InvalidAnnotationJSONException, HasSerializationHints, Annotation} +import firrtl.annotations.{ + Annotation, + HasSerializationHints, + InvalidAnnotationJSONException, + JsonProtocol, + NoTargetAnnotation +} import org.scalatest.flatspec.AnyFlatSpec object JsonProtocolTestClasses { @@ -13,12 +19,16 @@ object JsonProtocolTestClasses { case class ChildA(foo: Int) extends Parent case class ChildB(bar: String) extends Parent case class PolymorphicParameterAnnotation(param: Parent) extends NoTargetAnnotation - case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) extends NoTargetAnnotation with HasSerializationHints { + case class PolymorphicParameterAnnotationWithTypeHints(param: Parent) + extends NoTargetAnnotation + with HasSerializationHints { def typeHints = Seq(param.getClass) } case class TypeParameterizedAnnotation[T](param: T) extends NoTargetAnnotation - case class TypeParameterizedAnnotationWithTypeHints[T](param: T) extends NoTargetAnnotation with HasSerializationHints { + case class TypeParameterizedAnnotationWithTypeHints[T](param: T) + extends NoTargetAnnotation + with HasSerializationHints { def typeHints = Seq(param.getClass) } } @@ -51,11 +61,11 @@ class JsonProtocolSpec extends AnyFlatSpec { "Annotations with non-primitive type parameters" should "not serialize and deserialize without type hints" in { val anno = TypeParameterizedAnnotation(ChildA(1)) val deserAnno = serializeAndDeserialize(anno) - assert (anno != deserAnno) + assert(anno != deserAnno) } it should "serialize and deserialize with type hints" in { val anno = TypeParameterizedAnnotationWithTypeHints(ChildA(1)) val deserAnno = serializeAndDeserialize(anno) - assert (anno == deserAnno) + assert(anno == deserAnno) } } diff --git a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala index 599b4e52..ca30b60b 100644 --- a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala +++ b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala @@ -8,7 +8,7 @@ import firrtl.options.Dependency import org.scalatest.flatspec.AnyFlatSpec class SymbolTableSpec extends AnyFlatSpec { - behavior of "SymbolTable" + behavior.of("SymbolTable") private val src = """circuit m: @@ -50,9 +50,20 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms("r").tpe == ir.SIntType(ir.IntWidth(4)) && syms("r").kind == firrtl.RegKind) val mType = firrtl.passes.MemPortUtils.memType( // only dataType, depth and reader, writer, readwriter properties affect the data type - ir.DefMemory(ir.NoInfo, "???", ir.UIntType(ir.IntWidth(8)), 32, 10, 10, Seq("r"), Seq(), Seq(), ir.ReadUnderWrite.New) + ir.DefMemory( + ir.NoInfo, + "???", + ir.UIntType(ir.IntWidth(8)), + 32, + 10, + 10, + Seq("r"), + Seq(), + Seq(), + ir.ReadUnderWrite.New + ) ) - assert(syms("m") .tpe == mType && syms("m").kind == firrtl.MemKind) + assert(syms("m").tpe == mType && syms("m").kind == firrtl.MemKind) } it should "find all declarations in module m after InferTypes" in { @@ -69,7 +80,7 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms("i").tpe == iType && syms("i").kind == firrtl.InstanceKind) } - behavior of "WithSeq" + behavior.of("WithSeq") it should "preserve declaration order" in { val c = firrtl.Parser.parse(src) @@ -79,7 +90,7 @@ class SymbolTableSpec extends AnyFlatSpec { assert(syms.getSymbols.map(_.name) == Seq("clk", "x", "y", "z", "a", "i", "r", "m")) } - behavior of "ModuleTypesSymbolTable" + behavior.of("ModuleTypesSymbolTable") it should "derive the module type from the module types map" in { val c = firrtl.Parser.parse(src) diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala index 015ac4a9..f7ce9914 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala @@ -20,10 +20,16 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { sys.signals.head.e.toString } - def primop(signed: Boolean, op: String, resWidth: Int, inWidth: Seq[Int], consts: Seq[Int] = List(), - resAlwaysUnsigned: Boolean = false): String = { - val tpe = if(signed) "SInt" else "UInt" - val resTpe = if(resAlwaysUnsigned) "UInt" else tpe + def primop( + signed: Boolean, + op: String, + resWidth: Int, + inWidth: Seq[Int], + consts: Seq[Int] = List(), + resAlwaysUnsigned: Boolean = false + ): String = { + val tpe = if (signed) "SInt" else "UInt" + val resTpe = if (resAlwaysUnsigned) "UInt" else tpe val inTpes = inWidth.map(w => s"$tpe<$w>") primop(op, s"$resTpe<$resWidth>", inTpes, consts) } @@ -52,16 +58,24 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { it should "correctly translate the `div` primitive operation" in { // division is a little bit more complicated because the result of division by zero is undefined - assert(primop(false, "div", 8, List(8, 8)) == - "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))") - assert(primop(false, "div", 8, List(8, 4)) == - "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))") + assert( + primop(false, "div", 8, List(8, 8)) == + "ite(eq(i1, 8'b0), RANDOM.res, udiv(i0, i1))" + ) + assert( + primop(false, "div", 8, List(8, 4)) == + "ite(eq(i1, 4'b0), RANDOM.res, udiv(i0, zext(i1, 4)))" + ) // signed division increases result width by 1 - assert(primop(true, "div", 8, List(7, 7)) == - "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))") - assert(primop(true, "div", 8, List(7, 4)) - == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))") + assert( + primop(true, "div", 8, List(7, 7)) == + "ite(eq(i1, 7'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 1)))" + ) + assert( + primop(true, "div", 8, List(7, 4)) + == "ite(eq(i1, 4'b0), RANDOM.res, sdiv(sext(i0, 1), sext(i1, 4)))" + ) } it should "correctly translate the `rem` primitive operation" in { @@ -134,15 +148,19 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { it should "correctly translate the `dshl` primitive operation" in { assert(primop(false, "dshl", 31, List(16, 4)) == "logical_shift_left(zext(i0, 15), zext(i1, 27))") assert(primop(false, "dshl", 19, List(16, 2)) == "logical_shift_left(zext(i0, 3), zext(i1, 17))") - assert(primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == - "logical_shift_left(sext(i0, 3), zext(i1, 17))") + assert( + primop("dshl", "SInt<19>", List("SInt<16>", "UInt<2>"), List()) == + "logical_shift_left(sext(i0, 3), zext(i1, 17))" + ) } it should "correctly translate the `dshr` primitive operation" in { assert(primop(false, "dshr", 16, List(16, 4)) == "logical_shift_right(i0, zext(i1, 12))") assert(primop(false, "dshr", 16, List(16, 2)) == "logical_shift_right(i0, zext(i1, 14))") - assert(primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == - "arithmetic_shift_right(i0, zext(i1, 14))") + assert( + primop("dshr", "SInt<16>", List("SInt<16>", "UInt<2>"), List()) == + "arithmetic_shift_right(i0, zext(i1, 14))" + ) } it should "correctly translate the `cvt` primitive operation" in { @@ -197,15 +215,15 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { } it should "correctly translate the `bits` primitive operation" in { - assert(primop(false, "bits", 1, List(4), List(2,2)) == "i0[2]") - assert(primop(false, "bits", 2, List(4), List(2,1)) == "i0[2:1]") - assert(primop(false, "bits", 1, List(4), List(2,1)) == "i0[2:1][0]") - assert(primop(false, "bits", 3, List(4), List(2,1)) == "zext(i0[2:1], 1)") + assert(primop(false, "bits", 1, List(4), List(2, 2)) == "i0[2]") + assert(primop(false, "bits", 2, List(4), List(2, 1)) == "i0[2:1]") + assert(primop(false, "bits", 1, List(4), List(2, 1)) == "i0[2:1][0]") + assert(primop(false, "bits", 3, List(4), List(2, 1)) == "zext(i0[2:1], 1)") - assert(primop(true, "bits", 1, List(4), List(2,2), resAlwaysUnsigned = true) == "i0[2]") - assert(primop(true, "bits", 2, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1]") - assert(primop(true, "bits", 1, List(4), List(2,1), resAlwaysUnsigned = true) == "i0[2:1][0]") - assert(primop(true, "bits", 3, List(4), List(2,1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") + assert(primop(true, "bits", 1, List(4), List(2, 2), resAlwaysUnsigned = true) == "i0[2]") + assert(primop(true, "bits", 2, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1]") + assert(primop(true, "bits", 1, List(4), List(2, 1), resAlwaysUnsigned = true) == "i0[2:1][0]") + assert(primop(true, "bits", 3, List(4), List(2, 1), resAlwaysUnsigned = true) == "zext(i0[2:1], 1)") } it should "correctly translate the `head` primitive operation" in { @@ -221,4 +239,4 @@ private class FirrtlExpressionSemanticsSpec extends SMTBackendBaseSpec { assert(primop(false, "tail", 4, List(5), List(1)) == "i0[3:0]") assert(primop(false, "tail", 2, List(5), List(3)) == "i0[1:0]") } -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala index ca7974c5..b41313e3 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlModuleToTransitionSystemSpec.scala @@ -5,8 +5,7 @@ package firrtl.backends.experimental.smt import firrtl.{MemoryArrayInit, MemoryScalarInit, Utils} private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { - behavior of "ModuleToTransitionSystem.run" - + behavior.of("ModuleToTransitionSystem.run") it should "model registers as state" in { // if a signal is invalid, it could take on an arbitrary value in that cycle @@ -42,39 +41,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { private def memCircuit(depth: Int = 32) = s"""circuit m: - | module m: - | input reset : UInt<1> - | input clock : Clock - | input addr : UInt<${Utils.getUIntWidth(depth)}> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => $depth - | reader => r - | writer => w - | read-latency => 0 - | write-latency => 1 - | read-under-write => new - | - | m.w.clk <= clock - | m.w.mask <= UInt(1) - | m.w.en <= UInt(1) - | m.w.data <= in - | m.w.addr <= addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= addr - | - |""".stripMargin + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<${Utils.getUIntWidth(depth)}> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => $depth + | reader => r + | writer => w + | read-latency => 0 + | write-latency => 1 + | read-under-write => new + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin it should "model memories as state" in { val sys = toSys(memCircuit()) - assert(sys.signals.length == 9-2+1, "9 connects - 2 clock connects + 1 combinatorial read port") + assert(sys.signals.length == 9 - 2 + 1, "9 connects - 2 clock connects + 1 combinatorial read port") val sig = sys.signals.map(s => s.name -> s.e).toMap @@ -140,40 +139,39 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { it should "support memories with registered read port" in { def src(readUnderWrite: String) = s"""circuit m: - | module m: - | input reset : UInt<1> - | input clock : Clock - | input addr : UInt<5> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => 32 - | reader => r - | writer => w1, w2 - | read-latency => 1 - | write-latency => 1 - | read-under-write => $readUnderWrite - | - | m.w1.clk <= clock - | m.w1.mask <= UInt(1) - | m.w1.en <= UInt(1) - | m.w1.data <= in - | m.w1.addr <= addr - | m.w2.clk <= clock - | m.w2.mask <= UInt(1) - | m.w2.en <= UInt(1) - | m.w2.data <= in - | m.w2.addr <= addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= addr - | - |""".stripMargin - + | module m: + | input reset : UInt<1> + | input clock : Clock + | input addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w1, w2 + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w1.clk <= clock + | m.w1.mask <= UInt(1) + | m.w1.en <= UInt(1) + | m.w1.data <= in + | m.w1.addr <= addr + | m.w2.clk <= clock + | m.w2.mask <= UInt(1) + | m.w2.en <= UInt(1) + | m.w2.data <= in + | m.w2.addr <= addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= addr + | + |""".stripMargin val oldValue = toSys(src("old")) val oldMData = oldValue.states.find(_.sym.name == "m.r.data").get @@ -186,9 +184,11 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { val undefinedMData = undefinedValue.states.find(_.sym.name == "m.r.data").get assert(undefinedMData.sym.toString == "m.r.data") val undefined = "RANDOM.m_r_read_under_write_undefined" - assert(undefinedMData.next.get.toString == - s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", - "randomize result if there is a write") + assert( + undefinedMData.next.get.toString == + s"ite(or(eq(m.r.addr, m.w1.addr), eq(m.r.addr, m.w2.addr)), $undefined, m[m.r.addr])", + "randomize result if there is a write" + ) } it should "support memories with potential write-write conflicts" in { @@ -228,7 +228,6 @@ private class FirrtlModuleToTransitionSystemSpec extends SMTBackendBaseSpec { | |""".stripMargin - val sys = toSys(src) val m = sys.states.find(_.sym.name == "m").get diff --git a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala index 6bfb5437..209279fd 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/SMTBackendBaseSpec.scala @@ -3,7 +3,7 @@ package firrtl.backends.experimental.smt import firrtl.annotations.Annotation -import firrtl.{MemoryInitValue, ir} +import firrtl.{ir, MemoryInitValue} import firrtl.stage.{Forms, TransformManager} import org.scalatest.flatspec.AnyFlatSpec @@ -16,8 +16,12 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec { compiler.runTransform(firrtl.CircuitState(c, annos)).circuit } - protected def toSys(src: String, mod: String = "m", presetRegs: Set[String] = Set(), - memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = { + protected def toSys( + src: String, + mod: String = "m", + presetRegs: Set[String] = Set(), + memInit: Map[String, MemoryInitValue] = Map() + ): TransitionSystem = { val circuit = compile(src) val module = circuit.modules.find(_.name == mod).get.asInstanceOf[ir.Module] // println(module.serialize) @@ -35,4 +39,4 @@ private abstract class SMTBackendBaseSpec extends AnyFlatSpec { protected def toSMTLibStr(src: String, mod: String = "m"): String = toSMTLib(src, mod).mkString("\n") + "\n" -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala index 4c6901ea..e7c8d534 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/AsyncResetSpec.scala @@ -10,9 +10,10 @@ import firrtl.stage.RunFirrtlTransformAnnotation class AsyncResetSpec extends EndToEndSMTBaseSpec { def annos(name: String) = Seq( RunFirrtlTransformAnnotation(Dependency[StutteringClockTransform]), - GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock"))) + GlobalClockAnnotation(CircuitTarget(name).module(name).ref("global_clock")) + ) - "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs(RequiresZ3) in { + "a module with asynchronous reset" should "allow a register to change between clock edges" taggedAs (RequiresZ3) in { def in(resetType: String) = s"""circuit AsyncReset00: | module AsyncReset00: @@ -39,8 +40,8 @@ class AsyncResetSpec extends EndToEndSMTBaseSpec { | ; can the value of r change without the count changing? | assert(global_clock, or(not(eq(count, past_count)), eq(r, past_r)), past_valid, "count = past(count) |-> r = past(r)") |""".stripMargin - test(in("AsyncReset"), MCFail(1), kmax=2, annos=annos("AsyncReset00")) - test(in("UInt<1>"), MCSuccess, kmax=2, annos=annos("AsyncReset00")) + test(in("AsyncReset"), MCFail(1), kmax = 2, annos = annos("AsyncReset00")) + test(in("UInt<1>"), MCSuccess, kmax = 2, annos = annos("AsyncReset00")) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala index 2227719b..974d2e81 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/EndToEndSMTSpec.scala @@ -16,24 +16,23 @@ import org.scalatest.matchers.must.Matchers import scala.sys.process._ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { - "we" should "check if Z3 is available" taggedAs(RequiresZ3) in { + "we" should "check if Z3 is available" taggedAs (RequiresZ3) in { val log = ProcessLogger(_ => (), logger.warn(_)) val ret = Process(Seq("which", "z3")).run(log).exitValue() - if(ret != 0) { - logger.error( - """The z3 SMT-Solver seems not to be installed. - |You can exclude the end-to-end smt backend tests which rely on z3 like this: - |sbt testOnly -- -l RequiresZ3 - |""".stripMargin) + if (ret != 0) { + logger.error("""The z3 SMT-Solver seems not to be installed. + |You can exclude the end-to-end smt backend tests which rely on z3 like this: + |sbt testOnly -- -l RequiresZ3 + |""".stripMargin) } assert(ret == 0) } - "Z3" should "be available in version 4" taggedAs(RequiresZ3) in { + "Z3" should "be available in version 4" taggedAs (RequiresZ3) in { assert(Z3ModelChecker.getZ3Version.startsWith("4.")) } - "a simple combinatorial check" should "pass" taggedAs(RequiresZ3) in { + "a simple combinatorial check" should "pass" taggedAs (RequiresZ3) in { val in = """circuit CC00: | module CC00: @@ -45,7 +44,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in, MCSuccess) } - "a simple combinatorial check" should "fail immediately" taggedAs(RequiresZ3) in { + "a simple combinatorial check" should "fail immediately" taggedAs (RequiresZ3) in { val in = """circuit CC01: | module CC01: @@ -57,7 +56,7 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in, MCFail(0)) } - "adding the right assumption" should "make a test pass" taggedAs(RequiresZ3) in { + "adding the right assumption" should "make a test pass" taggedAs (RequiresZ3) in { val in0 = """circuit CC01: | module CC01: @@ -75,8 +74,8 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { | assert(c, neq(add(a, b), UInt(0)), UInt(1), "a + b != 0") | assume(c, neq(a, UInt(0)), UInt(1), "a != 0") |""".stripMargin - test(in0, MCFail(0)) - test(in1, MCSuccess) + test(in0, MCFail(0)) + test(in1, MCSuccess) val in2 = """circuit CC01: @@ -91,20 +90,20 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { test(in2, MCFail(0)) } - "a register connected to preset reset" should "be initialized with the reset value" taggedAs(RequiresZ3) in { + "a register connected to preset reset" should "be initialized with the reset value" taggedAs (RequiresZ3) in { def in(rEq: Int) = s"""circuit Preset00: - | module Preset00: - | input c: Clock - | input preset: AsyncReset - | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) - | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") - |""".stripMargin + | module Preset00: + | input c: Clock + | input preset: AsyncReset + | reg r: UInt<4>, c with: (reset => (preset, UInt(3))) + | assert(c, eq(r, UInt($rEq)), UInt(1), "r = $rEq") + |""".stripMargin test(in(3), MCSuccess, kmax = 1) test(in(2), MCFail(0)) } - "a register's initial value" should "should not change" taggedAs(RequiresZ3) in { + "a register's initial value" should "should not change" taggedAs (RequiresZ3) in { val in = """circuit Preset00: | module Preset00: @@ -127,24 +126,29 @@ class EndToEndSMTSpec extends EndToEndSMTBaseSpec with LazyLogging { abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { def test(src: String, expected: MCResult, kmax: Int = 0, clue: String = "", annos: Seq[Annotation] = Seq()): Unit = { expected match { - case MCFail(k) => assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") + case MCFail(k) => + assert(kmax >= k, s"Please set a kmax that includes the expected failing step! ($kmax < $expected)") case _ => } val fir = firrtl.Parser.parse(src) val name = fir.main val testDir = BackendCompilationUtilities.createTestDirectory("EndToEndSMT." + name) // we automagically add a preset annotation if an input called preset exists - val presetAnno = if(!src.contains("input preset")) { None } else { + val presetAnno = if (!src.contains("input preset")) { None } + else { Some(PresetAnnotation(CircuitTarget(name).module(name).ref("preset"))) } - val res = (new FirrtlStage).execute(Array(), Seq( - LogLevelAnnotation(LogLevel.Error), // silence warnings for tests - RunFirrtlTransformAnnotation(new SMTLibEmitter), - RunFirrtlTransformAnnotation(new Btor2Emitter), - FirrtlCircuitAnnotation(fir), - TargetDirAnnotation(testDir.getAbsolutePath) - ) ++ presetAnno ++ annos) - assert(res.collectFirst{ case _: OutputFileAnnotation => true }.isDefined) + val res = (new FirrtlStage).execute( + Array(), + Seq( + LogLevelAnnotation(LogLevel.Error), // silence warnings for tests + RunFirrtlTransformAnnotation(new SMTLibEmitter), + RunFirrtlTransformAnnotation(new Btor2Emitter), + FirrtlCircuitAnnotation(fir), + TargetDirAnnotation(testDir.getAbsolutePath) + ) ++ presetAnno ++ annos + ) + assert(res.collectFirst { case _: OutputFileAnnotation => true }.isDefined) val r = Z3ModelChecker.bmc(testDir, name, kmax) assert(r == expected, clue + "\n" + s"$testDir") } @@ -153,7 +157,7 @@ abstract class EndToEndSMTBaseSpec extends AnyFlatSpec with Matchers { /** Minimal implementation of a Z3 based bounded model checker. * A more complete version of this with better use feedback should eventually be provided by a * chisel3 formal verification library. Do not use this implementation outside of the firrtl test suite! - * */ + */ private object Z3ModelChecker extends LazyLogging { def getZ3Version: String = { val (out, ret) = executeCmd("-version") @@ -164,14 +168,15 @@ private object Z3ModelChecker extends LazyLogging { } def bmc(testDir: File, main: String, kmax: Int): MCResult = { - assert(kmax >=0 && kmax < 50, "Trying to keep kmax in a reasonable range.") + assert(kmax >= 0 && kmax < 50, "Trying to keep kmax in a reasonable range.") val smtFile = new File(testDir, main + ".smt2") val header = read(smtFile) val steps = (0 to kmax).map(k => new File(testDir, main + s"_step$k.smt2")).zipWithIndex - steps.foreach { case (f,k) => - writeStep(f, main, header, k) - val success = executeStep(f.getAbsolutePath) - if(!success) return MCFail(k) + steps.foreach { + case (f, k) => + writeStep(f, main, header, k) + val success = executeStep(f.getAbsolutePath) + if (!success) return MCFail(k) } MCSuccess } @@ -200,21 +205,22 @@ private object Z3ModelChecker extends LazyLogging { private def step(main: String, k: Int): Iterable[String] = { // define all states (0 to k).map(ii => s"(declare-fun s$ii () $main$StateTpe)") ++ - // assert that init holds in state 0 - List(s"(assert ($main$Init s0))") ++ - // assert transition relation - (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii+1}))") ++ - // assert that assumptions hold in all states - (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ - // assert that assertions hold for all but last state - (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ - // check to see if we can violate the assertions in the last state - List(s"(assert (not ($main$Asserts s$k)))") + // assert that init holds in state 0 + List(s"(assert ($main$Init s0))") ++ + // assert transition relation + (0 until k).map(ii => s"(assert ($main$Transition s$ii s${ii + 1}))") ++ + // assert that assumptions hold in all states + (0 to k).map(ii => s"(assert ($main$Assumes s$ii))") ++ + // assert that assertions hold for all but last state + (0 until k).map(ii => s"(assert ($main$Asserts s$ii))") ++ + // check to see if we can violate the assertions in the last state + List(s"(assert (not ($main$Asserts s$k)))") } private def read(f: File): Iterable[String] = { val source = scala.io.Source.fromFile(f) - try source.getLines().toVector finally source.close() + try source.getLines().toVector + finally source.close() } // the following suffixes have to match the ones in [[SMTTransitionSystemEncoder]] diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala index 10de9cda..61e1f0f8 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/MemorySpec.scala @@ -9,43 +9,43 @@ class MemorySpec extends EndToEndSMTBaseSpec { registeredTestMem(name, cmds.split("\n"), readUnderWrite) private def registeredTestMem(name: String, cmds: Iterable[String], readUnderWrite: String): String = s"""circuit $name: - | module $name: - | input reset : UInt<1> - | input clock : Clock - | input preset: AsyncReset - | input write_addr : UInt<5> - | input read_addr : UInt<5> - | input in : UInt<8> - | output out : UInt<8> - | - | mem m: - | data-type => UInt<8> - | depth => 32 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => $readUnderWrite - | - | m.w.clk <= clock - | m.w.mask <= UInt(1) - | m.w.en <= UInt(1) - | m.w.data <= in - | m.w.addr <= write_addr - | - | m.r.clk <= clock - | m.r.en <= UInt(1) - | out <= m.r.data - | m.r.addr <= read_addr - | - | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) - | cycle <= add(cycle, UInt(1)) - | node past_valid = geq(cycle, UInt(1)) - | - | ${cmds.mkString("\n ")} - |""".stripMargin + | module $name: + | input reset : UInt<1> + | input clock : Clock + | input preset: AsyncReset + | input write_addr : UInt<5> + | input read_addr : UInt<5> + | input in : UInt<8> + | output out : UInt<8> + | + | mem m: + | data-type => UInt<8> + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => $readUnderWrite + | + | m.w.clk <= clock + | m.w.mask <= UInt(1) + | m.w.en <= UInt(1) + | m.w.data <= in + | m.w.addr <= write_addr + | + | m.r.clk <= clock + | m.r.en <= UInt(1) + | out <= m.r.data + | m.r.addr <= read_addr + | + | reg cycle: UInt<8>, clock with: (reset => (preset, UInt(0))) + | cycle <= add(cycle, UInt(1)) + | node past_valid = geq(cycle, UInt(1)) + | + | ${cmds.mkString("\n ")} + |""".stripMargin - "Registered test memory" should "return written data after two cycles" taggedAs(RequiresZ3) in { + "Registered test memory" should "return written data after two cycles" taggedAs (RequiresZ3) in { val cmds = """node past_past_valid = geq(cycle, UInt(2)) |reg past_in: UInt<8>, clock @@ -85,23 +85,29 @@ class MemorySpec extends EndToEndSMTBaseSpec { |""".stripMargin private def m(num: Int) = CircuitTarget(s"Mem0$num").module(s"Mem0$num").ref("m") - "read-only memory" should "always return 0" taggedAs(RequiresZ3) in { - test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax=2, - annos=Seq(MemoryScalarInitAnnotation(m(1), 0))) + "read-only memory" should "always return 0" taggedAs (RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(0))", 1), MCSuccess, kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(1), 0))) } - "read-only memory" should "not always return 1" taggedAs(RequiresZ3) in { - test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax=2, - annos=Seq(MemoryScalarInitAnnotation(m(2), 0))) + "read-only memory" should "not always return 1" taggedAs (RequiresZ3) in { + test(readOnlyMem("eq(out, UInt(1))", 2), MCFail(0), kmax = 2, annos = Seq(MemoryScalarInitAnnotation(m(2), 0))) } - "read-only memory" should "always return 1 or 2" taggedAs(RequiresZ3) in { - test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), MCSuccess, kmax=2, - annos=Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1)))) + "read-only memory" should "always return 1 or 2" taggedAs (RequiresZ3) in { + test( + readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 3), + MCSuccess, + kmax = 2, + annos = Seq(MemoryArrayInitAnnotation(m(3), Seq(1, 2, 2, 1))) + ) } - "read-only memory" should "not always return 1 or 2 or 3" taggedAs(RequiresZ3) in { - test(readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), MCFail(0), kmax=2, - annos=Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3)))) + "read-only memory" should "not always return 1 or 2 or 3" taggedAs (RequiresZ3) in { + test( + readOnlyMem("or(eq(out, UInt(1)), eq(out, UInt(2)))", 4), + MCFail(0), + kmax = 2, + annos = Seq(MemoryArrayInitAnnotation(m(4), Seq(1, 2, 2, 3))) + ) } } diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala index cbf194dd..8ece0e23 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/SMTCompilationTest.scala @@ -13,15 +13,17 @@ import scala.sys.process.{Process, ProcessLogger} /** compiles the regression tests to SMTLib and parses the result with z3 */ class SMTCompilationTest extends AnyFlatSpec with LazyLogging { - it should "generate valid SMTLib for AddNot" taggedAs(RequiresZ3) in { compileAndParse("AddNot") } - it should "generate valid SMTLib for FPU" taggedAs(RequiresZ3) in { compileAndParse("FPU") } + it should "generate valid SMTLib for AddNot" taggedAs (RequiresZ3) in { compileAndParse("AddNot") } + it should "generate valid SMTLib for FPU" taggedAs (RequiresZ3) in { compileAndParse("FPU") } // we get a stack overflow in Scala 2.11 because of a deeply nested and(...) expression in the sequencer - it should "generate valid SMTLib for HwachaSequencer" taggedAs(RequiresZ3) ignore { compileAndParse("HwachaSequencer") } - it should "generate valid SMTLib for ICache" taggedAs(RequiresZ3) in { compileAndParse("ICache") } - it should "generate valid SMTLib for Ops" taggedAs(RequiresZ3) in { compileAndParse("Ops") } + it should "generate valid SMTLib for HwachaSequencer" taggedAs (RequiresZ3) ignore { + compileAndParse("HwachaSequencer") + } + it should "generate valid SMTLib for ICache" taggedAs (RequiresZ3) in { compileAndParse("ICache") } + it should "generate valid SMTLib for Ops" taggedAs (RequiresZ3) in { compileAndParse("Ops") } // TODO: enable Rob test once we support more than 2 write ports on a memory - it should "generate valid SMTLib for Rob" taggedAs(RequiresZ3) ignore { compileAndParse("Rob") } - it should "generate valid SMTLib for RocketCore" taggedAs(RequiresZ3) in { compileAndParse("RocketCore") } + it should "generate valid SMTLib for Rob" taggedAs (RequiresZ3) ignore { compileAndParse("Rob") } + it should "generate valid SMTLib for RocketCore" taggedAs (RequiresZ3) in { compileAndParse("RocketCore") } private def compileAndParse(name: String): Unit = { val testDir = BackendCompilationUtilities.createTestDirectory(name + "-smt") @@ -29,14 +31,18 @@ class SMTCompilationTest extends AnyFlatSpec with LazyLogging { BackendCompilationUtilities.copyResourceToFile(s"/regress/${name}.fir", inputFile) val args = Array( - "-ll", "error", // surpress warnings to keep test output clean - "--target-dir", testDir.toString, - "-i", inputFile.toString, - "-E", "experimental-smt2" + "-ll", + "error", // surpress warnings to keep test output clean + "--target-dir", + testDir.toString, + "-i", + inputFile.toString, + "-E", + "experimental-smt2" // "-fct", "firrtl.backends.experimental.smt.StutteringClockTransform" ) val res = (new FirrtlStage).execute(args, Seq()) - val fileName = res.collectFirst{ case OutputFileAnnotation(file) => file }.get + val fileName = res.collectFirst { case OutputFileAnnotation(file) => file }.get val smtFile = testDir.toString + "/" + fileName + ".smt2" val log = ProcessLogger(_ => (), logger.error(_)) diff --git a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala index 8fa80b4c..8682c2ce 100644 --- a/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala +++ b/src/test/scala/firrtl/backends/experimental/smt/end2end/UndefinedFirrtlSpec.scala @@ -5,27 +5,26 @@ package firrtl.backends.experimental.smt.end2end /** undefined values in firrtl are modelled as fresh auxiliary variables (inputs) */ class UndefinedFirrtlSpec extends EndToEndSMTBaseSpec { - "division by zero" should "result in an arbitrary value" taggedAs(RequiresZ3) in { + "division by zero" should "result in an arbitrary value" taggedAs (RequiresZ3) in { // the SMTLib spec defines the result of division by zero to be all 1s // https://cs.nyu.edu/pipermail/smt-lib/2015/000977.html def in(dEq: Int) = - s"""circuit CC00: - | module CC00: - | input c: Clock - | input a: UInt<2> - | input b: UInt<2> - | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") - | node d = div(a, b) - | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") - |""".stripMargin + s"""circuit CC00: + | module CC00: + | input c: Clock + | input a: UInt<2> + | input b: UInt<2> + | assume(c, eq(b, UInt(0)), UInt(1), "b = 0") + | node d = div(a, b) + | assert(c, eq(d, UInt($dEq)), UInt(1), "d = $dEq") + |""".stripMargin // we try to assert that (d = a / 0) is any fixed value which should be false (0 until 4).foreach { ii => test(in(ii), MCFail(0), 0, s"d = a / 0 = $ii") } } // TODO: rem should probably also be undefined, but the spec isn't 100% clear here - - "invalid signals" should "have an arbitrary values" taggedAs(RequiresZ3) in { + "invalid signals" should "have an arbitrary values" taggedAs (RequiresZ3) in { def in(aEq: Int) = s"""circuit CC00: | module CC00: diff --git a/src/test/scala/firrtl/ir/StructuralHashSpec.scala b/src/test/scala/firrtl/ir/StructuralHashSpec.scala index 17fe0b84..c4622939 100644 --- a/src/test/scala/firrtl/ir/StructuralHashSpec.scala +++ b/src/test/scala/firrtl/ir/StructuralHashSpec.scala @@ -6,11 +6,11 @@ import firrtl.PrimOps._ import org.scalatest.flatspec.AnyFlatSpec class StructuralHashSpec extends AnyFlatSpec { - private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) - private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) + private def hash(n: DefModule): HashCode = StructuralHash.sha256(n, n => n) + private def hash(c: Circuit): HashCode = StructuralHash.sha256Node(c) private def hash(e: Expression): HashCode = StructuralHash.sha256Node(e) - private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) - private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) + private def hash(t: Type): HashCode = StructuralHash.sha256Node(t) + private def hash(s: Statement): HashCode = StructuralHash.sha256Node(s) private val highFirrtlCompiler = new firrtl.stage.transforms.Compiler( targets = firrtl.stage.Forms.HighForm ) @@ -24,18 +24,18 @@ class StructuralHashSpec extends AnyFlatSpec { highFirrtlCompiler.transform(firrtl.CircuitState(rawFirrtl, Seq())).circuit } - private val b0 = UIntLiteral(0,IntWidth(1)) - private val b1 = UIntLiteral(1,IntWidth(1)) + private val b0 = UIntLiteral(0, IntWidth(1)) + private val b1 = UIntLiteral(1, IntWidth(1)) private val add = DoPrim(Add, Seq(b0, b1), Seq(), UnknownType) it should "generate the same hash if the objects are structurally the same" in { - assert(hash(b0) == hash(UIntLiteral(0,IntWidth(1)))) - assert(hash(b0) != hash(UIntLiteral(1,IntWidth(1)))) - assert(hash(b0) != hash(UIntLiteral(1,IntWidth(2)))) + assert(hash(b0) == hash(UIntLiteral(0, IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1, IntWidth(1)))) + assert(hash(b0) != hash(UIntLiteral(1, IntWidth(2)))) - assert(hash(b1) == hash(UIntLiteral(1,IntWidth(1)))) - assert(hash(b1) != hash(UIntLiteral(0,IntWidth(1)))) - assert(hash(b1) != hash(UIntLiteral(1,IntWidth(2)))) + assert(hash(b1) == hash(UIntLiteral(1, IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(0, IntWidth(1)))) + assert(hash(b1) != hash(UIntLiteral(1, IntWidth(2)))) } it should "ignore expression types" in { @@ -84,16 +84,19 @@ class StructuralHashSpec extends AnyFlatSpec { |""".stripMargin assert(hash(parse(a)) != hash(parse(d)), "circuits with different names are always different") - assert(hash(parse(a).modules.head) == hash(parse(d).modules.head), - "modules with different names can be structurally different") + assert( + hash(parse(a).modules.head) == hash(parse(d).modules.head), + "modules with different names can be structurally different" + ) // for the Dedup pass we do need a way to take the port names into account - assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), - "renaming ports does affect the hash if we ask to") + assert( + StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "renaming ports does affect the hash if we ask to" + ) } - it should "not ignore port names if asked to" in { val e = """circuit a: @@ -119,14 +122,20 @@ class StructuralHashSpec extends AnyFlatSpec { | z <= x |""".stripMargin - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), - "renaming ports does affect the hash if we ask to") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == - StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), - "renaming internal wires should never affect the hash") - assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), - "renaming internal wires should never affect the hash") + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming ports does affect the hash if we ask to" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wires should never affect the hash" + ) + assert( + hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wires should never affect the hash" + ) } it should "not ignore port bundle names if asked to" in { @@ -154,19 +163,26 @@ class StructuralHashSpec extends AnyFlatSpec { | y.z <= x.x |""".stripMargin - assert(hash(parse(e).modules.head) == hash(parse(f).modules.head), - "renaming port bundles does normally not affect the hash") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), - "renaming port bundles does affect the hash if we ask to") - assert(StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == - StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), - "renaming internal wire bundles should never affect the hash") - assert(hash(parse(e).modules.head) == hash(parse(g).modules.head), - "renaming internal wire bundles should never affect the hash") + assert( + hash(parse(e).modules.head) == hash(parse(f).modules.head), + "renaming port bundles does normally not affect the hash" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(f).modules.head), + "renaming port bundles does affect the hash if we ask to" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(e).modules.head) == + StructuralHash.sha256WithSignificantPortNames(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash" + ) + assert( + hash(parse(e).modules.head) == hash(parse(g).modules.head), + "renaming internal wire bundles should never affect the hash" + ) } - it should "fail on Info" in { // it does not make sense to hash Info nodes assertThrows[RuntimeException] { @@ -178,9 +194,9 @@ class StructuralHashSpec extends AnyFlatSpec { def parse(str: String): BundleType = { val src = s"""circuit c: - | module c: - | input z: $str - |""".stripMargin + | module c: + | input z: $str + |""".stripMargin val c = firrtl.Parser.parse(src) val tpe = c.modules.head.ports.head.tpe tpe.asInstanceOf[BundleType] @@ -219,11 +235,15 @@ class StructuralHashSpec extends AnyFlatSpec { // Q: should extmodule portnames always be significant since they map to the verilog pins? // A: It would be a bug for two exmodules in the same circuit to have the same defname but different // port names. This should be detected by an earlier pass and thus we do not have to deal with that situation. - assert(hash(parse(a).modules.head) == hash(parse(b).modules.head), - "two ext modules with the same defname and the same type and number of ports") - assert(StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != - StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), - "two ext modules with significant port names") + assert( + hash(parse(a).modules.head) == hash(parse(b).modules.head), + "two ext modules with the same defname and the same type and number of ports" + ) + assert( + StructuralHash.sha256WithSignificantPortNames(parse(a).modules.head) != + StructuralHash.sha256WithSignificantPortNames(parse(b).modules.head), + "two ext modules with significant port names" + ) } "Blocks and empty statements" should "not affect structural equivalence" in { @@ -269,9 +289,9 @@ class StructuralHashSpec extends AnyFlatSpec { } private case object DebugHasher extends Hasher { - override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") - override def update(i: Int): Unit = println(s"i(${i})") - override def update(l: Long): Unit = println(s"l(${l})") - override def update(s: String): Unit = println(s"s(${s})") + override def update(b: Byte): Unit = println(s"b(${b.toInt & 0xff})") + override def update(i: Int): Unit = println(s"i(${i})") + override def update(l: Long): Unit = println(s"l(${l})") + override def update(s: String): Unit = println(s"s(${s})") override def update(b: Array[Byte]): Unit = println(s"bytes(${b.map(x => x.toInt & 0xff).mkString(", ")})") -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala index 884e51b8..0575c5da 100644 --- a/src/test/scala/firrtl/passes/LowerTypesSpec.scala +++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala @@ -8,7 +8,6 @@ import firrtl.stage.TransformManager import firrtl.stage.TransformManager.TransformDependency import org.scalatest.flatspec.AnyFlatSpec - /** Unit test style tests for [[LowerTypes]]. * You can find additional integration style tests in [[firrtlTests.LowerTypesSpec]] */ @@ -31,11 +30,12 @@ class LowerTypesEndToEndSpec extends LowerTypesBaseSpec { | $n is invalid |""".stripMargin val c = CircuitState(firrtl.Parser.parse(src), Seq()) - val c2 = lowerTypesCompiler.execute(c) + val c2 = lowerTypesCompiler.execute(c) val ps = c2.circuit.modules.head.ports.filterNot(p => namespace.contains(p.name)) - ps.map{p => + ps.map { p => val orientation = Utils.to_flip(p.direction) - s"${orientation.serialize}${p.name} : ${p.tpe.serialize}"} + s"${orientation.serialize}${p.name} : ${p.tpe.serialize}" + } } override protected def lower(n: String, tpe: String, namespace: Set[String]): Seq[String] = @@ -50,8 +50,10 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec { assert(lower("a", "{ a : UInt<1>, b : UInt<1>}") == Seq("a_a : UInt<1>", "a_b : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}") == Seq("a_a : UInt<1>", "a_b_c : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}") == Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) // with conflicts assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_a")) == Seq("a__a : UInt<1>", "a__b : UInt<1>")) @@ -63,40 +65,71 @@ abstract class LowerTypesBaseSpec extends AnyFlatSpec { assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>")) assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b_c")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) == - Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>")) - - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) == - Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) == - Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) == - Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) == + Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>") + ) + + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) == + Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) == + Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) == + Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>") + ) // collisions inside the bundle - assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>")) - - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>")) - assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") == - Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>")) + assert( + lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>") + ) + + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>") + ) + assert( + lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") == + Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>") + ) } it should "correctly lower the orientation" in { assert(lower("a", "{ flip a : UInt<1>, b : UInt<1>}") == Seq("flip a_a : UInt<1>", "a_b : UInt<1>")) - assert(lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") == - Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>")) - assert(lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") == - Seq("flip a_a_0_c : UInt<1>", "a_a_0_d : UInt<1>", "flip a_a_1_c : UInt<1>", "a_a_1_d : UInt<1>", "a_b : UInt<1>") + assert( + lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") == + Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>") + ) + assert( + lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") == + Seq( + "flip a_a_0_c : UInt<1>", + "a_a_0_d : UInt<1>", + "flip a_a_1_c : UInt<1>", + "a_a_1_d : UInt<1>", + "a_b : UInt<1>" + ) ) } } @@ -121,43 +154,45 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { def one(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set(m.ref(prefix + "b"))) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b"))) } one(Set(), "a_") one(Set("a_a"), "a__") def two(namespace: Set[String], prefix: String): Unit = { - val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set(m.ref(prefix + "b_c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref(prefix + "b_c"))) + val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref(prefix + "b_c"))) } two(Set(), "a_") two(Set("a_a"), "a__") def three(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", namespace) - assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) - assert(get(r,a.field("a")) == Set(m.ref(prefix + "a"))) - assert(get(r,a.field("b")) == Set( m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) - assert(get(r,a.field("b").index(0)) == Set(m.ref(prefix + "b_0"))) - assert(get(r,a.field("b").index(1)) == Set(m.ref(prefix + "b_1"))) + assert(get(r, a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) + assert(get(r, a.field("a")) == Set(m.ref(prefix + "a"))) + assert(get(r, a.field("b")) == Set(m.ref(prefix + "b_0"), m.ref(prefix + "b_1"))) + assert(get(r, a.field("b").index(0)) == Set(m.ref(prefix + "b_0"))) + assert(get(r, a.field("b").index(1)) == Set(m.ref(prefix + "b_1"))) } three(Set(), "a_") three(Set("a_b_0"), "a__") def four(namespace: Set[String], prefix: String): Unit = { val r = lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", namespace) - assert(get(r,a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b"))) - assert(get(r,a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b"))) - assert(get(r,a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b"))) - assert(get(r,a.index(0).field("a")) == Set(m.ref(prefix + "0_a"))) - assert(get(r,a.index(0).field("b")) == Set(m.ref(prefix + "0_b"))) - assert(get(r,a.index(1).field("a")) == Set(m.ref(prefix + "1_a"))) - assert(get(r,a.index(1).field("b")) == Set(m.ref(prefix + "1_b"))) + assert( + get(r, a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b")) + ) + assert(get(r, a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b"))) + assert(get(r, a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b"))) + assert(get(r, a.index(0).field("a")) == Set(m.ref(prefix + "0_a"))) + assert(get(r, a.index(0).field("b")) == Set(m.ref(prefix + "0_b"))) + assert(get(r, a.index(1).field("a")) == Set(m.ref(prefix + "1_a"))) + assert(get(r, a.index(1).field("b")) == Set(m.ref(prefix + "1_b"))) } four(Set(), "a_") four(Set("a_0"), "a__") @@ -166,28 +201,28 @@ class LowerTypesRenamingSpec extends AnyFlatSpec { // collisions inside the bundle { val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b__c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b__c"))) - assert(get(r,a.field("b_c")) == Set(m.ref("a_b_c"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b__c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b__c"))) + assert(get(r, a.field("b_c")) == Set(m.ref("a_b_c"))) } { val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b_c"))) - assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b_c"))) - assert(get(r,a.field("b_b")) == Set(m.ref("a_b_b"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b_c"))) + assert(get(r, a.field("b").field("c")) == Set(m.ref("a_b_c"))) + assert(get(r, a.field("b_b")) == Set(m.ref("a_b_b"))) } { val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") - assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0"))) - assert(get(r,a.field("a")) == Set(m.ref("a_a"))) - assert(get(r,a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1"))) - assert(get(r,a.field("b").index(0)) == Set(m.ref("a_b__0"))) - assert(get(r,a.field("b").index(1)) == Set(m.ref("a_b__1"))) - assert(get(r,a.field("b_0")) == Set(m.ref("a_b_0"))) + assert(get(r, a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0"))) + assert(get(r, a.field("a")) == Set(m.ref("a_a"))) + assert(get(r, a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1"))) + assert(get(r, a.field("b").index(0)) == Set(m.ref("a_b__0"))) + assert(get(r, a.field("b").index(1)) == Set(m.ref("a_b__1"))) + assert(get(r, a.field("b_0")) == Set(m.ref("a_b_0"))) } } } @@ -199,8 +234,13 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { private val m = CircuitTarget("m").module("m") def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] = res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}") - private def lower(n: String, tpe: String, module: String, namespace: Set[String], renames: RenameMap = RenameMap()): - Lower = { + private def lower( + n: String, + tpe: String, + module: String, + namespace: Set[String], + renames: RenameMap = RenameMap() + ): Lower = { val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe)) val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set()) @@ -269,7 +309,7 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c"))) assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c"))) - } + } } } @@ -278,101 +318,139 @@ class LowerTypesOfInstancesSpec extends AnyFlatSpec { */ class LowerTypesOfMemorySpec extends AnyFlatSpec { import LowerTypesSpecUtils._ - private case class Lower(mems: Seq[firrtl.ir.DefMemory], refs: Seq[(String, firrtl.ir.SubField)], - renameMap: RenameMap) + private case class Lower( + mems: Seq[firrtl.ir.DefMemory], + refs: Seq[(String, firrtl.ir.SubField)], + renameMap: RenameMap) private val m = CircuitTarget("m").module("m") private val mem = m.ref("mem") - private def lower(name: String, tpe: String, namespace: Set[String], - r: Seq[String] = List("r"), w: Seq[String] = List("w"), rw: Seq[String] = List(), depth: Int = 2): Lower = { + private def lower( + name: String, + tpe: String, + namespace: Set[String], + r: Seq[String] = List("r"), + w: Seq[String] = List("w"), + rw: Seq[String] = List(), + depth: Int = 2 + ): Lower = { val dataType = parseType(tpe) - val mem = firrtl.ir.DefMemory(firrtl.ir.NoInfo, name, dataType, depth = depth, writeLatency = 1, readLatency = 1, - readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, readers = r, writers = w, readwriters = rw) + val mem = firrtl.ir.DefMemory( + firrtl.ir.NoInfo, + name, + dataType, + depth = depth, + writeLatency = 1, + readLatency = 1, + readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, + readers = r, + writers = w, + readwriters = rw + ) val renames = RenameMap() val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace - val(mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set()) + val (mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set()) Lower(mems, refs, renames) } private val UInt1 = firrtl.ir.UIntType(firrtl.ir.IntWidth(1)) it should "not rename anything for a ground type memory if there was no conflict" in { - val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")) + val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")) assert(l.renameMap.underlying.isEmpty) } it should "still produce reference lookups, even for a ground type memory with no conflicts" in { - val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")).refs - .map{case (n,r) => n -> r.serialize}.toSet - - assert(nameToRef == Set( - "mem.r.clk" -> "mem.r.clk", - "mem.r.en" -> "mem.r.en", - "mem.r.addr" -> "mem.r.addr", - "mem.r.data" -> "mem.r.data", - "mem.w.clk" -> "mem.w.clk", - "mem.w.en" -> "mem.w.en", - "mem.w.addr" -> "mem.w.addr", - "mem.w.data" -> "mem.w.data", - "mem.w.mask" -> "mem.w.mask" - )) + val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("w")).refs.map { + case (n, r) => n -> r.serialize + }.toSet + + assert( + nameToRef == Set( + "mem.r.clk" -> "mem.r.clk", + "mem.r.en" -> "mem.r.en", + "mem.r.addr" -> "mem.r.addr", + "mem.r.data" -> "mem.r.data", + "mem.w.clk" -> "mem.w.clk", + "mem.w.en" -> "mem.w.en", + "mem.w.addr" -> "mem.w.addr", + "mem.w.data" -> "mem.w.data", + "mem.w.mask" -> "mem.w.mask" + ) + ) } it should "produce references of correct type" in { - val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w=Seq("w"), depth = 3).refs - .map{case (n,r) => n -> r.tpe.serialize}.toSet - - assert(nameToType == Set( - "mem.r.clk" -> "Clock", - "mem.r.en" -> "UInt<1>", - "mem.r.addr" -> "UInt<2>", // depth = 3 - "mem.r.data" -> "UInt<4>", - "mem.w.clk" -> "Clock", - "mem.w.en" -> "UInt<1>", - "mem.w.addr" -> "UInt<2>", - "mem.w.data" -> "UInt<4>", - "mem.w.mask" -> "UInt<1>" - )) + val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w = Seq("w"), depth = 3).refs.map { + case (n, r) => n -> r.tpe.serialize + }.toSet + + assert( + nameToType == Set( + "mem.r.clk" -> "Clock", + "mem.r.en" -> "UInt<1>", + "mem.r.addr" -> "UInt<2>", // depth = 3 + "mem.r.data" -> "UInt<4>", + "mem.w.clk" -> "Clock", + "mem.w.en" -> "UInt<1>", + "mem.w.addr" -> "UInt<2>", + "mem.w.data" -> "UInt<4>", + "mem.w.mask" -> "UInt<1>" + ) + ) } it should "not rename ground type memories even if there are conflicts on the ports" in { // There actually isn't such a thing as conflicting ports, because they do not get flattened by LowerTypes. - val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("r_data")).renameMap + val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w = Seq("r_data")).renameMap assert(r.underlying.isEmpty) } it should "rename references to lowered ports" in { - val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r=Seq("r", "r_data")).renameMap + val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r = Seq("r", "r_data")).renameMap // complete memory assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b"))) // read ports - assert(get(r, mem.field("r")) == - Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r"))) - assert(get(r, mem.field("r_data")) == - Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r")) + ) + assert( + get(r, mem.field("r_data")) == + Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data")) + ) // port fields - assert(get(r, mem.field("r").field("data")) == - Set(m.ref("mem__a").field("r").field("data"), - m.ref("mem__b").field("r").field("data"))) - assert(get(r, mem.field("r").field("addr")) == - Set(m.ref("mem__a").field("r").field("addr"), - m.ref("mem__b").field("r").field("addr"))) - assert(get(r, mem.field("r").field("en")) == - Set(m.ref("mem__a").field("r").field("en"), - m.ref("mem__b").field("r").field("en"))) - assert(get(r, mem.field("r").field("clk")) == - Set(m.ref("mem__a").field("r").field("clk"), - m.ref("mem__b").field("r").field("clk"))) - assert(get(r, mem.field("w").field("mask")) == - Set(m.ref("mem__a").field("w").field("mask"), - m.ref("mem__b").field("w").field("mask"))) + assert( + get(r, mem.field("r").field("data")) == + Set(m.ref("mem__a").field("r").field("data"), m.ref("mem__b").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("addr")) == + Set(m.ref("mem__a").field("r").field("addr"), m.ref("mem__b").field("r").field("addr")) + ) + assert( + get(r, mem.field("r").field("en")) == + Set(m.ref("mem__a").field("r").field("en"), m.ref("mem__b").field("r").field("en")) + ) + assert( + get(r, mem.field("r").field("clk")) == + Set(m.ref("mem__a").field("r").field("clk"), m.ref("mem__b").field("r").field("clk")) + ) + assert( + get(r, mem.field("w").field("mask")) == + Set(m.ref("mem__a").field("w").field("mask"), m.ref("mem__b").field("w").field("mask")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").field("a")) == - Set(m.ref("mem__a").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b")) == - Set(m.ref("mem__b").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").field("a")) == + Set(m.ref("mem__a").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b")) == + Set(m.ref("mem__b").field("r").field("data")) + ) // need to rename the following: // mem -> mem__a, mem__b @@ -395,24 +473,38 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b_c"))) // read port - assert(get(r, mem.field("r")) == - Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").field("a")) == - Set(m.ref("mem__a").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b")) == - Set(m.ref("mem__b_c").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").field("b").field("c")) == - Set(m.ref("mem__b_c").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").field("a")) == + Set(m.ref("mem__a").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b")) == + Set(m.ref("mem__b_c").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").field("b").field("c")) == + Set(m.ref("mem__b_c").field("r").field("data")) + ) // the mask field needs to be lowered just like the data field - assert(get(r, mem.field("w").field("mask").field("a")) == - Set(m.ref("mem__a").field("w").field("mask"))) - assert(get(r, mem.field("w").field("mask").field("b")) == - Set(m.ref("mem__b_c").field("w").field("mask"))) - assert(get(r, mem.field("w").field("mask").field("b").field("c")) == - Set(m.ref("mem__b_c").field("w").field("mask"))) + assert( + get(r, mem.field("w").field("mask").field("a")) == + Set(m.ref("mem__a").field("w").field("mask")) + ) + assert( + get(r, mem.field("w").field("mask").field("b")) == + Set(m.ref("mem__b_c").field("w").field("mask")) + ) + assert( + get(r, mem.field("w").field("mask").field("b").field("c")) == + Set(m.ref("mem__b_c").field("w").field("mask")) + ) val renameCount = r.underlying.map(_._2.size).sum assert(renameCount == 11, "it is enough to rename *to* 11 different signals") @@ -420,66 +512,89 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { } it should "return a name to RefLikeExpression map for a memory with a nested data type" in { - val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs - .map{case (n,r) => n -> r.serialize}.toSet - - assert(nameToRef == Set( - // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. - // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. - "mem.r.clk" -> "mem__a.r.clk", "mem.r.clk" -> "mem__b_c.r.clk", - "mem.r.en" -> "mem__a.r.en", "mem.r.en" -> "mem__b_c.r.en", - "mem.r.addr" -> "mem__a.r.addr", "mem.r.addr" -> "mem__b_c.r.addr", - "mem.w.clk" -> "mem__a.w.clk", "mem.w.clk" -> "mem__b_c.w.clk", - "mem.w.en" -> "mem__a.w.en", "mem.w.en" -> "mem__b_c.w.en", - "mem.w.addr" -> "mem__a.w.addr", "mem.w.addr" -> "mem__b_c.w.addr", - // Ground type references to the data or mask field are unique. - "mem.r.data.a" -> "mem__a.r.data", - "mem.w.data.a" -> "mem__a.w.data", - "mem.w.mask.a" -> "mem__a.w.mask", - "mem.r.data.b.c" -> "mem__b_c.r.data", - "mem.w.data.b.c" -> "mem__b_c.w.data", - "mem.w.mask.b.c" -> "mem__b_c.w.mask" - )) + val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs.map { + case (n, r) => n -> r.serialize + }.toSet + + assert( + nameToRef == Set( + // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. + // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. + "mem.r.clk" -> "mem__a.r.clk", + "mem.r.clk" -> "mem__b_c.r.clk", + "mem.r.en" -> "mem__a.r.en", + "mem.r.en" -> "mem__b_c.r.en", + "mem.r.addr" -> "mem__a.r.addr", + "mem.r.addr" -> "mem__b_c.r.addr", + "mem.w.clk" -> "mem__a.w.clk", + "mem.w.clk" -> "mem__b_c.w.clk", + "mem.w.en" -> "mem__a.w.en", + "mem.w.en" -> "mem__b_c.w.en", + "mem.w.addr" -> "mem__a.w.addr", + "mem.w.addr" -> "mem__b_c.w.addr", + // Ground type references to the data or mask field are unique. + "mem.r.data.a" -> "mem__a.r.data", + "mem.w.data.a" -> "mem__a.w.data", + "mem.w.mask.a" -> "mem__a.w.mask", + "mem.r.data.b.c" -> "mem__b_c.r.data", + "mem.w.data.b.c" -> "mem__b_c.w.data", + "mem.w.mask.b.c" -> "mem__b_c.w.mask" + ) + ) } it should "produce references of correct type for memories with a read/write port" in { - val refs = lower("mem", "{ a : UInt<3>, b : { c : UInt<4>} }", Set("mem_a"), - r=Seq(), w=Seq(), rw=Seq("rw"), depth = 3).refs - val nameToRef = refs.map{case (n,r) => n -> r.serialize}.toSet - val nameToType = refs.map{case (n,r) => n -> r.tpe.serialize}.toSet - - assert(nameToRef == Set( - // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. - // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. - "mem.rw.clk" -> "mem__a.rw.clk", "mem.rw.clk" -> "mem__b_c.rw.clk", - "mem.rw.en" -> "mem__a.rw.en", "mem.rw.en" -> "mem__b_c.rw.en", - "mem.rw.addr" -> "mem__a.rw.addr", "mem.rw.addr" -> "mem__b_c.rw.addr", - "mem.rw.wmode" -> "mem__a.rw.wmode", "mem.rw.wmode" -> "mem__b_c.rw.wmode", - // Ground type references to the data or mask field are unique. - "mem.rw.rdata.a" -> "mem__a.rw.rdata", - "mem.rw.wdata.a" -> "mem__a.rw.wdata", - "mem.rw.wmask.a" -> "mem__a.rw.wmask", - "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata", - "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata", - "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask" - )) - - assert(nameToType == Set( - // - "mem.rw.clk" -> "Clock", - "mem.rw.en" -> "UInt<1>", - "mem.rw.addr" -> "UInt<2>", - "mem.rw.wmode" -> "UInt<1>", - // Ground type references to the data or mask field are unique. - "mem.rw.rdata.a" -> "UInt<3>", - "mem.rw.wdata.a" -> "UInt<3>", - "mem.rw.wmask.a" -> "UInt<1>", - "mem.rw.rdata.b.c" -> "UInt<4>", - "mem.rw.wdata.b.c" -> "UInt<4>", - "mem.rw.wmask.b.c" -> "UInt<1>" - )) - } + val refs = lower( + "mem", + "{ a : UInt<3>, b : { c : UInt<4>} }", + Set("mem_a"), + r = Seq(), + w = Seq(), + rw = Seq("rw"), + depth = 3 + ).refs + val nameToRef = refs.map { case (n, r) => n -> r.serialize }.toSet + val nameToType = refs.map { case (n, r) => n -> r.tpe.serialize }.toSet + + assert( + nameToRef == Set( + // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated. + // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do. + "mem.rw.clk" -> "mem__a.rw.clk", + "mem.rw.clk" -> "mem__b_c.rw.clk", + "mem.rw.en" -> "mem__a.rw.en", + "mem.rw.en" -> "mem__b_c.rw.en", + "mem.rw.addr" -> "mem__a.rw.addr", + "mem.rw.addr" -> "mem__b_c.rw.addr", + "mem.rw.wmode" -> "mem__a.rw.wmode", + "mem.rw.wmode" -> "mem__b_c.rw.wmode", + // Ground type references to the data or mask field are unique. + "mem.rw.rdata.a" -> "mem__a.rw.rdata", + "mem.rw.wdata.a" -> "mem__a.rw.wdata", + "mem.rw.wmask.a" -> "mem__a.rw.wmask", + "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata", + "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata", + "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask" + ) + ) + assert( + nameToType == Set( + // + "mem.rw.clk" -> "Clock", + "mem.rw.en" -> "UInt<1>", + "mem.rw.addr" -> "UInt<2>", + "mem.rw.wmode" -> "UInt<1>", + // Ground type references to the data or mask field are unique. + "mem.rw.rdata.a" -> "UInt<3>", + "mem.rw.wdata.a" -> "UInt<3>", + "mem.rw.wmask.a" -> "UInt<1>", + "mem.rw.rdata.b.c" -> "UInt<4>", + "mem.rw.wdata.b.c" -> "UInt<4>", + "mem.rw.wmask.b.c" -> "UInt<1>" + ) + ) + } it should "rename references for vector type memories" in { val l = lower("mem", "UInt<1>[2]", Set("mem_0")) @@ -491,14 +606,20 @@ class LowerTypesOfMemorySpec extends AnyFlatSpec { assert(get(r, mem) == Set(m.ref("mem__0"), m.ref("mem__1"))) // read port - assert(get(r, mem.field("r")) == - Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r"))) + assert( + get(r, mem.field("r")) == + Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r")) + ) // port sub-fields - assert(get(r, mem.field("r").field("data").index(0)) == - Set(m.ref("mem__0").field("r").field("data"))) - assert(get(r, mem.field("r").field("data").index(1)) == - Set(m.ref("mem__1").field("r").field("data"))) + assert( + get(r, mem.field("r").field("data").index(0)) == + Set(m.ref("mem__0").field("r").field("data")) + ) + assert( + get(r, mem.field("r").field("data").index(1)) == + Set(m.ref("mem__1").field("r").field("data")) + ) val renameCount = r.underlying.map(_._2.size).sum assert(renameCount == 8, "it is enough to rename *to* 8 different signals") diff --git a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala index 007608ca..0b9b830c 100644 --- a/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala +++ b/src/test/scala/firrtl/stage/phases/tests/DriverCompatibilitySpec.scala @@ -8,7 +8,14 @@ import java.io.File import firrtl._ import firrtl.stage.phases.DriverCompatibility._ import firrtl.options.{InputAnnotationFileAnnotation, Phase, TargetDirAnnotation} -import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation} +import firrtl.stage.{ + CompilerAnnotation, + FirrtlCircuitAnnotation, + FirrtlFileAnnotation, + FirrtlSourceAnnotation, + OutputFileAnnotation, + RunFirrtlTransformAnnotation +} import firrtl.stage.phases.DriverCompatibility import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -20,7 +27,7 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth /* This method wraps some magic that lets you use the private method DriverCompatibility.topName */ def topName(annotations: AnnotationSeq): Option[String] = { val topName = PrivateMethod[Option[String]]('topName) - DriverCompatibility invokePrivate topName(annotations) + DriverCompatibility.invokePrivate(topName(annotations)) } def simpleCircuit(main: String): String = s"""|circuit $main: @@ -41,22 +48,22 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth (FirrtlFileAnnotation("src/test/resources/integration/GCDTester.pb"), "GCDTester") ) - behavior of s"${DriverCompatibility.getClass.getName}.topName (private method)" + behavior.of(s"${DriverCompatibility.getClass.getName}.topName (private method)") /* This iterates over the tails of annosWithTops. Using the ordering of annosWithTops, if this AnnotationSeq is fed to * DriverCompatibility.topName, the head annotation will be used to determine the top name. This test ensures that * topName behaves as expected. */ - for ( t <- annosWithTops.tails ) t match { + for (t <- annosWithTops.tails) t match { case Nil => it should "return None on an empty AnnotationSeq" in { - topName(Seq.empty) should be (None) + topName(Seq.empty) should be(None) } case x => val annotations = x.map(_._1) val top = x.head._2 it should s"determine a top name ('$top') from a ${annotations.head.getClass.getName}" in { - topName(annotations).get should be (top) + topName(annotations).get should be(top) } } @@ -66,152 +73,148 @@ class DriverCompatibilitySpec extends AnyFlatSpec with Matchers with PrivateMeth file.createNewFile() } - behavior of classOf[AddImplicitAnnotationFile].toString + behavior.of(classOf[AddImplicitAnnotationFile].toString) val testDir = "test_run_dir/DriverCompatibilitySpec" it should "not modify the annotations if an InputAnnotationFile already exists" in - new PhaseFixture(new AddImplicitAnnotationFile) { + new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/foo.anno") - val annotations = Seq( - InputAnnotationFileAnnotation("bar.anno"), - TargetDirAnnotation(testDir), - TopNameAnnotation("foo") ) + createFile(testDir + "/foo.anno") + val annotations = + Seq(InputAnnotationFileAnnotation("bar.anno"), TargetDirAnnotation(testDir), TopNameAnnotation("foo")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "add an InputAnnotationFile based on a derived topName" in - new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/bar.anno") - val annotations = Seq( - TargetDirAnnotation(testDir), - TopNameAnnotation("bar") ) + new PhaseFixture(new AddImplicitAnnotationFile) { + createFile(testDir + "/bar.anno") + val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("bar")) - val expected = annotations.toSet + - InputAnnotationFileAnnotation(testDir + "/bar.anno") + val expected = annotations.toSet + + InputAnnotationFileAnnotation(testDir + "/bar.anno") - phase.transform(annotations).toSet should be (expected) - } + phase.transform(annotations).toSet should be(expected) + } it should "not add an InputAnnotationFile for .anno.json annotations" in - new PhaseFixture(new AddImplicitAnnotationFile) { - createFile(testDir + "/baz.anno.json") - val annotations = Seq( - TargetDirAnnotation(testDir), - TopNameAnnotation("baz") ) + new PhaseFixture(new AddImplicitAnnotationFile) { + createFile(testDir + "/baz.anno.json") + val annotations = Seq(TargetDirAnnotation(testDir), TopNameAnnotation("baz")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "not add an InputAnnotationFile if it cannot determine the topName" in - new PhaseFixture(new AddImplicitAnnotationFile) { - val annotations = Seq( TargetDirAnnotation(testDir) ) + new PhaseFixture(new AddImplicitAnnotationFile) { + val annotations = Seq(TargetDirAnnotation(testDir)) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } - behavior of classOf[AddImplicitFirrtlFile].toString + behavior.of(classOf[AddImplicitFirrtlFile].toString) it should "not modify the annotations if a CircuitOption is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( FirrtlFileAnnotation("foo"), TopNameAnnotation("bar") ) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(FirrtlFileAnnotation("foo"), TopNameAnnotation("bar")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } it should "add an FirrtlFileAnnotation if a TopNameAnnotation is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( TopNameAnnotation("foo") ) - val expected = annotations.toSet + - FirrtlFileAnnotation(new File("foo.fir").getPath()) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(TopNameAnnotation("foo")) + val expected = annotations.toSet + + FirrtlFileAnnotation(new File("foo.fir").getPath()) - phase.transform(annotations).toSet should be (expected) - } + phase.transform(annotations).toSet should be(expected) + } it should "do nothing if no TopNameAnnotation is present" in - new PhaseFixture(new AddImplicitFirrtlFile) { - val annotations = Seq( TargetDirAnnotation("foo") ) + new PhaseFixture(new AddImplicitFirrtlFile) { + val annotations = Seq(TargetDirAnnotation("foo")) - phase.transform(annotations).toSeq should be (annotations) - } + phase.transform(annotations).toSeq should be(annotations) + } - behavior of classOf[AddImplicitEmitter].toString + behavior.of(classOf[AddImplicitEmitter].toString) - val (nc, hfc, mfc, lfc, vc, svc) = ( new NoneCompiler, - new HighFirrtlCompiler, - new MiddleFirrtlCompiler, - new LowFirrtlCompiler, - new VerilogCompiler, - new SystemVerilogCompiler ) + val (nc, hfc, mfc, lfc, vc, svc) = ( + new NoneCompiler, + new HighFirrtlCompiler, + new MiddleFirrtlCompiler, + new LowFirrtlCompiler, + new VerilogCompiler, + new SystemVerilogCompiler + ) it should "convert CompilerAnnotations into EmitCircuitAnnotations without EmitOneFilePerModuleAnnotation" in - new PhaseFixture(new AddImplicitEmitter) { - val annotations = Seq( - CompilerAnnotation(nc), - CompilerAnnotation(hfc), - CompilerAnnotation(mfc), - CompilerAnnotation(lfc), - CompilerAnnotation(vc), - CompilerAnnotation(svc) - ) - val expected = annotations - .flatMap( a => Seq(a, - RunFirrtlTransformAnnotation(a.compiler.emitter), - EmitCircuitAnnotation(a.compiler.emitter.getClass)) ) - - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitEmitter) { + val annotations = Seq( + CompilerAnnotation(nc), + CompilerAnnotation(hfc), + CompilerAnnotation(mfc), + CompilerAnnotation(lfc), + CompilerAnnotation(vc), + CompilerAnnotation(svc) + ) + val expected = annotations + .flatMap(a => + Seq(a, RunFirrtlTransformAnnotation(a.compiler.emitter), EmitCircuitAnnotation(a.compiler.emitter.getClass)) + ) + + phase.transform(annotations).toSeq should be(expected) + } it should "convert CompilerAnnotations into EmitAllodulesAnnotation with EmitOneFilePerModuleAnnotation" in - new PhaseFixture(new AddImplicitEmitter) { - val annotations = Seq( - EmitOneFilePerModuleAnnotation, - CompilerAnnotation(nc), - CompilerAnnotation(hfc), - CompilerAnnotation(mfc), - CompilerAnnotation(lfc), - CompilerAnnotation(vc), - CompilerAnnotation(svc) - ) - val expected = annotations - .flatMap{ - case a: CompilerAnnotation => Seq(a, - RunFirrtlTransformAnnotation(a.compiler.emitter), - EmitAllModulesAnnotation(a.compiler.emitter.getClass)) + new PhaseFixture(new AddImplicitEmitter) { + val annotations = Seq( + EmitOneFilePerModuleAnnotation, + CompilerAnnotation(nc), + CompilerAnnotation(hfc), + CompilerAnnotation(mfc), + CompilerAnnotation(lfc), + CompilerAnnotation(vc), + CompilerAnnotation(svc) + ) + val expected = annotations.flatMap { + case a: CompilerAnnotation => + Seq( + a, + RunFirrtlTransformAnnotation(a.compiler.emitter), + EmitAllModulesAnnotation(a.compiler.emitter.getClass) + ) case a => Seq(a) } - phase.transform(annotations).toSeq should be (expected) - } + phase.transform(annotations).toSeq should be(expected) + } - behavior of classOf[AddImplicitOutputFile].toString + behavior.of(classOf[AddImplicitOutputFile].toString) it should "add an OutputFileAnnotation derived from a TopNameAnnotation if no OutputFileAnnotation exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq( TopNameAnnotation("foo") ) - val expected = Seq( - OutputFileAnnotation("foo"), - TopNameAnnotation("foo") - ) - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq(TopNameAnnotation("foo")) + val expected = Seq( + OutputFileAnnotation("foo"), + TopNameAnnotation("foo") + ) + phase.transform(annotations).toSeq should be(expected) + } it should "do nothing if an OutputFileannotation already exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq( - TopNameAnnotation("foo"), - OutputFileAnnotation("bar") ) - val expected = annotations - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq(TopNameAnnotation("foo"), OutputFileAnnotation("bar")) + val expected = annotations + phase.transform(annotations).toSeq should be(expected) + } it should "do nothing if no TopNameAnnotation exists" in - new PhaseFixture(new AddImplicitOutputFile) { - val annotations = Seq.empty - val expected = annotations - phase.transform(annotations).toSeq should be (expected) - } + new PhaseFixture(new AddImplicitOutputFile) { + val annotations = Seq.empty + val expected = annotations + phase.transform(annotations).toSeq should be(expected) + } } diff --git a/src/test/scala/firrtl/testutils/FirrtlSpec.scala b/src/test/scala/firrtl/testutils/FirrtlSpec.scala index dfc20352..a0c41085 100644 --- a/src/test/scala/firrtl/testutils/FirrtlSpec.scala +++ b/src/test/scala/firrtl/testutils/FirrtlSpec.scala @@ -46,11 +46,13 @@ object RenameTop extends Transform { val c = state.circuit val ns = Namespace(c) - val newTopName = state.annotations.collectFirst({ - case RenameTopAnnotation(name) => - require(ns.tryName(name)) - name - }).getOrElse(c.main) + val newTopName = state.annotations + .collectFirst({ + case RenameTopAnnotation(name) => + require(ns.tryName(name)) + name + }) + .getOrElse(c.main) state.annotations.collect { case ModuleNamespaceAnnotation(mustNotCollideNS) => require(mustNotCollideNS.tryName(newTopName)) @@ -70,6 +72,7 @@ object RenameTop extends Transform { trait FirrtlRunners extends BackendCompilationUtilities { val cppHarnessResourceName: String = "/firrtl/testTop.cpp" + /** Extra transforms to run by default */ val extraCheckTransforms = Seq(new CheckLowForm) @@ -80,10 +83,12 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param customAnnotations Optional Firrtl annotations * @param timesteps the maximum number of timesteps to consider */ - def firrtlEquivalenceTest(input: String, - customTransforms: Seq[Transform] = Seq.empty, - customAnnotations: AnnotationSeq = Seq.empty, - timesteps: Int = 1): Unit = { + def firrtlEquivalenceTest( + input: String, + customTransforms: Seq[Transform] = Seq.empty, + customAnnotations: AnnotationSeq = Seq.empty, + timesteps: Int = 1 + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val prefix = circuit.main val testDir = createTestDirectory(prefix + "_equivalence_test") @@ -93,12 +98,12 @@ trait FirrtlRunners extends BackendCompilationUtilities { def getBaseAnnos(topName: String) = { val baseTransforms = RenameTop +: extraCheckTransforms TargetDirAnnotation(testDir.toString) +: - InfoModeAnnotation("ignore") +: - RenameTopAnnotation(topName) +: - stage.FirrtlCircuitAnnotation(circuit) +: - stage.CompilerAnnotation("mverilog") +: - stage.OutputFileAnnotation(topName) +: - toAnnos(baseTransforms) + InfoModeAnnotation("ignore") +: + RenameTopAnnotation(topName) +: + stage.FirrtlCircuitAnnotation(circuit) +: + stage.CompilerAnnotation("mverilog") +: + stage.OutputFileAnnotation(topName) +: + toAnnos(baseTransforms) } val customName = s"${prefix}_custom" @@ -111,7 +116,8 @@ trait FirrtlRunners extends BackendCompilationUtilities { val refAnnos = getBaseAnnos(refSuggestedName) ++: Seq(RunFirrtlTransformAnnotation(new RenameModules), nsAnno) val refResult = (new firrtl.stage.FirrtlStage).execute(Array.empty, refAnnos) - val refName = refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName) + val refName = + refResult.collectFirst({ case stage.FirrtlCircuitAnnotation(c) => c.main }).getOrElse(refSuggestedName) assert(BackendCompilationUtilities.yosysExpectSuccess(customName, refName, testDir, timesteps)) } @@ -123,6 +129,7 @@ trait FirrtlRunners extends BackendCompilationUtilities { val res = compiler.compileAndEmit(CircuitState(circuit, HighForm, annotations), extraCheckTransforms) res.getEmittedCircuit.value } + /** Compile a Firrtl file * * @param prefix is the name of the Firrtl file without path or file extension @@ -130,25 +137,27 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param annotations Optional Firrtl annotations */ def compileFirrtlTest( - prefix: String, - srcDir: String, - customTransforms: Seq[Transform] = Seq.empty, - annotations: AnnotationSeq = Seq.empty): File = { + prefix: String, + srcDir: String, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty + ): File = { val testDir = createTestDirectory(prefix) val inputFile = new File(testDir, s"${prefix}.fir") copyResourceToFile(s"${srcDir}/${prefix}.fir", inputFile) val annos = FirrtlFileAnnotation(inputFile.toString) +: - TargetDirAnnotation(testDir.toString) +: - InfoModeAnnotation("ignore") +: - annotations ++: - (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_)) + TargetDirAnnotation(testDir.toString) +: + InfoModeAnnotation("ignore") +: + annotations ++: + (customTransforms ++ extraCheckTransforms).map(RunFirrtlTransformAnnotation(_)) (new firrtl.stage.FirrtlStage).execute(Array.empty, annos) testDir } + /** Execute a Firrtl Test * * @param prefix is the name of the Firrtl file without path or file extension @@ -157,25 +166,26 @@ trait FirrtlRunners extends BackendCompilationUtilities { * @param annotations Optional Firrtl annotations */ def runFirrtlTest( - prefix: String, - srcDir: String, - verilogPrefixes: Seq[String] = Seq.empty, - customTransforms: Seq[Transform] = Seq.empty, - annotations: AnnotationSeq = Seq.empty) = { + prefix: String, + srcDir: String, + verilogPrefixes: Seq[String] = Seq.empty, + customTransforms: Seq[Transform] = Seq.empty, + annotations: AnnotationSeq = Seq.empty + ) = { val testDir = compileFirrtlTest(prefix, srcDir, customTransforms, annotations) val harness = new File(testDir, s"top.cpp") copyResourceToFile(cppHarnessResourceName, harness) // Note file copying side effect - val verilogFiles = verilogPrefixes map { vprefix => + val verilogFiles = verilogPrefixes.map { vprefix => val file = new File(testDir, s"$vprefix.v") copyResourceToFile(s"$srcDir/$vprefix.v", file) file } verilogToCpp(prefix, testDir, verilogFiles, harness) #&& - cppToExe(prefix, testDir) ! - loggingProcessLogger + cppToExe(prefix, testDir) ! + loggingProcessLogger assert(executeExpectingSuccess(prefix, testDir)) } } @@ -201,6 +211,7 @@ trait FirrtlMatchers extends Matchers { require(!s.contains("\n")) s.replaceAll("\\s+", " ").trim } + /** Helper to make circuits that are the same appear the same */ def canonicalize(circuit: Circuit): Circuit = { import firrtl.Mappers._ @@ -208,19 +219,21 @@ trait FirrtlMatchers extends Matchers { circuit.map(onModule) } def parse(str: String) = Parser.parse(str.split("\n").toIterator, UseInfo) + /** Helper for executing tests * compiler will be run on input then emitted result will each be split into * lines and normalized. */ def executeTest( - input: String, - expected: Seq[String], - compiler: Compiler, - annotations: Seq[Annotation] = Seq.empty) = { + input: String, + expected: Seq[String], + compiler: Compiler, + annotations: Seq[Annotation] = Seq.empty + ) = { val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val lines = finalState.getEmittedCircuit.value split "\n" map normalized + val lines = finalState.getEmittedCircuit.value.split("\n").map(normalized) for (e <- expected) { - lines should contain (e) + lines should contain(e) } } } @@ -239,10 +252,12 @@ object FirrtlCheckers extends FirrtlMatchers { case Some(res) => res // Otherwise keep digging case None => - require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], - "Error! Unexpected FirrtlNode that does not implement Product!") + require( + node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode], + "Error! Unexpected FirrtlNode that does not implement Product!" + ) val iter = node match { - case p: Product => p.productIterator + case p: Product => p.productIterator case i: Iterable[Any] => i.iterator case _ => Iterator.empty } @@ -296,57 +311,63 @@ class TestFirrtlFlatSpec extends FirrtlFlatSpec { import FirrtlCheckers._ val c = parse(""" - |circuit Test: - | module Test : - | input in : UInt<8> - | output out : UInt<8> - | out <= in - |""".stripMargin) + |circuit Test: + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | out <= in + |""".stripMargin) val state = CircuitState(c, ChirrtlForm) val compiled = (new LowFirrtlCompiler).compileAndEmit(state, List.empty) // While useful, ScalaTest helpers should be used over search - behavior of "Search" + behavior.of("Search") it should "be supported on Circuit" in { - assert(c search { - case Connect(_, Reference("out",_, _, _), Reference("in", _, _, _)) => true + assert(c.search { + case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true }) } it should "be supported on CircuitStates" in { - assert(state search { - case Connect(_, Reference("out", _, _, _), Reference("in",_, _, _)) => true + assert(state.search { + case Connect(_, Reference("out", _, _, _), Reference("in", _, _, _)) => true }) } it should "be supported on the results of compilers" in { - assert(compiled search { - case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true + assert(compiled.search { + case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true }) } // Use these!!! - behavior of "ScalaTest helpers" + behavior.of("ScalaTest helpers") they should "work for lines of emitted text" in { - compiled should containLine (s"input in : UInt<8>") - compiled should containLine (s"output out : UInt<8>") - compiled should containLine (s"out <= in") + compiled should containLine(s"input in : UInt<8>") + compiled should containLine(s"output out : UInt<8>") + compiled should containLine(s"out <= in") } they should "work for partial functions matching on subtrees" in { val UInt8 = UIntType(IntWidth(8)) // BigInt unapply is weird compiled should containTree { case Port(_, "in", Input, UInt8) => true } compiled should containTree { case Port(_, "out", Output, UInt8) => true } - compiled should containTree { case Connect(_, WRef("out",_,_,_), WRef("in",_,_,_)) => true } + compiled should containTree { case Connect(_, WRef("out", _, _, _), WRef("in", _, _, _)) => true } } } /** Super class for execution driven Firrtl tests */ -abstract class ExecutionTest(name: String, dir: String, vFiles: Seq[String] = Seq.empty, annotations: AnnotationSeq = Seq.empty) extends FirrtlPropSpec { +abstract class ExecutionTest( + name: String, + dir: String, + vFiles: Seq[String] = Seq.empty, + annotations: AnnotationSeq = Seq.empty) + extends FirrtlPropSpec { property(s"$name should execute correctly") { runFirrtlTest(name, dir, vFiles, annotations = annotations) } } + /** Super class for compilation driven Firrtl tests */ abstract class CompilationTest(name: String, dir: String) extends FirrtlPropSpec { property(s"$name should compile correctly") { @@ -444,7 +465,9 @@ abstract class EquivalenceTest(transforms: Seq[Transform], name: String, dir: St throw new FileNotFoundException(s"Resource '$fileName'") } val source = scala.io.Source.fromInputStream(in) - val input = try source.mkString finally source.close() + val input = + try source.mkString + finally source.close() s"$name with ${transforms.map(_.name).mkString(", ")}" should s"be equivalent to $name without ${transforms.map(_.name).mkString(", ")}" in { diff --git a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala index c1f0943a..4ae6a7be 100644 --- a/src/test/scala/firrtl/testutils/LeanTransformSpec.scala +++ b/src/test/scala/firrtl/testutils/LeanTransformSpec.scala @@ -1,6 +1,6 @@ package firrtl.testutils -import firrtl.{AnnotationSeq, CircuitState, EmitCircuitAnnotation, ir} +import firrtl.{ir, AnnotationSeq, CircuitState, EmitCircuitAnnotation} import firrtl.options.Dependency import firrtl.passes.RemoveEmpty import firrtl.stage.TransformManager.TransformDependency @@ -11,30 +11,33 @@ class VerilogTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.Veril class LowFirrtlTransformSpec extends LeanTransformSpec(Seq(Dependency[firrtl.LowFirrtlEmitter])) /** The new cool kid on the block, creates a custom compiler for your transform. */ -class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) extends AnyFlatSpec with FirrtlMatchers with LazyLogging { +class LeanTransformSpec(protected val transforms: Seq[TransformDependency]) + extends AnyFlatSpec + with FirrtlMatchers + with LazyLogging { private val compiler = new firrtl.stage.transforms.Compiler(transforms) private val emitterAnnos = LeanTransformSpec.deriveEmitCircuitAnnotations(transforms) protected def compile(src: String): CircuitState = compile(src, Seq()) protected def compile(src: String, annos: AnnotationSeq): CircuitState = compile(firrtl.Parser.parse(src), annos) - protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq()) - protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState = + protected def compile(c: ir.Circuit): CircuitState = compile(c, Seq()) + protected def compile(c: ir.Circuit, annos: AnnotationSeq): CircuitState = compiler.transform(CircuitState(c, emitterAnnos ++ annos)) - protected def execute(input: String, check: String): CircuitState = execute(input, check ,Seq()) + protected def execute(input: String, check: String): CircuitState = execute(input, check, Seq()) protected def execute(input: String, check: String, inAnnos: AnnotationSeq): CircuitState = { val finalState = compiler.transform(CircuitState(parse(input), inAnnos)) val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize val expected = parse(check).serialize logger.debug(actual) logger.debug(expected) - actual should be (expected) + actual should be(expected) finalState } } private object LeanTransformSpec { private def deriveEmitCircuitAnnotations(transforms: Iterable[TransformDependency]): AnnotationSeq = { - val emitters = transforms.map(_.getObject()).collect{ case e: firrtl.Emitter => e } + val emitters = transforms.map(_.getObject()).collect { case e: firrtl.Emitter => e } emitters.map(e => EmitCircuitAnnotation(e.getClass)).toSeq } } @@ -47,4 +50,4 @@ trait MakeCompiler { new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.MinimumVerilogEmitter]) ++ transforms) protected def makeLowFirrtlCompiler(transforms: Seq[TransformDependency] = Seq()) = new firrtl.stage.transforms.Compiler(Seq(Dependency[firrtl.LowFirrtlEmitter]) ++ transforms) -}
\ No newline at end of file +} diff --git a/src/test/scala/firrtl/testutils/PassTests.scala b/src/test/scala/firrtl/testutils/PassTests.scala index 49dea199..7a5dc306 100644 --- a/src/test/scala/firrtl/testutils/PassTests.scala +++ b/src/test/scala/firrtl/testutils/PassTests.scala @@ -15,49 +15,53 @@ import org.scalatest.flatspec.AnyFlatSpec // An example methodology for testing Firrtl Passes // Spec class should extend this class abstract class SimpleTransformSpec extends AnyFlatSpec with FirrtlMatchers with Compiler with LazyLogging { - // Utility function - def squash(c: Circuit): Circuit = RemoveEmpty.run(c) - - // Executes the test. Call in tests. - // annotations cannot have default value because scalatest trait Suite has a default value - def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize - val expected = parse(check).serialize - logger.debug(actual) - logger.debug(expected) - (actual) should be (expected) - finalState - } - - def executeWithAnnos(input: String, check: String, annotations: Seq[Annotation], - checkAnnotations: Seq[Annotation]): CircuitState = { - val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) - val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize - val expected = parse(check).serialize - logger.debug(actual) - logger.debug(expected) - (actual) should be (expected) - - annotations.foreach { anno => - logger.debug(anno.serialize) - } - - finalState.annotations.toSeq.foreach { anno => - logger.debug(anno.serialize) - } - checkAnnotations.foreach { check => - (finalState.annotations.toSeq) should contain (check) - } - finalState - } - // Executes the test, should throw an error - // No default to be consistent with execute - def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { - intercept[PassExceptions] { - compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) - } - } + // Utility function + def squash(c: Circuit): Circuit = RemoveEmpty.run(c) + + // Executes the test. Call in tests. + // annotations cannot have default value because scalatest trait Suite has a default value + def execute(input: String, check: String, annotations: Seq[Annotation]): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be(expected) + finalState + } + + def executeWithAnnos( + input: String, + check: String, + annotations: Seq[Annotation], + checkAnnotations: Seq[Annotation] + ): CircuitState = { + val finalState = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annotations)) + val actual = RemoveEmpty.run(parse(finalState.getEmittedCircuit.value)).serialize + val expected = parse(check).serialize + logger.debug(actual) + logger.debug(expected) + (actual) should be(expected) + + annotations.foreach { anno => + logger.debug(anno.serialize) + } + + finalState.annotations.toSeq.foreach { anno => + logger.debug(anno.serialize) + } + checkAnnotations.foreach { check => + (finalState.annotations.toSeq) should contain(check) + } + finalState + } + // Executes the test, should throw an error + // No default to be consistent with execute + def failingexecute(input: String, annotations: Seq[Annotation]): Exception = { + intercept[PassExceptions] { + compile(CircuitState(parse(input), ChirrtlForm, annotations), Seq.empty) + } + } } @deprecated( @@ -86,19 +90,19 @@ object ReRunResolveAndCheck extends Transform with DependencyAPIMigration with I } trait LowTransformSpec extends SimpleTransformSpec { - def emitter = new LowFirrtlEmitter - def transform: Transform - def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject) + def emitter = new LowFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.LowForm.map(_.getObject) } trait MiddleTransformSpec extends SimpleTransformSpec { - def emitter = new MiddleFirrtlEmitter - def transform: Transform - def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject) + def emitter = new MiddleFirrtlEmitter + def transform: Transform + def transforms: Seq[Transform] = transform +: ReRunResolveAndCheck +: Forms.MidForm.map(_.getObject) } trait HighTransformSpec extends SimpleTransformSpec { - def emitter = new HighFirrtlEmitter - def transform: Transform - def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject) + def emitter = new HighFirrtlEmitter + def transform: Transform + def transforms = transform +: ReRunResolveAndCheck +: Forms.HighForm.map(_.getObject) } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 4017503e..6f8dd574 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -15,7 +15,6 @@ import firrtl.util.BackendCompilationUtilities import firrtl.testutils._ import org.scalatest.matchers.should.Matchers - object AnnotationTests { class DeletingTransform extends Transform { @@ -31,26 +30,26 @@ object AnnotationTests { abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with MakeCompiler { import AnnotationTests._ - def anno(s: String, value: String ="this is a value", mod: String = "Top"): Annotation + def anno(s: String, value: String = "this is a value", mod: String = "Top"): Annotation def manno(mod: String): Annotation "Annotation on a node" should "pass through" in { val input: String = """circuit Top : - | module Top : - | input a : UInt<1>[2] - | input b : UInt<1> - | node c = b""".stripMargin + | module Top : + | input a : UInt<1>[2] + | input b : UInt<1> + | node c = b""".stripMargin val ta = anno("c", "") val r = compile(input, Seq(ta)) - r.annotations.toSeq should contain (ta) + r.annotations.toSeq should contain(ta) } "Deleting annotations" should "create a DeletedAnnotation" in { val transform = Dependency[DeletingTransform] val compiler = makeVerilogCompiler(Seq(transform)) val input = - """circuit Top : + """circuit Top : | module Top : | input in: UInt<3> |""".stripMargin @@ -65,7 +64,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with result.getEmittedCircuit }) val deleted = result.deletedAnnotations - exception.getMessage should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") + exception.getMessage should be(s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") } "Renaming" should "propagate in Lowering of memories" in { @@ -73,7 +72,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with // Uncomment to help debugging failing tests // Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug)) val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input in: UInt<3> @@ -87,25 +86,24 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | m.r.en <= UInt(1) | m.r.addr <= in |""".stripMargin - val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"), - dontTouch("Top.m")) + val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"), dontTouch("Top.m")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("m_a", "mem")) - resultAnno should contain (anno("m_b_0", "mem")) - resultAnno should contain (anno("m_b_1", "mem")) - resultAnno should contain (anno("m_a.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "all")) - resultAnno should contain (anno("m_b_1.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "sub")) - resultAnno should contain (anno("m_b_1.r.data", "sub")) + resultAnno should contain(anno("m_a", "mem")) + resultAnno should contain(anno("m_b_0", "mem")) + resultAnno should contain(anno("m_b_1", "mem")) + resultAnno should contain(anno("m_a.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "all")) + resultAnno should contain(anno("m_b_1.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "sub")) + resultAnno should contain(anno("m_b_1.r.data", "sub")) resultAnno should not contain (anno("m")) resultAnno should not contain (anno("r")) } "Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input in: UInt<3> @@ -115,14 +113,14 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"), dontTouch("Top.m")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("m_a", "mem")) - resultAnno should contain (anno("m_b_0", "mem")) - resultAnno should contain (anno("m_b_1", "mem")) - resultAnno should contain (anno("m_a.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "all")) - resultAnno should contain (anno("m_b_1.r.data", "all")) - resultAnno should contain (anno("m_b_0.r.data", "sub")) - resultAnno should contain (anno("m_b_1.r.data", "sub")) + resultAnno should contain(anno("m_a", "mem")) + resultAnno should contain(anno("m_b_0", "mem")) + resultAnno should contain(anno("m_b_1", "mem")) + resultAnno should contain(anno("m_a.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "all")) + resultAnno should contain(anno("m_b_1.r.data", "all")) + resultAnno should contain(anno("m_b_0.r.data", "sub")) + resultAnno should contain(anno("m_b_1.r.data", "sub")) resultAnno should not contain (anno("m")) resultAnno should not contain (anno("r")) } @@ -130,7 +128,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Renaming" should "propagate in ZeroWidth" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input zero: UInt<0> | wire x: {a: UInt<3>, b: UInt<0>} @@ -141,11 +139,11 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | x.a <= zero | x.b <= zero |""".stripMargin - val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), - anno("y[2]"), dontTouch("Top.x")) + val annos = + Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]"), dontTouch("Top.x")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("x_a")) + resultAnno should contain(anno("x_a")) resultAnno should not contain (anno("zero")) resultAnno should not contain (anno("x.a")) resultAnno should not contain (anno("x.b")) @@ -161,7 +159,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Renaming subcomponents" should "propagate in Lowering" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -176,12 +174,24 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | write <= in |""".stripMargin val annos = Seq( - anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), - anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), - anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), - anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), - anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), - dontTouch("Top.r"), dontTouch("Top.w"), dontTouch("Top.mem") + anno("in.a"), + anno("in.b[0]"), + anno("in.b[1]"), + anno("out.a"), + anno("out.b[0]"), + anno("out.b[1]"), + anno("w.a"), + anno("w.b[0]"), + anno("w.b[1]"), + anno("r.a"), + anno("r.b[0]"), + anno("r.b[1]"), + anno("write.a"), + anno("write.b[0]"), + anno("write.b[1]"), + dontTouch("Top.r"), + dontTouch("Top.w"), + dontTouch("Top.mem") ) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq @@ -200,27 +210,27 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with resultAnno should not contain (anno("r.a")) resultAnno should not contain (anno("r.b[0]")) resultAnno should not contain (anno("r.b[1]")) - resultAnno should contain (anno("in_a")) - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_a")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) - resultAnno should contain (anno("w_a")) - resultAnno should contain (anno("w_b_0")) - resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("r_a")) - resultAnno should contain (anno("r_b_0")) - resultAnno should contain (anno("r_b_1")) - resultAnno should contain (anno("mem_a.write.data")) - resultAnno should contain (anno("mem_b_0.write.data")) - resultAnno should contain (anno("mem_b_1.write.data")) + resultAnno should contain(anno("in_a")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_a")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) + resultAnno should contain(anno("w_a")) + resultAnno should contain(anno("w_b_0")) + resultAnno should contain(anno("w_b_1")) + resultAnno should contain(anno("r_a")) + resultAnno should contain(anno("r_b_0")) + resultAnno should contain(anno("r_b_1")) + resultAnno should contain(anno("mem_a.write.data")) + resultAnno should contain(anno("mem_b_0.write.data")) + resultAnno should contain(anno("mem_b_1.write.data")) } "Renaming components" should "expand in Lowering" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -231,28 +241,27 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= mux(pred, in, w) | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("r"), dontTouch("Top.r"), - dontTouch("Top.w")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("r"), dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("in_a")) - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_a")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) - resultAnno should contain (anno("w_a")) - resultAnno should contain (anno("w_b_0")) - resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("r_a")) - resultAnno should contain (anno("r_b_0")) - resultAnno should contain (anno("r_b_1")) + resultAnno should contain(anno("in_a")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_a")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) + resultAnno should contain(anno("w_a")) + resultAnno should contain(anno("w_b_0")) + resultAnno should contain(anno("w_b_1")) + resultAnno should contain(anno("r_a")) + resultAnno should contain(anno("r_b_0")) + resultAnno should contain(anno("r_b_1")) } "Renaming subcomponents that aren't leaves" should "expand in Lowering" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -264,24 +273,23 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("r.b"), - dontTouch("Top.r"), dontTouch("Top.w")) + val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("r.b"), dontTouch("Top.r"), dontTouch("Top.w")) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) - resultAnno should contain (anno("w_b_0")) - resultAnno should contain (anno("w_b_1")) - resultAnno should contain (anno("r_b_0")) - resultAnno should contain (anno("r_b_1")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) + resultAnno should contain(anno("w_b_0")) + resultAnno should contain(anno("w_b_1")) + resultAnno should contain(anno("r_b_0")) + resultAnno should contain(anno("r_b_1")) } "Renaming" should "track constprop + dce" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input pred: UInt<1> @@ -291,9 +299,15 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= n |""".stripMargin val annos = Seq( - anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), - anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), - anno("n.a"), anno("n.b[0]"), anno("n.b[1]") + anno("in.a"), + anno("in.b[0]"), + anno("in.b[1]"), + anno("out.a"), + anno("out.b[0]"), + anno("out.b[1]"), + anno("n.a"), + anno("n.b[0]"), + anno("n.b[1]") ) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq @@ -309,18 +323,18 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with resultAnno should not contain (anno("n_a")) resultAnno should not contain (anno("n_b_0")) resultAnno should not contain (anno("n_b_1")) - resultAnno should contain (anno("in_a")) - resultAnno should contain (anno("in_b_0")) - resultAnno should contain (anno("in_b_1")) - resultAnno should contain (anno("out_a")) - resultAnno should contain (anno("out_b_0")) - resultAnno should contain (anno("out_b_1")) + resultAnno should contain(anno("in_a")) + resultAnno should contain(anno("in_b_0")) + resultAnno should contain(anno("in_b_1")) + resultAnno should contain(anno("out_a")) + resultAnno should contain(anno("out_b_0")) + resultAnno should contain(anno("out_b_1")) } "Renaming" should "track deleted modules AND instances in dce" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Dead : | input foo : UInt<8> | output bar : UInt<8> @@ -339,11 +353,17 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with |""".stripMargin val annos = Seq( OptimizableExtModuleAnnotation(ModuleName("DeadExt", CircuitName("Top"))), - manno("Dead"), manno("DeadExt"), manno("Top"), - anno("d"), anno("d2"), - anno("foo", mod = "Top"), anno("bar", mod = "Top"), - anno("foo", mod = "Dead"), anno("bar", mod = "Dead"), - anno("foo", mod = "DeadExt"), anno("bar", mod = "DeadExt") + manno("Dead"), + manno("DeadExt"), + manno("Top"), + anno("d"), + anno("d2"), + anno("foo", mod = "Top"), + anno("bar", mod = "Top"), + anno("foo", mod = "Dead"), + anno("bar", mod = "Dead"), + anno("foo", mod = "DeadExt"), + anno("bar", mod = "DeadExt") ) val result = compiler.transform(CircuitState(parse(input), annos)) /* Uncomment to help debug @@ -354,12 +374,12 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with case Annotation(target, _, _) => println(s"not deleted: $target") } } - */ + */ val resultAnno = result.annotations.toSeq - resultAnno should contain (manno("Top")) - resultAnno should contain (anno("foo", mod = "Top")) - resultAnno should contain (anno("bar", mod = "Top")) + resultAnno should contain(manno("Top")) + resultAnno should contain(anno("foo", mod = "Top")) + resultAnno should contain(anno("bar", mod = "Top")) resultAnno should not contain (manno("Dead")) resultAnno should not contain (manno("DeadExt")) @@ -373,7 +393,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Renaming" should "track deduplication" in { val input = - """circuit Top : + """circuit Top : | module Child : | input x : UInt<32> | output y : UInt<32> @@ -392,13 +412,16 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | out <= tail(add(a.y, b.y), 1) |""".stripMargin val annos = Seq( - anno("x", mod = "Child"), anno("y", mod = "Child_1"), manno("Child"), manno("Child_1") + anno("x", mod = "Child"), + anno("y", mod = "Child_1"), + manno("Child"), + manno("Child_1") ) val result = compile(input, annos) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("x", mod = "Child")) - resultAnno should contain (anno("y", mod = "Child")) - resultAnno should contain (manno("Child")) + resultAnno should contain(anno("x", mod = "Child")) + resultAnno should contain(anno("y", mod = "Child")) + resultAnno should contain(manno("Child")) resultAnno should not contain (anno("y", mod = "Child_1")) resultAnno should not contain (manno("Child_1")) } @@ -412,7 +435,7 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with "Annotations on empty aggregates" should "be deleted" in { val compiler = makeVerilogCompiler() val input = - """circuit Top : + """circuit Top : | module Top : | input x : { foo : UInt<8>, bar : {}, fizz : UInt<8>[0], buzz : UInt<0> } | output y : { foo : UInt<8>, bar : {}, fizz : UInt<8>[0], buzz : UInt<0> } @@ -423,12 +446,19 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with | y <= x |""".stripMargin val annos = Seq( - anno("x"), anno("y.bar"), anno("y.fizz"), anno("y.buzz"), anno("a"), anno("b"), anno("c"), - anno("c[0].d"), anno("c[1].d") + anno("x"), + anno("y.bar"), + anno("y.fizz"), + anno("y.buzz"), + anno("a"), + anno("b"), + anno("c"), + anno("c[0].d"), + anno("c[1].d") ) val result = compiler.transform(CircuitState(parse(input), annos)) val resultAnno = result.annotations.toSeq - resultAnno should contain (anno("x_foo")) + resultAnno should contain(anno("x_foo")) resultAnno should not contain (anno("a")) resultAnno should not contain (anno("b")) // Check both with and without dots because both are wrong @@ -445,8 +475,8 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with resultAnno should not contain (anno("x_fizz")) resultAnno should not contain (anno("x_buzz")) resultAnno should not contain (anno("c")) - resultAnno should contain (anno("c_0_e")) - resultAnno should contain (anno("c_1_e")) + resultAnno should contain(anno("c_0_e")) + resultAnno should contain(anno("c_1_e")) resultAnno should not contain (anno("c[0].d")) resultAnno should not contain (anno("c[1].d")) resultAnno should not contain (anno("c_0_d")) @@ -456,15 +486,14 @@ abstract class AnnotationTests extends LowFirrtlTransformSpec with Matchers with class JsonAnnotationTests extends AnnotationTests { // Helper annotations - case class SimpleAnno(target: ComponentName, value: String) extends - SingleTargetAnnotation[ComponentName] { + case class SimpleAnno(target: ComponentName, value: String) extends SingleTargetAnnotation[ComponentName] { def duplicate(n: ComponentName) = this.copy(target = n) } case class ModuleAnno(target: ModuleName) extends SingleTargetAnnotation[ModuleName] { def duplicate(n: ModuleName) = this.copy(target = n) } - def anno(s: String, value: String ="this is a value", mod: String = "Top"): SimpleAnno = + def anno(s: String, value: String = "this is a value", mod: String = "Top"): SimpleAnno = SimpleAnno(ComponentName(s, ModuleName(mod, CircuitName("Top"))), value) def manno(mod: String): Annotation = ModuleAnno(ModuleName(mod, CircuitName("Top"))) @@ -487,17 +516,17 @@ class JsonAnnotationTests extends AnnotationTests { val readAnnos = JsonProtocol.deserializeTry(text).get - annos should be (readAnnos) + annos should be(readAnnos) } private def setupManager(annoFileText: Option[String]) = { val source = """ - |circuit test : - | module test : - | input x : UInt<1> - | output z : UInt<1> - | z <= x - | node y = x""".stripMargin + |circuit test : + | module test : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | node y = x""".stripMargin val testDir = BackendCompilationUtilities.createTestDirectory(this.getClass.getSimpleName) val annoFile = new File(testDir, "anno.json") @@ -519,58 +548,57 @@ class JsonAnnotationTests extends AnnotationTests { "Annotation file not found" should "give a reasonable error message" in { val manager = setupManager(None) - an [AnnotationFileNotFoundException] shouldBe thrownBy { + an[AnnotationFileNotFoundException] shouldBe thrownBy { Driver.execute(manager) } } "Annotation class not found" should "give a reasonable error message" in { val anno = """ - |[ - | { - | "class":"ThisClassDoesNotExist", - | "target":"test.test.y" - | } - |] """.stripMargin + |[ + | { + | "class":"ThisClassDoesNotExist", + | "target":"test.test.y" + | } + |] """.stripMargin val manager = setupManager(Some(anno)) - the [Exception] thrownBy Driver.execute(manager) should matchPattern { + the[Exception] thrownBy Driver.execute(manager) should matchPattern { case InvalidAnnotationFileException(_, _: AnnotationClassNotFoundException) => } } "Malformed annotation file" should "give a reasonable error message" in { val anno = """ - |[ - | { - | "class": - | "target":"test.test.y" - | } - |] """.stripMargin + |[ + | { + | "class": + | "target":"test.test.y" + | } + |] """.stripMargin val manager = setupManager(Some(anno)) - the [Exception] thrownBy Driver.execute(manager) should matchPattern { + the[Exception] thrownBy Driver.execute(manager) should matchPattern { case InvalidAnnotationFileException(_, _: InvalidAnnotationJSONException) => } } "Non-array annotation file" should "give a reasonable error message" in { val anno = """ - |{ - | "class":"firrtl.transforms.DontTouchAnnotation", - | "target":"test.test.y" - |} - |""".stripMargin + |{ + | "class":"firrtl.transforms.DontTouchAnnotation", + | "target":"test.test.y" + |} + |""".stripMargin val manager = setupManager(Some(anno)) - the [Exception] thrownBy Driver.execute(manager) should matchPattern { - case InvalidAnnotationFileException(_, InvalidAnnotationJSONException(msg)) - if msg.contains("JObject") => + the[Exception] thrownBy Driver.execute(manager) should matchPattern { + case InvalidAnnotationFileException(_, InvalidAnnotationJSONException(msg)) if msg.contains("JObject") => } } object DoNothingTransform extends Transform { - override def inputForm: CircuitForm = UnknownForm + override def inputForm: CircuitForm = UnknownForm override def outputForm: CircuitForm = UnknownForm def execute(state: CircuitState): CircuitState = state @@ -580,9 +608,9 @@ class JsonAnnotationTests extends AnnotationTests { val annos = Seq(anno("a"), anno("b"), anno("c"), anno("d"), anno("e")) val input: String = """circuit Top : - | module Top : - | input a : UInt<1> - | node b = c""".stripMargin + | module Top : + | input a : UInt<1> + | node b = c""".stripMargin val cr = DoNothingTransform.runTransform(CircuitState(parse(input), ChirrtlForm, annos)) cr.annotations.toSeq shouldEqual annos } diff --git a/src/test/scala/firrtlTests/AsyncResetSpec.scala b/src/test/scala/firrtlTests/AsyncResetSpec.scala index 70b28585..04b558e9 100644 --- a/src/test/scala/firrtlTests/AsyncResetSpec.scala +++ b/src/test/scala/firrtlTests/AsyncResetSpec.scala @@ -9,330 +9,313 @@ import FirrtlCheckers._ class AsyncResetSpec extends VerilogTransformSpec { def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } "AsyncReset" should "generate async-reset always blocks" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<8> - |output z : UInt<8> - |reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<8> + |output z : UInt<8> + |reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } it should "work in nested and flipped aggregates with regular and partial connect" in { val result = compileBody(s""" - |output fizz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } - |output buzz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } - |fizz.bar <= fizz.foo - |buzz.bar <- buzz.foo - |""".stripMargin - ) + |output fizz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } + |output buzz : { flip foo : { a : AsyncReset, flip b: AsyncReset }[2], bar : { a : AsyncReset, flip b: AsyncReset }[2] } + |fizz.bar <= fizz.foo + |buzz.bar <- buzz.foo + |""".stripMargin) - result should containLine ("assign fizz_foo_0_b = fizz_bar_0_b;") - result should containLine ("assign fizz_foo_1_b = fizz_bar_1_b;") - result should containLine ("assign fizz_bar_0_a = fizz_foo_0_a;") - result should containLine ("assign fizz_bar_1_a = fizz_foo_1_a;") - result should containLine ("assign buzz_foo_0_b = buzz_bar_0_b;") - result should containLine ("assign buzz_foo_1_b = buzz_bar_1_b;") - result should containLine ("assign buzz_bar_0_a = buzz_foo_0_a;") - result should containLine ("assign buzz_bar_1_a = buzz_foo_1_a;") + result should containLine("assign fizz_foo_0_b = fizz_bar_0_b;") + result should containLine("assign fizz_foo_1_b = fizz_bar_1_b;") + result should containLine("assign fizz_bar_0_a = fizz_foo_0_a;") + result should containLine("assign fizz_bar_1_a = fizz_foo_1_a;") + result should containLine("assign buzz_foo_0_b = buzz_bar_0_b;") + result should containLine("assign buzz_foo_1_b = buzz_bar_1_b;") + result should containLine("assign buzz_bar_0_a = buzz_foo_0_a;") + result should containLine("assign buzz_bar_1_a = buzz_foo_1_a;") } it should "support casting to other types" in { val result = compileBody(s""" - |input a : AsyncReset - |output u : Interval[0, 1].0 - |output v : UInt<1> - |output w : SInt<1> - |output x : Clock - |output y : Fixed<1><<0>> - |output z : AsyncReset - |u <= asInterval(a, 0, 1, 0) - |v <= asUInt(a) - |w <= asSInt(a) - |x <= asClock(a) - |y <= asFixedPoint(a, 0) - |z <= asAsyncReset(a) - |""".stripMargin - ) - result should containLine ("assign v = a;") - result should containLine ("assign w = a;") - result should containLine ("assign x = a;") - result should containLine ("assign y = a;") - result should containLine ("assign z = a;") + |input a : AsyncReset + |output u : Interval[0, 1].0 + |output v : UInt<1> + |output w : SInt<1> + |output x : Clock + |output y : Fixed<1><<0>> + |output z : AsyncReset + |u <= asInterval(a, 0, 1, 0) + |v <= asUInt(a) + |w <= asSInt(a) + |x <= asClock(a) + |y <= asFixedPoint(a, 0) + |z <= asAsyncReset(a) + |""".stripMargin) + result should containLine("assign v = a;") + result should containLine("assign w = a;") + result should containLine("assign x = a;") + result should containLine("assign y = a;") + result should containLine("assign z = a;") } "Other types" should "support casting to AsyncReset" in { val result = compileBody(s""" - |input a : UInt<1> - |input b : SInt<1> - |input c : Clock - |input d : Fixed<1><<0>> - |input e : AsyncReset - |input f : Interval[0, 0].0 - |output u : AsyncReset - |output v : AsyncReset - |output w : AsyncReset - |output x : AsyncReset - |output y : AsyncReset - |output z : AsyncReset - |u <= asAsyncReset(a) - |v <= asAsyncReset(b) - |w <= asAsyncReset(c) - |x <= asAsyncReset(d) - |y <= asAsyncReset(e) - |z <= asAsyncReset(f)""".stripMargin - ) - result should containLine ("assign u = a;") - result should containLine ("assign v = b;") - result should containLine ("assign w = c;") - result should containLine ("assign x = d;") - result should containLine ("assign y = e;") - result should containLine ("assign z = f;") + |input a : UInt<1> + |input b : SInt<1> + |input c : Clock + |input d : Fixed<1><<0>> + |input e : AsyncReset + |input f : Interval[0, 0].0 + |output u : AsyncReset + |output v : AsyncReset + |output w : AsyncReset + |output x : AsyncReset + |output y : AsyncReset + |output z : AsyncReset + |u <= asAsyncReset(a) + |v <= asAsyncReset(b) + |w <= asAsyncReset(c) + |x <= asAsyncReset(d) + |y <= asAsyncReset(e) + |z <= asAsyncReset(f)""".stripMargin) + result should containLine("assign u = a;") + result should containLine("assign v = b;") + result should containLine("assign w = c;") + result should containLine("assign x = d;") + result should containLine("assign y = e;") + result should containLine("assign z = f;") } "Non-literals" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |reg r : UInt<8>, clock with : (reset => (reset, y)) - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |reg r : UInt<8>, clock with : (reset => (reset, y)) + |r <= x + |z <= r""".stripMargin) } } "Self-inits" should "NOT cause infinite loops in CheckResets" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input in : UInt<12> - |output out : UInt<10> - | - |reg a : UInt<10>, clock with : - | reset => (reset, a) - |out <= UInt<5>("h15")""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input in : UInt<12> + |output out : UInt<10> + | + |reg a : UInt<10>, clock with : + | reset => (reset, a) + |out <= UInt<5>("h15")""".stripMargin) result should containLine("assign out = 10'h15;") } "Late non-literals connections" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |wire a : UInt<8> - |reg r : UInt<8>, clock with : (reset => (reset, a)) - |a <= y - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |wire a : UInt<8> + |reg r : UInt<8>, clock with : (reset => (reset, a)) + |a <= y + |r <= x + |z <= r""".stripMargin) } } "Hidden Non-literals" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |input y : UInt<1> - |output z : UInt<1>[4] - |wire literal : UInt<1>[4] - |literal[0] <= UInt<1>("h00") - |literal[1] <= y - |literal[2] <= UInt<1>("h00") - |literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |input y : UInt<1> + |output z : UInt<1>[4] + |wire literal : UInt<1>[4] + |literal[0] <= UInt<1>("h00") + |literal[1] <= y + |literal[2] <= UInt<1>("h00") + |literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) + |r <= x + |z <= r""".stripMargin) } } "Wire connected to non-literal" should "NOT be allowed as reset values for AsyncReset" in { - an [checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { + an[checks.CheckResets.NonLiteralAsyncResetValueException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |input y : UInt<1> - |input cond : UInt<1> - |output z : UInt<1> - |wire w : UInt<1> - |w <= UInt(1) - |when cond : - | w <= y - |reg r : UInt<1>, clock with : (reset => (reset, w)) - |r <= x - |z <= r""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |input y : UInt<1> + |input cond : UInt<1> + |output z : UInt<1> + |wire w : UInt<1> + |w <= UInt(1) + |when cond : + | w <= y + |reg r : UInt<1>, clock with : (reset => (reset, w)) + |r <= x + |z <= r""".stripMargin) } } "Complex literals" should "be allowed as reset values for AsyncReset" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |output z : UInt<1>[4] - |wire literal : UInt<1>[4] - |literal[0] <= UInt<1>("h00") - |literal[1] <= UInt<1>("h00") - |literal[2] <= UInt<1>("h00") - |literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |output z : UInt<1>[4] + |wire literal : UInt<1>[4] + |literal[0] <= UInt<1>("h00") + |literal[1] <= UInt<1>("h00") + |literal[2] <= UInt<1>("h00") + |literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, literal)) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } "Complex literals of complex literals" should "be allowed as reset values for AsyncReset" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |output z : UInt<1>[4] - |wire literal : UInt<1>[2] - |literal[0] <= UInt<1>("h01") - |literal[1] <= UInt<1>("h01") - |wire complex_literal : UInt<1>[4] - |complex_literal[0] <= literal[0] - |complex_literal[1] <= literal[1] - |complex_literal[2] <= UInt<1>("h00") - |complex_literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |output z : UInt<1>[4] + |wire literal : UInt<1>[2] + |literal[0] <= UInt<1>("h01") + |literal[1] <= UInt<1>("h01") + |wire complex_literal : UInt<1>[4] + |complex_literal[0] <= literal[0] + |complex_literal[1] <= literal[1] + |complex_literal[2] <= UInt<1>("h00") + |complex_literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } "Literals of bundle literals" should "be allowed as reset values for AsyncReset" in { val result = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1>[4] - |output z : UInt<1>[4] - |wire bundle : {a: UInt<1>, b: UInt<1>} - |bundle.a <= UInt<1>("h01") - |bundle.b <= UInt<1>("h01") - |wire complex_literal : UInt<1>[4] - |complex_literal[0] <= bundle.a - |complex_literal[1] <= bundle.b - |complex_literal[2] <= UInt<1>("h00") - |complex_literal[3] <= UInt<1>("h00") - |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) - |r <= x - |z <= r""".stripMargin - ) - result should containLine ("always @(posedge clock or posedge reset) begin") + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1>[4] + |output z : UInt<1>[4] + |wire bundle : {a: UInt<1>, b: UInt<1>} + |bundle.a <= UInt<1>("h01") + |bundle.b <= UInt<1>("h01") + |wire complex_literal : UInt<1>[4] + |complex_literal[0] <= bundle.a + |complex_literal[1] <= bundle.b + |complex_literal[2] <= UInt<1>("h00") + |complex_literal[3] <= UInt<1>("h00") + |reg r : UInt<1>[4], clock with : (reset => (reset, complex_literal)) + |r <= x + |z <= r""".stripMargin) + result should containLine("always @(posedge clock or posedge reset) begin") } "Cast literals" should "be allowed as reset values for AsyncReset" in { // This also checks that casts can be across wires and nodes val sintResult = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : SInt<4> - |output y : SInt<4> - |output z : SInt<4> - |reg r : SInt<4>, clock with : (reset => (reset, asSInt(UInt(0)))) - |r <= x - |wire w : SInt<4> - |reg r2 : SInt<4>, clock with : (reset => (reset, w)) - |r2 <= x - |node n = UInt("hf") - |w <= asSInt(n) - |y <= r2 - |z <= r""".stripMargin - ) - sintResult should containLine ("always @(posedge clock or posedge reset) begin") - sintResult should containLine ("r <= 4'sh0;") - sintResult should containLine ("r2 <= -4'sh1;") + |input clock : Clock + |input reset : AsyncReset + |input x : SInt<4> + |output y : SInt<4> + |output z : SInt<4> + |reg r : SInt<4>, clock with : (reset => (reset, asSInt(UInt(0)))) + |r <= x + |wire w : SInt<4> + |reg r2 : SInt<4>, clock with : (reset => (reset, w)) + |r2 <= x + |node n = UInt("hf") + |w <= asSInt(n) + |y <= r2 + |z <= r""".stripMargin) + sintResult should containLine("always @(posedge clock or posedge reset) begin") + sintResult should containLine("r <= 4'sh0;") + sintResult should containLine("r2 <= -4'sh1;") val fixedResult = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : Fixed<2><<0>> - |output z : Fixed<2><<0>> - |reg r : Fixed<2><<0>>, clock with : (reset => (reset, asFixedPoint(UInt(2), 0))) - |r <= x - |z <= r""".stripMargin - ) - fixedResult should containLine ("always @(posedge clock or posedge reset) begin") - fixedResult should containLine ("r <= 2'sh2;") + |input clock : Clock + |input reset : AsyncReset + |input x : Fixed<2><<0>> + |output z : Fixed<2><<0>> + |reg r : Fixed<2><<0>>, clock with : (reset => (reset, asFixedPoint(UInt(2), 0))) + |r <= x + |z <= r""".stripMargin) + fixedResult should containLine("always @(posedge clock or posedge reset) begin") + fixedResult should containLine("r <= 2'sh2;") val intervalResult = compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |input x : Interval[0, 4].0 - |output z : Interval[0, 4].0 - |reg r : Interval[0, 4].0, clock with : (reset => (reset, asInterval(UInt(0), 0, 0, 0))) - |r <= x - |z <= r""".stripMargin - ) - intervalResult should containLine ("always @(posedge clock or posedge reset) begin") - intervalResult should containLine ("r <= 4'sh0;") + |input clock : Clock + |input reset : AsyncReset + |input x : Interval[0, 4].0 + |output z : Interval[0, 4].0 + |reg r : Interval[0, 4].0, clock with : (reset => (reset, asInterval(UInt(0), 0, 0, 0))) + |r <= x + |z <= r""".stripMargin) + intervalResult should containLine("always @(posedge clock or posedge reset) begin") + intervalResult should containLine("r <= 4'sh0;") } "CheckResets" should "NOT raise StackOverflow Exception on Combinational Loops (should be caught by firrtl.transforms.CheckCombLoops)" in { - an [firrtl.transforms.CheckCombLoops.CombLoopException] shouldBe thrownBy { + an[firrtl.transforms.CheckCombLoops.CombLoopException] shouldBe thrownBy { compileBody(s""" - |input clock : Clock - |input reset : AsyncReset - |wire x : UInt<1> - |wire y : UInt<2> - |x <= UInt<1>("h01") - |node ad = add(x, y) - |node adt = tail(ad, 1) - |y <= adt - |reg r : UInt, clock with : (reset => (reset, y)) - |""".stripMargin - ) + |input clock : Clock + |input reset : AsyncReset + |wire x : UInt<1> + |wire y : UInt<2> + |x <= UInt<1>("h01") + |node ad = add(x, y) + |node adt = tail(ad, 1) + |y <= adt + |reg r : UInt, clock with : (reset => (reset, y)) + |""".stripMargin) } } "Every async reset reg" should "generate its own always block" in { val result = compileBody(s""" - |input clock0 : Clock - |input clock1 : Clock - |input syncReset : UInt<1> - |input asyncReset : AsyncReset - |input x : UInt<8>[5] - |output z : UInt<8>[5] - |reg r0 : UInt<8>, clock0 with : (reset => (syncReset, UInt(123))) - |reg r1 : UInt<8>, clock1 with : (reset => (syncReset, UInt(123))) - |reg r2 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) - |reg r3 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) - |reg r4 : UInt<8>, clock1 with : (reset => (asyncReset, UInt(123))) - |r0 <= x[0] - |r1 <= x[1] - |r2 <= x[2] - |r3 <= x[3] - |r4 <= x[4] - |z[0] <= r0 - |z[1] <= r1 - |z[2] <= r2 - |z[3] <= r3 - |z[4] <= r4""".stripMargin - ) - result should containLines ( + |input clock0 : Clock + |input clock1 : Clock + |input syncReset : UInt<1> + |input asyncReset : AsyncReset + |input x : UInt<8>[5] + |output z : UInt<8>[5] + |reg r0 : UInt<8>, clock0 with : (reset => (syncReset, UInt(123))) + |reg r1 : UInt<8>, clock1 with : (reset => (syncReset, UInt(123))) + |reg r2 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) + |reg r3 : UInt<8>, clock0 with : (reset => (asyncReset, UInt(123))) + |reg r4 : UInt<8>, clock1 with : (reset => (asyncReset, UInt(123))) + |r0 <= x[0] + |r1 <= x[1] + |r2 <= x[2] + |r3 <= x[3] + |r4 <= x[4] + |z[0] <= r0 + |z[1] <= r1 + |z[2] <= r2 + |z[3] <= r3 + |z[4] <= r4""".stripMargin) + result should containLines( "always @(posedge clock0) begin", "if (syncReset) begin", "r0 <= 8'h7b;", @@ -341,7 +324,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock1) begin", "if (syncReset) begin", "r1 <= 8'h7b;", @@ -350,7 +333,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock0 or posedge asyncReset) begin", "if (asyncReset) begin", "r2 <= 8'h7b;", @@ -359,7 +342,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock0 or posedge asyncReset) begin", "if (asyncReset) begin", "r3 <= 8'h7b;", @@ -368,7 +351,7 @@ class AsyncResetSpec extends VerilogTransformSpec { "end", "end" ) - result should containLines ( + result should containLines( "always @(posedge clock1 or posedge asyncReset) begin", "if (asyncReset) begin", "r4 <= 8'h7b;", @@ -427,27 +410,26 @@ class AsyncResetSpec extends VerilogTransformSpec { "AsyncReset registers" should "emit 'else' case for reset even for trivial valued registers" in { val withDontTouch = s""" - |circuit m : - | module m : - | input clock : Clock - | input reset : AsyncReset - | input x : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) - |""".stripMargin + |circuit m : + | module m : + | input clock : Clock + | input reset : AsyncReset + | input x : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt(123))) + |""".stripMargin val annos = Seq(dontTouch("m.r")) // dontTouch prevents ConstantPropagation from fixing this problem val result = (new VerilogCompiler).compileAndEmit(CircuitState(parse(withDontTouch), ChirrtlForm, annos)) - result should containLines ( - "always @(posedge clock or posedge reset) begin", - "if (reset) begin", - "r <= 8'h7b;", - "end else begin", - "r <= 8'h7b;", - "end", - "end" - ) + result should containLines( + "always @(posedge clock or posedge reset) begin", + "if (reset) begin", + "r <= 8'h7b;", + "end else begin", + "r <= 8'h7b;", + "end", + "end" + ) } } class AsyncResetExecutionTest extends ExecutionTest("AsyncResetTester", "/features") - diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index e4acc735..709e3692 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -9,12 +9,12 @@ import firrtl.testutils._ class InoutVerilogSpec extends FirrtlFlatSpec { - behavior of "Analog" + behavior.of("Analog") it should "attach a module input source directly" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | input an: Analog<3> | inst a of A @@ -25,32 +25,32 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module B: | input an2: Analog<3> """.stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - | A a ( - | .an1(an) - | ); - | B b ( - | .an2(an) - | ); - |endmodule - |module A( - | inout [2:0] an1 - |); - |endmodule - |module B( - | inout [2:0] an2 - |); - |endmodule - |""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + | A a ( + | .an1(an) + | ); + | B b ( + | .an2(an) + | ); + |endmodule + |module A( + | inout [2:0] an1 + |); + |endmodule + |module B( + | inout [2:0] an2 + |); + |endmodule + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontDedup("A"), dontDedup("B"))) } it should "attach two instances" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | inst a of A | inst b of B @@ -60,24 +60,24 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module B: | input an: Analog<3>""".stripMargin val check = - """module Attaching( - |); - | wire [2:0] _GEN_0; - | A a ( - | .an(_GEN_0) - | ); - | B b ( - | .an(_GEN_0) - | ); - |endmodule - |module A( - | inout [2:0] an - |); - |module B( - | inout [2:0] an - |); - |endmodule - |""".stripMargin.split("\n") map normalized + """module Attaching( + |); + | wire [2:0] _GEN_0; + | A a ( + | .an(_GEN_0) + | ); + | B b ( + | .an(_GEN_0) + | ); + |endmodule + |module A( + | inout [2:0] an + |); + |module B( + | inout [2:0] an + |); + |endmodule + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontTouch("A.an"), dontDedup("A"))) } @@ -85,12 +85,12 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | wire x: Analog - | inst a of A - | attach (x, a.an) - | module A: - | input an: Analog<3> """.stripMargin + | module Attaching : + | wire x: Analog + | inst a of A + | attach (x, a.an) + | module A: + | input an: Analog<3> """.stripMargin val check = """module Attaching( |); @@ -99,7 +99,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | .an(x) | ); |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontTouch("Attaching.x"))) } @@ -107,14 +107,14 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input an: Analog<3> - | wire x: Analog - | inst a of A - | attach (x, a.an) - | attach (x, an) - | module A: - | input an: Analog<3> """.stripMargin + | module Attaching : + | input an: Analog<3> + | wire x: Analog + | inst a of A + | attach (x, a.an) + | attach (x, an) + | module A: + | input an: Analog<3> """.stripMargin val check = """module Attaching( | inout [2:0] an @@ -123,20 +123,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | .an(an) | ); |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler, Seq(dontTouch("Attaching.x"))) } - it should "attach multiple sources" in { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input a1 : Analog<3> - | input a2 : Analog<3> - | wire x: Analog<3> - | attach (x, a1, a2)""".stripMargin + | module Attaching : + | input a1 : Analog<3> + | input a2 : Analog<3> + | wire x: Analog<3> + | attach (x, a1, a2)""".stripMargin val check = """module Attaching( | inout [2:0] a1, @@ -151,7 +150,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | alias a1 = a2; | `endif |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } @@ -159,10 +158,10 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input foo : { b : UInt<3>, a : Analog<3> } - | output bar : { b : UInt<3>, a : Analog<3> } - | bar <- foo""".stripMargin + | module Attaching : + | input foo : { b : UInt<3>, a : Analog<3> } + | output bar : { b : UInt<3>, a : Analog<3> } + | bar <- foo""".stripMargin // Omitting `ifdef SYNTHESIS and `elsif verilator since it's tested above val check = """module Attaching( @@ -174,7 +173,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | assign bar_b = foo_b; | alias bar_a = foo_a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } @@ -182,14 +181,14 @@ class InoutVerilogSpec extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Attaching : - | module Attaching : - | input a : Analog<32> - | input b : Analog<32> - | input c : Analog<32> - | input d : Analog<32> - | attach (a, b) - | attach (c, b) - | attach (a, d)""".stripMargin + | module Attaching : + | input a : Analog<32> + | input b : Analog<32> + | input c : Analog<32> + | input d : Analog<32> + | attach (a, b) + | attach (c, b) + | attach (a, d)""".stripMargin val check = """module Attaching( | inout [31:0] a, @@ -199,19 +198,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec { |); | alias a = b = c = d; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) val input2 = """circuit Attaching : - | module Attaching : - | input a : Analog<32> - | input b : Analog<32> - | input c : Analog<32> - | input d : Analog<32> - | attach (a, b) - | attach (c, d) - | attach (d, a)""".stripMargin + | module Attaching : + | input a : Analog<32> + | input b : Analog<32> + | input c : Analog<32> + | input d : Analog<32> + | attach (a, b) + | attach (c, d) + | attach (d, a)""".stripMargin val check2 = """module Attaching( | inout [31:0] a, @@ -221,14 +220,14 @@ class InoutVerilogSpec extends FirrtlFlatSpec { |); | alias a = b = c = d; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input2, check2, compiler) } it should "infer widths" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | input an: Analog | inst a of A @@ -236,70 +235,65 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module A: | input an1: Analog<3>""".stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - | A a ( - | .an1(an) - | ); - |endmodule - |module A( - | inout [2:0] an1 - |); - |endmodule""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + | A a ( + | .an1(an) + | ); + |endmodule + |module A( + | inout [2:0] an1 + |); + |endmodule""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } it should "not error if not isinvalid" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | output an: Analog<3> |""".stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - |endmodule""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + |endmodule""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } it should "not error if isinvalid" in { val compiler = new VerilogCompiler val input = - """circuit Attaching : + """circuit Attaching : | module Attaching : | output an: Analog<3> | an is invalid |""".stripMargin val check = - """module Attaching( - | inout [2:0] an - |); - |endmodule""".stripMargin.split("\n") map normalized + """module Attaching( + | inout [2:0] an + |); + |endmodule""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } } class AttachAnalogSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } "Connecting analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -307,38 +301,28 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | output x: Analog<1> | x <= y""".stripMargin intercept[CheckTypes.InvalidConnect] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Declaring register with analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : | input clk: Clock | reg r: Analog<2>, clk""".stripMargin intercept[CheckTypes.IllegalAnalogDeclaration] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Declaring memory with analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -350,38 +334,28 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | write-latency => 1 | read-under-write => undefined""".stripMargin intercept[CheckTypes.IllegalAnalogDeclaration] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Declaring node with analog types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : | input in: Analog<2> | node n = in """.stripMargin intercept[CheckTypes.IllegalAnalogDeclaration] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Attaching a non-analog expression" should "not be ok" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -394,21 +368,14 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | extmodule B: | input o: Analog<2>""".stripMargin intercept[CheckTypes.OpNotAnalog] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Inequal attach widths" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -418,8 +385,8 @@ class AttachAnalogSpec extends FirrtlFlatSpec { | extmodule A : | output o: Analog<2> """.stripMargin intercept[CheckWidths.AttachWidthsNotEqual] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/CInferMDirSpec.scala b/src/test/scala/firrtlTests/CInferMDirSpec.scala index 6c9d4047..ce0c0a74 100644 --- a/src/test/scala/firrtlTests/CInferMDirSpec.scala +++ b/src/test/scala/firrtlTests/CInferMDirSpec.scala @@ -14,22 +14,21 @@ class CInferMDirSpec extends LowTransformSpec { def checkStmt(s: Statement): Boolean = s match { case s: DefMemory if s.name == "indices" => (s.readers contains "index") && - (s.writers contains "bar") && - s.readwriters.isEmpty + (s.writers contains "bar") && + s.readwriters.isEmpty case s: Block => - s.stmts exists checkStmt + s.stmts.exists(checkStmt) case _ => false } - def run (c: Circuit) = { + def run(c: Circuit) = { val errors = new Errors - val check = c.modules exists { - case m: Module => checkStmt(m.body) + val check = c.modules.exists { + case m: Module => checkStmt(m.body) case m: ExtModule => false } if (!check) { - errors append new PassException( - "Memory has incorrect port directions!") + errors.append(new PassException("Memory has incorrect port directions!")) errors.trigger } c diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala index 2016e160..d8151142 100644 --- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala +++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala @@ -12,46 +12,46 @@ import java.nio.file.Paths import firrtl.options.Dependency import firrtl.stage.FirrtlStage -class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops]) ){ +class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops])) { "Loop-free circuit" should "not throw an exception" in { val input = """circuit hasnoloops : - | module thru : - | input in1 : UInt<1> - | input in2 : UInt<1> - | output out1 : UInt<1> - | output out2 : UInt<1> - | out1 <= in1 - | out2 <= in2 - | module hasnoloops : - | input clk : Clock - | input a : UInt<1> - | output b : UInt<1> - | wire x : UInt<1> - | inst inner of thru - | inner.in1 <= a - | x <= inner.out1 - | inner.in2 <= x - | b <= inner.out2 - |""".stripMargin + | module thru : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | out1 <= in1 + | out2 <= in2 + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of thru + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin compile(parse(input)) } "Simple combinational loop" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | z <= y - | y <= z - | d <= z - |""".stripMargin + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | z <= y + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -60,12 +60,12 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Single-element combinational loop" should "throw an exception" in { val input = """circuit loop : - | module loop : - | output y : UInt<8> - | wire w : UInt<8> - | w <= w - | y <= w - |""".stripMargin + | module loop : + | output y : UInt<8> + | wire w : UInt<8> + | w <= w + | y <= w + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -74,18 +74,18 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Node combinational loop" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | c <= b - | node z = and(c,y) - | y <= z - | d <= z - |""".stripMargin + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | c <= b + | node z = and(c,y) + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -94,29 +94,29 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Combinational loop through a combinational memory read port" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | mem m : - | data-type => UInt<1> - | depth => 2 - | read-latency => 0 - | write-latency => 1 - | reader => r - | read-under-write => undefined - | m.r.clk <= clk - | m.r.addr <= y - | m.r.en <= UInt(1) - | z <= m.r.data - | y <= z - | d <= z - |""".stripMargin + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | mem m : + | data-type => UInt<1> + | depth => 2 + | read-latency => 0 + | write-latency => 1 + | reader => r + | read-under-write => undefined + | m.r.clk <= clk + | m.r.addr <= y + | m.r.en <= UInt(1) + | z <= m.r.data + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -125,25 +125,25 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Combination loop through an instance" should "throw an exception" in { val input = """circuit hasloops : - | module thru : - | input in : UInt<1> - | output out : UInt<1> - | out <= in - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | inst inner of thru - | inner.in <= y - | z <= inner.out - | y <= z - | d <= z - |""".stripMargin + | module thru : + | input in : UInt<1> + | output out : UInt<1> + | out <= in + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -152,24 +152,24 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Combinational loop through an annotated ExtModule" should "throw an exception" in { val input = """circuit hasloops : - | extmodule blackbox : - | input in : UInt<1> - | output out : UInt<1> - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | inst inner of blackbox - | inner.in <= y - | z <= inner.out - | y <= z - | d <= z - |""".stripMargin + | extmodule blackbox : + | input in : UInt<1> + | output out : UInt<1> + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of blackbox + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin val mt = ModuleTarget("hasloops", "blackbox") val annos = AnnotationSeq(Seq(ExtModulePathAnnotation(mt.ref("in"), mt.ref("out")))) @@ -180,53 +180,56 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Loop-free circuit with ExtModulePathAnnotations" should "not throw an exception" in { val input = """circuit hasnoloops : - | extmodule blackbox : - | input in1 : UInt<1> - | input in2 : UInt<1> - | output out1 : UInt<1> - | output out2 : UInt<1> - | module hasnoloops : - | input clk : Clock - | input a : UInt<1> - | output b : UInt<1> - | wire x : UInt<1> - | inst inner of blackbox - | inner.in1 <= a - | x <= inner.out1 - | inner.in2 <= x - | b <= inner.out2 - |""".stripMargin + | extmodule blackbox : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of blackbox + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin val mt = ModuleTarget("hasnoloops", "blackbox") - val annos = AnnotationSeq(Seq( - ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")), - ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2")))) + val annos = AnnotationSeq( + Seq( + ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")), + ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2")) + ) + ) compile(parse(input), annos) } "Combinational loop through an output RHS reference" should "throw an exception" in { val input = """circuit hasloops : - | module thru : - | input in : UInt<1> - | output tmp : UInt<1> - | output out : UInt<1> - | tmp <= in - | out <= tmp - | module hasloops : - | input clk : Clock - | input a : UInt<1> - | input b : UInt<1> - | output c : UInt<1> - | output d : UInt<1> - | wire y : UInt<1> - | wire z : UInt<1> - | c <= b - | inst inner of thru - | inner.in <= y - | z <= inner.out - | y <= z - | d <= z - |""".stripMargin + | module thru : + | input in : UInt<1> + | output tmp : UInt<1> + | output out : UInt<1> + | tmp <= in + | out <= tmp + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -235,21 +238,21 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops "Multiple simple loops in one SCC" should "throw an exception" in { val input = """circuit hasloops : - | module hasloops : - | input i : UInt<1> - | output o : UInt<1> - | wire a : UInt<1> - | wire b : UInt<1> - | wire c : UInt<1> - | wire d : UInt<1> - | wire e : UInt<1> - | a <= and(c,i) - | b <= and(a,d) - | c <= b - | d <= and(c,e) - | e <= b - | o <= e - |""".stripMargin + | module hasloops : + | input i : UInt<1> + | output o : UInt<1> + | wire a : UInt<1> + | wire b : UInt<1> + | wire c : UInt<1> + | wire d : UInt<1> + | wire e : UInt<1> + | a <= and(c,i) + | b <= and(a,d) + | c <= b + | d <= and(c,e) + | e <= b + | o <= e + |""".stripMargin intercept[CheckCombLoops.CombLoopException] { compile(parse(input)) @@ -280,7 +283,7 @@ class CheckCombLoopsSpec extends LeanTransformSpec(Seq(Dependency[CheckCombLoops val cs = compile(parse(input)) val mt = ModuleTarget("hasnoloops", "hasnoloops") val anno = CombinationalPath(mt.ref("b"), Seq(mt.ref("a"))) - cs.annotations.contains(anno) should be (true) + cs.annotations.contains(anno) should be(true) } } @@ -292,7 +295,7 @@ class CheckCombLoopsCommandLineSpec extends FirrtlFlatSpec { val args = Array("-i", inputFile.getAbsolutePath, "-o", outFile.getAbsolutePath, "-X", "verilog") "Combinational loops detection" should "run by default" in { - a [CheckCombLoops.CombLoopException] should be thrownBy { + a[CheckCombLoops.CombLoopException] should be thrownBy { (new FirrtlStage).execute(args, Seq()) } } diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index 34e0da03..5fd9543e 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -2,27 +2,27 @@ package firrtlTests -import firrtl.{CircuitState, UnknownForm, Transform} +import firrtl.{CircuitState, Transform, UnknownForm} import firrtl.passes._ import firrtl.testutils._ class CheckInitializationSpec extends FirrtlFlatSpec { private val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - CheckFlows, - new InferWidths, - CheckWidths, - PullMuxes, - ExpandConnects, - RemoveAccesses, - ExpandWhens, - CheckInitialization, - InferTypes + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveFlows, + CheckFlows, + new InferWidths, + CheckWidths, + PullMuxes, + ExpandConnects, + RemoveAccesses, + ExpandWhens, + CheckInitialization, + InferTypes ) "Missing assignment in consequence branch" should "trigger a PassException" in { val input = @@ -33,8 +33,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | x <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } @@ -48,8 +48,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | else : | x <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } @@ -64,8 +64,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | x <= UInt(1) | x <= UInt(1) | """.stripMargin - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } @@ -84,8 +84,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | x <= UInt(2) | x <= UInt(1) | """.stripMargin - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } @@ -100,8 +100,8 @@ class CheckInitializationSpec extends FirrtlFlatSpec { | when p : | c.in <= UInt(1)""".stripMargin intercept[CheckInitialization.RefNotInitializedException] { - passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + passes.foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index 5c38bf30..20a5f969 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -3,17 +3,29 @@ package firrtlTests import org.scalatest._ -import firrtl.{Parser, CircuitState, UnknownForm, Transform} +import firrtl.{CircuitState, Parser, Transform, UnknownForm} import firrtl.ir.Circuit -import firrtl.passes.{Pass,ToWorkingIR,CheckHighForm,ResolveKinds,InferTypes,CheckTypes,PassException,InferWidths,CheckWidths,ResolveFlows,CheckFlows} +import firrtl.passes.{ + CheckFlows, + CheckHighForm, + CheckTypes, + CheckWidths, + InferTypes, + InferWidths, + Pass, + PassException, + ResolveFlows, + ResolveKinds, + ToWorkingIR +} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class CheckSpec extends AnyFlatSpec with Matchers { val defaultPasses = Seq(ToWorkingIR, CheckHighForm) def checkHighInput(input: String) = { - defaultPasses.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + defaultPasses.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } @@ -44,9 +56,7 @@ class CheckSpec extends AnyFlatSpec with Matchers { } "Memories with zero write latency" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) + val passes = Seq(ToWorkingIR, CheckHighForm) val input = """circuit Unit : | module Unit : @@ -56,8 +66,8 @@ class CheckSpec extends AnyFlatSpec with Matchers { | read-latency => 0 | write-latency => 0""".stripMargin intercept[CheckHighForm.IllegalMemLatencyException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } } @@ -181,90 +191,81 @@ class CheckSpec extends AnyFlatSpec with Matchers { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """ - |circuit TheRealTop : - | - | module Top : - | output io : {flip debug_clk : Clock} - | - | extmodule BlackBoxTop : - | input jtag : {TCK : Clock} - | - | module TheRealTop : - | input clock : Clock - | input reset : UInt<1> - | output io : {flip jtag : {TCK : Clock}} - | - | io is invalid - | inst sub of Top - | sub.io is invalid - | inst bb of BlackBoxTop - | bb.jtag is invalid - | bb.jtag <- io.jtag - | - | sub.io.debug_clk <= io.jtag.TCK - | - |""".stripMargin + |circuit TheRealTop : + | + | module Top : + | output io : {flip debug_clk : Clock} + | + | extmodule BlackBoxTop : + | input jtag : {TCK : Clock} + | + | module TheRealTop : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip jtag : {TCK : Clock}} + | + | io is invalid + | inst sub of Top + | sub.io is invalid + | inst bb of BlackBoxTop + | bb.jtag is invalid + | bb.jtag <- io.jtag + | + | sub.io.debug_clk <= io.jtag.TCK + | + |""".stripMargin passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { (c: CircuitState, p: Transform) => p.runTransform(c) } } "Clocks with types other than ClockType" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """ - |circuit Top : - | - | module Top : - | input clk : UInt<1> - | input i : UInt<1> - | output o : UInt<1> - | - | reg r : UInt<1>, clk - | r <= i - | o <= r - | - |""".stripMargin + |circuit Top : + | + | module Top : + | input clk : UInt<1> + | input i : UInt<1> + | output o : UInt<1> + | + | reg r : UInt<1>, clk + | r <= i + | o <= r + | + |""".stripMargin intercept[CheckTypes.RegReqClk] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Illegal reset type" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """ - |circuit Top : - | - | module Top : - | input clk : Clock - | input reset : UInt<2> - | input i : UInt<1> - | output o : UInt<1> - | - | reg r : UInt<1>, clk with : (reset => (reset, UInt<1>("h00"))) - | r <= i - | o <= r - | - |""".stripMargin + |circuit Top : + | + | module Top : + | input clk : Clock + | input reset : UInt<2> + | input i : UInt<1> + | output o : UInt<1> + | + | reg r : UInt<1>, clk with : (reset => (reset, UInt<1>("h00"))) + | r <= i + | o <= r + | + |""".stripMargin intercept[CheckTypes.IllegalResetType] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } } } @@ -281,7 +282,7 @@ class CheckSpec extends AnyFlatSpec with Matchers { val exception = intercept[PassException] { checkHighInput(input) } - exception.getMessage should include (s"Primop $op argument $amount < 0") + exception.getMessage should include(s"Primop $op argument $amount < 0") } } @@ -301,11 +302,11 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } - behavior of "Uniqueness" + behavior.of("Uniqueness") for ((description, input) <- CheckSpec.nonUniqueExamples) { it should s"be asserted for $description" in { assertThrows[CheckHighForm.NotUniqueException] { - Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } + Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)) { case (c, tx) => tx.run(c) } } } } @@ -400,7 +401,7 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } - behavior of "CheckHighForm running on circuits containing ExtModules" + behavior.of("CheckHighForm running on circuits containing ExtModules") it should "throw an exception if parameterless ExtModules have the same ports, but different widths" in { val input = @@ -539,19 +540,17 @@ class CheckSpec extends AnyFlatSpec with Matchers { object CheckSpec { val nonUniqueExamples = List( - ("two ports with the same name", - """|circuit Top: - | module Top: - | input a: UInt<1> - | input a: UInt<1>""".stripMargin), - ("two nodes with the same name", - """|circuit Top: - | module Top: - | node a = UInt<1>("h0") - | node a = UInt<1>("h0")""".stripMargin), - ("a port and a node with the same name", - """|circuit Top: - | module Top: - | input a: UInt<1> - | node a = UInt<1>("h0") """.stripMargin) ) - } + ("two ports with the same name", """|circuit Top: + | module Top: + | input a: UInt<1> + | input a: UInt<1>""".stripMargin), + ("two nodes with the same name", """|circuit Top: + | module Top: + | node a = UInt<1>("h0") + | node a = UInt<1>("h0")""".stripMargin), + ("a port and a node with the same name", """|circuit Top: + | module Top: + | input a: UInt<1> + | node a = UInt<1>("h0") """.stripMargin) + ) +} diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index 372ba53b..11a27d65 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -16,37 +16,37 @@ class ChirrtlMemSpec extends LowFirrtlTransformSpec { type Netlist = collection.mutable.HashMap[String, Expression] def buildNetlist(netlist: Netlist)(s: Statement): Statement = { s match { - case s: Connect => Utils.kind(s.loc) match { - case MemKind => netlist(s.loc.serialize) = s.expr - case _ => - } + case s: Connect => + Utils.kind(s.loc) match { + case MemKind => netlist(s.loc.serialize) = s.expr + case _ => + } case _ => } - s map buildNetlist(netlist) + s.map(buildNetlist(netlist)) } // walks on memories and checks whether or not read enables are high def checkStmt(netlist: Netlist)(s: Statement): Boolean = s match { - case s: DefMemory if s.name == "mem" && s.readers.size == 1=> + case s: DefMemory if s.name == "mem" && s.readers.size == 1 => val en = MemPortUtils.memPortField(s, s.readers.head, "en") // memory read enable ?= 1 WrappedExpression.weq(netlist(en.serialize), Utils.one) case s: Block => - s.stmts exists checkStmt(netlist) + s.stmts.exists(checkStmt(netlist)) case _ => false } - def run (c: Circuit) = { + def run(c: Circuit) = { val errors = new Errors - val check = c.modules exists { + val check = c.modules.exists { case m: Module => val netlist = new Netlist checkStmt(netlist)(buildNetlist(netlist)(m.body)) case m: ExtModule => false } if (!check) { - errors append new PassException( - "Enable signal for the read port is incorrect!") + errors.append(new PassException("Enable signal for the read port is incorrect!")) errors.trigger } c @@ -105,18 +105,18 @@ circuit foo : "An mport that refers to an undefined memory" should "have a helpful error message" in { val input = """circuit testTestModule : - | module testTestModule : - | input clock : Clock - | input reset : UInt<1> - | output io : {flip in : UInt<10>, out : UInt<10>} - | - | node _T_10 = bits(io.in, 1, 0) - | read mport _T_11 = m[_T_10], clock - | io.out <= _T_11""".stripMargin + | module testTestModule : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip in : UInt<10>, out : UInt<10>} + | + | node _T_10 = bits(io.in, 1, 0) + | read mport _T_11 = m[_T_10], clock + | io.out <= _T_11""".stripMargin - intercept[PassException]{ + intercept[PassException] { compile(parse(input)) - }.getMessage should startWith ("Undefined memory m referenced by mport _T_11") + }.getMessage should startWith("Undefined memory m referenced by mport _T_11") } ignore should "Memories should not have validif on port clocks when declared in a when" in { @@ -167,9 +167,19 @@ circuit foo : | io.dataOut <= out @[Stack.scala 31:14] """.stripMargin val res = compile(parse(input)) - assert(res search { - case Connect(_, WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_35",_, _), "clk", _, _), WRef("clock", _, _, _)) => true - case Connect(_, WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_17",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + assert(res.search { + case Connect( + _, + WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_35", _, _), "clk", _, _), + WRef("clock", _, _, _) + ) => + true + case Connect( + _, + WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_17", _, _), "clk", _, _), + WRef("clock", _, _, _) + ) => + true }) } @@ -188,8 +198,9 @@ circuit foo : | out <= bar |""".stripMargin val res = compile(parse(input)) - assert(res search { - case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + assert(res.search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar", _, _), "clk", _, _), WRef("clock", _, _, _)) => + true }) } @@ -209,8 +220,9 @@ circuit foo : | out <= bar |""".stripMargin val res = compile(parse(input)) - assert(res search { - case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + assert(res.search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar", _, _), "clk", _, _), WRef("clock", _, _, _)) => + true }) } @@ -230,12 +242,16 @@ circuit foo : | out <= bar |""".stripMargin val res = new LowFirrtlCompiler().compile(CircuitState(parse(input), ChirrtlForm), Seq()).circuit - assert(res search { - case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), DoPrim(AsClock, Seq(WRef("clock", _, _, _)), Nil, _)) => true + assert(res.search { + case Connect( + _, + WSubField(WSubField(WRef("mem", _, _, _), "bar", _, _), "clk", _, _), + DoPrim(AsClock, Seq(WRef("clock", _, _, _)), Nil, _) + ) => + true }) } - ignore should "Mem non-local nested clock port assignment should be ok" in { val input = """circuit foo : @@ -251,8 +267,13 @@ circuit foo : | out <= bar |""".stripMargin val res = (new HighFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm), Seq()).circuit - assert(res search { - case Connect(_, SubField(SubField(Reference("mem", _, _, _), "bar", _, _), "clk", _, _), DoPrim(AsClock, Seq(Reference("clock", _, _, _)), _, _)) => true + assert(res.search { + case Connect( + _, + SubField(SubField(Reference("mem", _, _, _), "bar", _, _), "clk", _, _), + DoPrim(AsClock, Seq(Reference("clock", _, _, _)), _, _) + ) => + true }) } } diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index dcc8b872..2d13c835 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -30,49 +30,49 @@ class ChirrtlSpec extends FirrtlFlatSpec { "Chirrtl memories" should "allow ports with clocks defined after the memory" in { val input = - """circuit Unit : - | module Unit : - | input clock : Clock - | smem ram : UInt<32>[128] - | node newClock = clock - | infer mport x = ram[UInt(2)], newClock - | x <= UInt(3) - | when UInt(1) : - | infer mport y = ram[UInt(4)], newClock - | y <= UInt(5) + """circuit Unit : + | module Unit : + | input clock : Clock + | smem ram : UInt<32>[128] + | node newClock = clock + | infer mport x = ram[UInt(2)], newClock + | x <= UInt(3) + | when UInt(1) : + | infer mport y = ram[UInt(4)], newClock + | y <= UInt(5) """.stripMargin val circuit = Parser.parse(input.split("\n").toIterator) - transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } "Chirrtl" should "catch undeclared wires" in { val input = - """circuit Unit : - | module Unit : - | input clock : Clock - | smem ram : UInt<32>[128] - | node newClock = clock - | infer mport x = ram[UInt(2)], newClock - | x <= UInt(3) - | when UInt(1) : - | infer mport y = ram[UInt(4)], newClock - | y <= z + """circuit Unit : + | module Unit : + | input clock : Clock + | smem ram : UInt<32>[128] + | node newClock = clock + | infer mport x = ram[UInt(2)], newClock + | x <= UInt(3) + | when UInt(1) : + | infer mport y = ram[UInt(4)], newClock + | y <= z """.stripMargin intercept[PassException] { val circuit = Parser.parse(input.split("\n").toIterator) - transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } } } - behavior of "Uniqueness" + behavior.of("Uniqueness") for ((description, input) <- CheckSpec.nonUniqueExamples) { it should s"be asserted for $description" in { assertThrows[CheckHighForm.NotUniqueException] { - Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)){ case (c, tx) => tx.run(c) } + Seq(ToWorkingIR, CheckHighForm).foldLeft(Parser.parse(input)) { case (c, tx) => tx.run(c) } } } } diff --git a/src/test/scala/firrtlTests/ClockListTests.scala b/src/test/scala/firrtlTests/ClockListTests.scala index 9233d4d5..c547448b 100644 --- a/src/test/scala/firrtlTests/ClockListTests.scala +++ b/src/test/scala/firrtlTests/ClockListTests.scala @@ -11,12 +11,12 @@ import clocklist._ class ClockListTests extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Pass) => + p.run(c) } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -69,19 +69,21 @@ class ClockListTests extends FirrtlFlatSpec { | output clk2: Clock | output clk3: Clock |""".stripMargin - val check = - """Sourcelist: List(h$clkGen$clk1, h$clkGen$clk2, h$clkGen$clk3, clock) - |Good Origin of clock is clock - |Good Origin of h.clock is h$clkGen.clk1 - |Good Origin of h$b.clock is h$clkGen.clk2 - |Good Origin of h$c.clock is h$clkGen.clk3 - |""".stripMargin - val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit + val check = + """Sourcelist: List(h$clkGen$clk1, h$clkGen$clk2, h$clkGen$clk3, clock) + |Good Origin of clock is clock + |Good Origin of h.clock is h$clkGen.clk1 + |Good Origin of h$b.clock is h$clkGen.clk2 + |Good Origin of h$c.clock is h$clkGen.clk3 + |""".stripMargin + val c = passes + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit val writer = new StringWriter() val retC = new ClockList("HTop", writer).run(c) - (writer.toString) should be (check) + (writer.toString) should be(check) } "A->B->C, and A.clock == C.clock" should "still emit C.clock origin" in { val input = @@ -101,18 +103,20 @@ class ClockListTests extends FirrtlFlatSpec { | input clock: Clock | reg r: UInt<5>, clock |""".stripMargin - val check = - """Sourcelist: List(clock, clkB) - |Good Origin of clock is clock - |Good Origin of b.clock is clkB - |Good Origin of b$c.clock is clock - |""".stripMargin - val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit + val check = + """Sourcelist: List(clock, clkB) + |Good Origin of clock is clock + |Good Origin of b.clock is clkB + |Good Origin of b$c.clock is clock + |""".stripMargin + val c = passes + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit val writer = new StringWriter() val retC = new ClockList("A", writer).run(c) - (writer.toString) should be (check) + (writer.toString) should be(check) } "Have not circuit main be top of clocklist pass" should "still work" in { val input = @@ -136,15 +140,17 @@ class ClockListTests extends FirrtlFlatSpec { | input clock: Clock |""".stripMargin val check = - """Sourcelist: List(clock, clkC) - |Good Origin of clock is clock - |Good Origin of c.clock is clkC - |""".stripMargin - val c = passes.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit + """Sourcelist: List(clock, clkC) + |Good Origin of clock is clock + |Good Origin of c.clock is clkC + |""".stripMargin + val c = passes + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit val writer = new StringWriter() val retC = new ClockList("B", writer).run(c) - (writer.toString) should be (check) + (writer.toString) should be(check) } } diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index dfa796c4..129ff8f5 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -12,36 +12,36 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers /** - * An example methodology for testing Firrtl compilers. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler - * should be compared against the check string. - */ + * An example methodology for testing Firrtl compilers. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler + * should be compared against the check string. + */ abstract class CompilerSpec(emitter: Dependency[firrtl.Emitter]) extends LeanTransformSpec(Seq(emitter)) { - def input: String - def getOutput: String = compile(input).getEmittedCircuit.value + def input: String + def getOutput: String = compile(input).getEmittedCircuit.value } /** - * An example test for testing the HighFirrtlCompiler. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler - * is parsed again and compared (in-memory) to the parsed - * input. - */ + * An example test for testing the HighFirrtlCompiler. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler + * is parsed again and compared (in-memory) to the parsed + * input. + */ class HighFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.HighFirrtlEmitter]) with Matchers { - val input = -"""circuit Top : + val input = + """circuit Top : module Top : input a : UInt<1>[2] node x = a """ - val check = input - "Any circuit" should "match exactly to its input" in { - (parse(getOutput)) should be (parse(check)) - } + val check = input + "Any circuit" should "match exactly to its input" in { + (parse(getOutput)) should be(parse(check)) + } } /** @@ -53,8 +53,8 @@ class HighFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.HighFirrtlEm * string compared to the correct lowered circuit. */ class MiddleFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.MiddleFirrtlEmitter]) with Matchers { - val input = - """ + val input = + """ circuit Top : module Top : input reset : UInt<1> @@ -64,77 +64,77 @@ circuit Top : when reset : b <= UInt(0) """ - // Verify that Vecs are retained, but widths are inferred and whens are expanded. - val check = Seq( - "circuit Top :", - " module Top :", - " input reset : UInt<1>", - " input a : UInt<1>[2]", - " wire b : UInt<1>", - " node _GEN_0 = mux(reset, UInt<1>(\"h0\"), a[0])", - " b <= _GEN_0\n\n" - ).reduce(_ + "\n" + _) - "A circuit" should "match exactly to its MidForm state" in { - (parse(getOutput)) should be (parse(check)) - } + // Verify that Vecs are retained, but widths are inferred and whens are expanded. + val check = Seq( + "circuit Top :", + " module Top :", + " input reset : UInt<1>", + " input a : UInt<1>[2]", + " wire b : UInt<1>", + " node _GEN_0 = mux(reset, UInt<1>(\"h0\"), a[0])", + " b <= _GEN_0\n\n" + ).reduce(_ + "\n" + _) + "A circuit" should "match exactly to its MidForm state" in { + (parse(getOutput)) should be(parse(check)) + } } /** - * An example test for testing the LoweringCompiler. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler is - * a lowered version of the input circuit. The output is - * string compared to the correct lowered circuit. - */ + * An example test for testing the LoweringCompiler. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler is + * a lowered version of the input circuit. The output is + * string compared to the correct lowered circuit. + */ class LowFirrtlCompilerSpec extends CompilerSpec(Dependency[firrtl.LowFirrtlEmitter]) with Matchers { - val input = -""" + val input = + """ circuit Top : module Top : input a : UInt<1>[2] node x = a """ - val check = Seq( - "circuit Top :", - " module Top :", - " input a_0 : UInt<1>", - " input a_1 : UInt<1>", - " node x_0 = a_0", - " node x_1 = a_1\n\n" - ).reduce(_ + "\n" + _) - "A circuit" should "match exactly to its lowered state" in { - (parse(getOutput)) should be (parse(check)) - } + val check = Seq( + "circuit Top :", + " module Top :", + " input a_0 : UInt<1>", + " input a_1 : UInt<1>", + " node x_0 = a_0", + " node x_1 = a_1\n\n" + ).reduce(_ + "\n" + _) + "A circuit" should "match exactly to its lowered state" in { + (parse(getOutput)) should be(parse(check)) + } } /** - * An example test for testing the VerilogCompiler. - * - * Given an input Firrtl circuit (expressed as a string), - * the compiler is executed. The output of the compiler is - * the corresponding Verilog. The output is string compared - * to the correct Verilog. - */ + * An example test for testing the VerilogCompiler. + * + * Given an input Firrtl circuit (expressed as a string), + * the compiler is executed. The output of the compiler is + * the corresponding Verilog. The output is string compared + * to the correct Verilog. + */ class VerilogCompilerSpec extends CompilerSpec(Dependency[firrtl.VerilogEmitter]) with Matchers { - val input = """circuit Top : - | module Top : - | input a : UInt<1>[2] - | output b : UInt<1>[2] - | b <= a""".stripMargin - val check = """module Top( - | input a_0, - | input a_1, - | output b_0, - | output b_1 - |); - | assign b_0 = a_0; - | assign b_1 = a_1; - |endmodule - |""".stripMargin - "A circuit's verilog output" should "match the given string and not have RANDOMIZE if no invalids" in { - getOutput should be (check) - } + val input = """circuit Top : + | module Top : + | input a : UInt<1>[2] + | output b : UInt<1>[2] + | b <= a""".stripMargin + val check = """module Top( + | input a_0, + | input a_1, + | output b_0, + | output b_1 + |); + | assign b_0 = a_0; + | assign b_1 = a_1; + |endmodule + |""".stripMargin + "A circuit's verilog output" should "match the given string and not have RANDOMIZE if no invalids" in { + getOutput should be(check) + } } class MinimumVerilogCompilerSpec extends CompilerSpec(Dependency[firrtl.MinimumVerilogEmitter]) with Matchers { @@ -166,6 +166,6 @@ class MinimumVerilogCompilerSpec extends CompilerSpec(Dependency[firrtl.MinimumV |endmodule |""".stripMargin "A circuit's minimum Verilog output" should "pad signed RHSes but not reflect any const-prop or DCE" in { - getOutput should be (check) + getOutput should be(check) } } diff --git a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala index bfb53ce1..530a036a 100644 --- a/src/test/scala/firrtlTests/CompilerUtilsSpec.scala +++ b/src/test/scala/firrtlTests/CompilerUtilsSpec.scala @@ -30,39 +30,39 @@ class CompilerUtilsSpec extends FirrtlFlatSpec { val lowToLowTwo = genTransform(LowForm, LowForm) - behavior of "mergeTransforms" + behavior.of("mergeTransforms") it should "do nothing if there are no custom transforms" in { - mergeTransforms(chirrtlToLowList, List.empty) should be (chirrtlToLowList) + mergeTransforms(chirrtlToLowList, List.empty) should be(chirrtlToLowList) } it should "insert transforms at the correct place" in { mergeTransforms(chirrtlToLowList, List(chirrtlToChirrtl)) should be - (chirrtlToChirrtl +: chirrtlToLowList) + (chirrtlToChirrtl +: chirrtlToLowList) mergeTransforms(chirrtlToLowList, List(highToHigh)) should be - (List(chirrtlToHigh, highToHigh, highToMid, midToLow)) + (List(chirrtlToHigh, highToHigh, highToMid, midToLow)) mergeTransforms(chirrtlToLowList, List(midToMid)) should be - (List(chirrtlToHigh, highToMid, midToMid, midToLow)) + (List(chirrtlToHigh, highToMid, midToMid, midToLow)) mergeTransforms(chirrtlToLowList, List(lowToLow)) should be - (chirrtlToLowList :+ lowToLow) + (chirrtlToLowList :+ lowToLow) } it should "insert transforms at the last legal location" in { lowToLow should not be (lowToLowTwo) // sanity check - mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be (lowToLowTwo) + mergeTransforms(chirrtlToLowList :+ lowToLow, List(lowToLowTwo)).last should be(lowToLowTwo) } it should "insert multiple transforms correctly" in { mergeTransforms(chirrtlToLowList, List(highToHigh, lowToLow)) should be - (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow)) + (List(chirrtlToHigh, highToHigh, highToMid, midToLow, lowToLow)) } it should "handle transforms that raise the form" in { mergeTransforms(chirrtlToLowList, List(lowToHigh)) match { case chirrtlToHigh :: highToMid :: midToLow :: lowToHigh :: remainder => // Remainder will be the actual Firrtl lowering transforms - remainder.head.inputForm should be (HighForm) - remainder.last.outputForm should be (LowForm) + remainder.head.inputForm should be(HighForm) + remainder.last.outputForm should be(LowForm) case _ => fail() } } @@ -70,8 +70,7 @@ class CompilerUtilsSpec extends FirrtlFlatSpec { // Order is not always maintained, see note on function Scaladoc it should "maintain order of custom tranforms" in { mergeTransforms(chirrtlToLowList, List(lowToLow, lowToLowTwo)) should be - (chirrtlToLowList ++ List(lowToLow, lowToLowTwo)) + (chirrtlToLowList ++ List(lowToLow, lowToLowTwo)) } } - diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index efe85e48..6ab54159 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -9,24 +9,22 @@ import firrtl.testutils._ import firrtl.annotations.Annotation class ConstantPropagationSpec extends FirrtlFlatSpec { - val transforms: Seq[Transform] = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - new ConstantPropagation) + val transforms: Seq[Transform] = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, new ConstantPropagation) protected def exec(input: String, annos: Seq[Annotation] = Nil) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm, AnnotationSeq(annos))) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } } class ConstantPropagationMultiModule extends ConstantPropagationSpec { - "ConstProp" should "propagate constant inputs" in { - val input = -"""circuit Top : + "ConstProp" should "propagate constant inputs" in { + val input = + """circuit Top : module Child : input in0 : UInt<1> input in1 : UInt<1> @@ -40,8 +38,8 @@ class ConstantPropagationMultiModule extends ConstantPropagationSpec { c.in1 <= UInt<1>(1) z <= c.out """ - val check = -"""circuit Top : + val check = + """circuit Top : module Child : input in0 : UInt<1> input in1 : UInt<1> @@ -55,12 +53,12 @@ class ConstantPropagationMultiModule extends ConstantPropagationSpec { c.in1 <= UInt<1>(1) z <= c.out """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { - def circuit(allSame: Boolean) = -s"""circuit Top : + "ConstProp" should "propagate constant inputs ONLY if ALL instance inputs get the same value" in { + def circuit(allSame: Boolean) = + s"""circuit Top : module Bottom : input in : UInt<1> output out : UInt<1> @@ -83,8 +81,8 @@ s"""circuit Top : z <= and(and(b0.out, b1.out), c.out) """ - val resultFromAllSame = -"""circuit Top : + val resultFromAllSame = + """circuit Top : module Bottom : input in : UInt<1> output out : UInt<1> @@ -104,14 +102,14 @@ s"""circuit Top : b1.in <= UInt(1) z <= UInt(1) """ - (parse(exec(circuit(false)))) should be (parse(circuit(false))) - (parse(exec(circuit(true)))) should be (parse(resultFromAllSame)) - } - - // ============================= - "ConstProp" should "do nothing on unrelated modules" in { - val input = -"""circuit foo : + (parse(exec(circuit(false)))) should be(parse(circuit(false))) + (parse(exec(circuit(true)))) should be(parse(resultFromAllSame)) + } + + // ============================= + "ConstProp" should "do nothing on unrelated modules" in { + val input = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -120,14 +118,14 @@ s"""circuit Top : input dummy : UInt<1> skip """ - val check = input - (parse(exec(input))) should be (parse(check)) - } - - // ============================= - "ConstProp" should "propagate module chains not connected to the top" in { - val input = -"""circuit foo : + val check = input + (parse(exec(input))) should be(parse(check)) + } + + // ============================= + "ConstProp" should "propagate module chains not connected to the top" in { + val input = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -151,8 +149,8 @@ s"""circuit Top : output test : UInt<1> test <= UInt<1>(0) """ - val check = -"""circuit foo : + val check = + """circuit foo : module foo : input dummy : UInt<1> skip @@ -176,8 +174,8 @@ s"""circuit Top : output test : UInt<1> test <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } } // Tests the following cases for constant propagation: @@ -188,332 +186,332 @@ s"""circuit Top : // 3) Values are always greater than a number smaller // than their minimum value class ConstantPropagationSingleModule extends ConstantPropagationSpec { - // ============================= - "The rule x >= 0 " should " always be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule x >= 0 " should " always be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= geq(x, UInt(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>("h1") """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x < 0 " should " never be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule x < 0 " should " never be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= lt(x, UInt(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 0 <= x " should " always be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule 0 <= x " should " always be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= leq(UInt(0),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 0 > x " should " never be true if x is a UInt" in { - val input = -"""circuit Top : + // ============================= + "The rule 0 > x " should " never be true if x is a UInt" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= gt(UInt(0),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 1 < 3 " should " always be true" in { - val input = -"""circuit Top : + // ============================= + "The rule 1 < 3 " should " always be true" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= lt(UInt(0),UInt(3)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x < 8 " should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x < 8 " should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= lt(x,UInt(8)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x <= 7 " should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x <= 7 " should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= leq(x,UInt(7)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 8 > x" should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 8 > x" should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= gt(UInt(8),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 7 >= x" should " always be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 7 >= x" should " always be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= geq(UInt(7),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 10 == 10" should " always be true" in { - val input = -"""circuit Top : + // ============================= + "The rule 10 == 10" should " always be true" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= eq(UInt(10),UInt(10)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(1) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x == z " should " not be true even if they have the same number of bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x == z " should " not be true even if they have the same number of bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> input z : UInt<3> output y : UInt<1> y <= eq(x,z) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> input z : UInt<3> output y : UInt<1> y <= eq(x,z) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 10 != 10 " should " always be false" in { - val input = -"""circuit Top : + // ============================= + "The rule 10 != 10 " should " always be false" in { + val input = + """circuit Top : module Top : output y : UInt<1> y <= neq(UInt(10),UInt(10)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : output y : UInt<1> y <= UInt(0) """ - (parse(exec(input))) should be (parse(check)) - } - // ============================= - "The rule 1 >= 3 " should " always be false" in { - val input = -"""circuit Top : + (parse(exec(input))) should be(parse(check)) + } + // ============================= + "The rule 1 >= 3 " should " always be false" in { + val input = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= geq(UInt(1),UInt(3)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<5> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x >= 8 " should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x >= 8 " should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= geq(x,UInt(8)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule x > 7 " should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule x > 7 " should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= gt(x,UInt(7)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 8 <= x" should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 8 <= x" should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= leq(UInt(8),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "The rule 7 < x" should " never be true if x only has 3 bits" in { - val input = -"""circuit Top : + // ============================= + "The rule 7 < x" should " never be true if x only has 3 bits" in { + val input = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= lt(UInt(7),x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<3> output y : UInt<1> y <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "work across wires" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "work across wires" in { + val input = + """circuit Top : module Top : input x : UInt<1> output y : UInt<1> @@ -521,8 +519,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { y <= z z <= mux(x, UInt<1>(0), UInt<1>(0)) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> output y : UInt<1> @@ -530,13 +528,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { y <= UInt<1>(0) z <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary nodes that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary nodes that drive them" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -545,8 +543,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node n = _T_1 z <= and(n, x) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -555,13 +553,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node _T_1 = n z <= and(n, x) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary wires that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary wires that drive them" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -571,8 +569,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n _T_1 <= and(x, y) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -582,13 +580,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n n <= and(x, y) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "swap named nodes with temporary registers that drive them" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "swap named nodes with temporary registers that drive them" in { + val input = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -598,8 +596,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n _T_1 <= x """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -609,13 +607,13 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { z <= n n <= x """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - // ============================= - "ConstProp" should "only swap a given name with one other name" in { - val input = -"""circuit Top : + // ============================= + "ConstProp" should "only swap a given name with one other name" in { + val input = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -625,8 +623,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node m = _T_1 z <= add(n, m) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input x : UInt<1> input y : UInt<1> @@ -636,12 +634,12 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { node m = n z <= add(n, n) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "NOT swap wire names with node names" in { - val input = -"""circuit Top : + "ConstProp" should "NOT swap wire names with node names" in { + val input = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -653,8 +651,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { hit <= _T_2 z <= hit """ - val check = -"""circuit Top : + val check = + """circuit Top : module Top : input clock : Clock input x : UInt<1> @@ -666,12 +664,12 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { hit <= or(x, y) z <= hit """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant outputs" in { - val input = -"""circuit Top : + "ConstProp" should "propagate constant outputs" in { + val input = + """circuit Top : module Child : output out : UInt<1> out <= UInt<1>(0) @@ -681,8 +679,8 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { inst c of Child z <= and(x, c.out) """ - val check = -"""circuit Top : + val check = + """circuit Top : module Child : output out : UInt<1> out <= UInt<1>(0) @@ -692,10 +690,10 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { inst c of Child z <= UInt<1>(0) """ - (parse(exec(input))) should be (parse(check)) - } + (parse(exec(input))) should be(parse(check)) + } - "ConstProp" should "propagate constant addition" in { + "ConstProp" should "propagate constant addition" in { val input = """circuit Top : | module Top : @@ -717,7 +715,7 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { (parse(exec(input))) should be(parse(check)) } - "ConstProp" should "propagate addition with zero" in { + "ConstProp" should "propagate addition with zero" in { val input = """circuit Top : | module Top : @@ -779,20 +777,20 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec { def castCheck(tpe: String, cast: String): Unit = { val input = - s"""circuit Top : - | module Top : - | input x : $tpe - | output z : $tpe - | z <= $cast(x) + s"""circuit Top : + | module Top : + | input x : $tpe + | output z : $tpe + | z <= $cast(x) """.stripMargin val check = - s"""circuit Top : - | module Top : - | input x : $tpe - | output z : $tpe - | z <= x + s"""circuit Top : + | module Top : + | input x : $tpe + | output z : $tpe + | z <= x """.stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } it should "optimize unnecessary casts" in { castCheck("UInt<4>", "asUInt") @@ -807,218 +805,217 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { def transform = new LowFirrtlOptimization "ConstProp" should "NOT optimize across dontTouch on nodes" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin - val check = input + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } it should "NOT optimize across nodes marked dontTouch by other annotations" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin - val check = input - val dontTouchRT = annotations.ModuleTarget("Top", "Top").ref("z") + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin + val check = input + val dontTouchRT = annotations.ModuleTarget("Top", "Top").ref("z") execute(input, check, Seq(AnnotationWithDontTouches(dontTouchRT))) } it should "NOT optimize across dontTouch on registers" in { - val input = - """circuit Top : - | module Top : - | input clk : Clock - | input reset : UInt<1> - | output y : UInt<1> - | reg z : UInt<1>, clk - | y <= z - | z <= mux(reset, UInt<1>("h0"), z)""".stripMargin - val check = input + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input reset : UInt<1> + | output y : UInt<1> + | reg z : UInt<1>, clk + | y <= z + | z <= mux(reset, UInt<1>("h0"), z)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Top.z"))) } - it should "NOT optimize across dontTouch on wires" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | wire z : UInt<1> - | y <= z - | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | output y : UInt<1> - | node z = x - | y <= z""".stripMargin + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | wire z : UInt<1> + | y <= z + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output y : UInt<1> + | node z = x + | y <= z""".stripMargin execute(input, check, Seq(dontTouch("Top.z"))) } it should "NOT optimize across dontTouch on output ports" in { val input = """circuit Top : - | module Child : - | output out : UInt<1> - | out <= UInt<1>(0) - | module Top : - | input x : UInt<1> - | output z : UInt<1> - | inst c of Child - | z <= and(x, c.out)""".stripMargin - val check = input + | module Child : + | output out : UInt<1> + | out <= UInt<1>(0) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= and(x, c.out)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Child.out"))) } it should "NOT optimize across dontTouch on input ports" in { val input = """circuit Top : - | module Child : - | input in0 : UInt<1> - | input in1 : UInt<1> - | output out : UInt<1> - | out <= and(in0, in1) - | module Top : - | input x : UInt<1> - | output z : UInt<1> - | inst c of Child - | z <= c.out - | c.in0 <= x - | c.in1 <= UInt<1>(1)""".stripMargin - val check = input + | module Child : + | input in0 : UInt<1> + | input in1 : UInt<1> + | output out : UInt<1> + | out <= and(in0, in1) + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst c of Child + | z <= c.out + | c.in0 <= x + | c.in1 <= UInt<1>(1)""".stripMargin + val check = input execute(input, check, Seq(dontTouch("Child.in1"))) } it should "still propagate constants even when there is name swapping" in { - val input = - """circuit Top : - | module Top : - | input x : UInt<1> - | input y : UInt<1> - | output z : UInt<1> - | node _T_1 = and(and(x, y), UInt<1>(0)) - | node n = _T_1 - | z <= n""".stripMargin - val check = - """circuit Top : - | module Top : - | input x : UInt<1> - | input y : UInt<1> - | output z : UInt<1> - | z <= UInt<1>(0)""".stripMargin + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | node _T_1 = and(and(x, y), UInt<1>(0)) + | node n = _T_1 + | z <= n""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | z <= UInt<1>(0)""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to wires when propagating" in { - val input = - """circuit Top : - | module Top : - | output z : UInt<16> - | wire w : { a : UInt<8>, b : UInt<8> } - | w.a <= UInt<2>("h3") - | w.b <= UInt<2>("h3") - | z <= cat(w.a, w.b)""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Top : + | output z : UInt<16> + | wire w : { a : UInt<8>, b : UInt<8> } + | w.a <= UInt<2>("h3") + | w.b <= UInt<2>("h3") + | z <= cat(w.a, w.b)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to registers when propagating" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | reg r : { a : UInt<8>, b : UInt<8> }, clock - | r.a <= UInt<2>("h3") - | r.b <= UInt<2>("h3") - | z <= cat(r.a, r.b)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | reg r : { a : UInt<8>, b : UInt<8> }, clock + | r.a <= UInt<2>("h3") + | r.b <= UInt<2>("h3") + | z <= cat(r.a, r.b)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad zero when constant propping a register replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | reg r : UInt<8>, clock - | r <= or(r, UInt(0)) - | node n = UInt("hab") - | z <= cat(n, r)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<16> - | z <= UInt<16>("hab00")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | reg r : UInt<8>, clock + | r <= or(r, UInt(0)) + | node n = UInt("hab") + | z <= cat(n, r)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<16> + | z <= UInt<16>("hab00")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to outputs when propagating" in { - val input = - """circuit Top : - | module Child : - | output x : UInt<8> - | x <= UInt<2>("h3") - | module Top : - | output z : UInt<16> - | inst c of Child - | z <= cat(UInt<2>("h3"), c.x)""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Child : + | output x : UInt<8> + | x <= UInt<2>("h3") + | module Top : + | output z : UInt<16> + | inst c of Child + | z <= cat(UInt<2>("h3"), c.x)""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "pad constant connections to submodule inputs when propagating" in { - val input = - """circuit Top : - | module Child : - | input x : UInt<8> - | output y : UInt<16> - | y <= cat(UInt<2>("h3"), x) - | module Top : - | output z : UInt<16> - | inst c of Child - | c.x <= UInt<2>("h3") - | z <= c.y""".stripMargin - val check = - """circuit Top : - | module Top : - | output z : UInt<16> - | z <= UInt<16>("h303")""".stripMargin + val input = + """circuit Top : + | module Child : + | input x : UInt<8> + | output y : UInt<16> + | y <= cat(UInt<2>("h3"), x) + | module Top : + | output z : UInt<16> + | inst c of Child + | c.x <= UInt<2>("h3") + | z <= c.y""".stripMargin + val check = + """circuit Top : + | module Top : + | output z : UInt<16> + | z <= UInt<16>("h303")""".stripMargin execute(input, check, Seq.empty) } it should "remove pads if the width is <= the width of the argument" in { def input(w: Int) = - s"""circuit Top : - | module Top : - | input x : UInt<8> - | output z : UInt<8> - | z <= pad(x, $w)""".stripMargin + s"""circuit Top : + | module Top : + | input x : UInt<8> + | output z : UInt<8> + | z <= pad(x, $w)""".stripMargin val check = """circuit Top : | module Top : @@ -1029,247 +1026,246 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input(8), check, Seq.empty) } - "Registers with no reset or connections" should "be replaced with constant zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<8> - | reg r : UInt<8>, clock - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output z : UInt<8> - | z <= UInt<8>(0)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<8> + | reg r : UInt<8>, clock + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output z : UInt<8> + | z <= UInt<8>(0)""".stripMargin execute(input, check, Seq.empty) } "Registers with ONLY constant reset" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Registers async reset and a constant connection" should "NOT be removed" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : AsyncReset - | input en : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | when en : - | r <= UInt<4>("h0") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : AsyncReset - | input en : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : - | reset => (reset, UInt<8>("hb")) - | z <= r - | r <= mux(en, UInt<8>("h0"), r)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : AsyncReset + | input en : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when en : + | r <= UInt<4>("h0") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : AsyncReset + | input en : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : + | reset => (reset, UInt<8>("hb")) + | z <= r + | r <= mux(en, UInt<8>("h0"), r)""".stripMargin execute(input, check, Seq.empty) } "Registers with constant reset and connection to the same constant" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cond : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | when cond : - | r <= UInt<4>("hb") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cond : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | when cond : + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cond : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Const prop of registers" should "do limited speculative expansion of optimized muxes to absorb bigger cones" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input en : UInt<1> - | output out : UInt<1> - | reg r1 : UInt<1>, clock - | reg r2 : UInt<1>, clock - | when en : - | r1 <= UInt<1>(1) - | r2 <= UInt<1>(0) - | when en : - | r2 <= r2 - | out <= xor(r1, r2)""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input en : UInt<1> - | output out : UInt<1> - | out <= UInt<1>("h1")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input en : UInt<1> + | output out : UInt<1> + | reg r1 : UInt<1>, clock + | reg r2 : UInt<1>, clock + | when en : + | r1 <= UInt<1>(1) + | r2 <= UInt<1>(0) + | when en : + | r2 <= r2 + | out <= xor(r1, r2)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input en : UInt<1> + | output out : UInt<1> + | out <= UInt<1>("h1")""".stripMargin execute(input, check, Seq.empty) } "A register with constant reset and all connection to either itself or the same constant" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cmd : UInt<3> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) - | r <= r - | when eq(cmd, UInt<3>("h0")) : - | r <= UInt<3>("h7") - | else : - | when eq(cmd, UInt<3>("h1")) : - | r <= r - | else : - | when eq(cmd, UInt<3>("h2")) : - | r <= UInt<4>("h7") - | else : - | r <= r - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | input cmd : UInt<3> - | output z : UInt<8> - | z <= UInt<8>("h7")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("h7"))) + | r <= r + | when eq(cmd, UInt<3>("h0")) : + | r <= UInt<3>("h7") + | else : + | when eq(cmd, UInt<3>("h1")) : + | r <= r + | else : + | when eq(cmd, UInt<3>("h2")) : + | r <= UInt<4>("h7") + | else : + | r <= r + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | input cmd : UInt<3> + | output z : UInt<8> + | z <= UInt<8>("h7")""".stripMargin execute(input, check, Seq.empty) } "Registers with ONLY constant connection" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : SInt<8> - | reg r : SInt<8>, clock - | r <= SInt<4>(-5) - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : SInt<8> - | z <= SInt<8>(-5)""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : SInt<8> + | reg r : SInt<8>, clock + | r <= SInt<4>(-5) + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : SInt<8> + | z <= SInt<8>(-5)""".stripMargin execute(input, check, Seq.empty) } "Registers with identical constant reset and connection" should "be replaced with that constant" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) - | r <= UInt<4>("hb") - | z <= r""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | input reset : UInt<1> - | output z : UInt<8> - | z <= UInt<8>("hb")""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (reset, UInt<4>("hb"))) + | r <= UInt<4>("hb") + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | input reset : UInt<1> + | output z : UInt<8> + | z <= UInt<8>("hb")""".stripMargin execute(input, check, Seq.empty) } "Connections to a node reference" should "be replaced with the rhs of that node" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<8> - | input b : UInt<8> - | input c : UInt<1> - | output z : UInt<8> - | node x = mux(c, a, b) - | z <= x""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<8> - | input b : UInt<8> - | input c : UInt<1> - | output z : UInt<8> - | z <= mux(c, a, b)""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | node x = mux(c, a, b) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<8> + | input b : UInt<8> + | input c : UInt<1> + | output z : UInt<8> + | z <= mux(c, a, b)""".stripMargin execute(input, check, Seq.empty) } "Registers connected only to themselves" should "be replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | reg ra : UInt<8>, clock - | ra <= ra - | a <= ra - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | a <= UInt<8>(0) - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | reg ra : UInt<8>, clock + | ra <= ra + | a <= ra + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | a <= UInt<8>(0) + |""".stripMargin execute(input, check, Seq.empty) } "Registers connected only to themselves from constant propagation" should "be replaced with zero" in { - val input = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | reg ra : UInt<8>, clock - | ra <= or(ra, UInt(0)) - | a <= ra - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input clock : Clock - | output a : UInt<8> - | a <= UInt<8>(0) - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | reg ra : UInt<8>, clock + | ra <= or(ra, UInt(0)) + | a <= ra + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input clock : Clock + | output a : UInt<8> + | a <= UInt<8>(0) + |""".stripMargin execute(input, check, Seq.empty) } @@ -1290,7 +1286,7 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } - behavior of "ConstProp" + behavior.of("ConstProp") it should "optimize shl of constants" in { val input = @@ -1381,30 +1377,30 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { it should "optimize some binary operations when arguments match" in { // Signedness matters - matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) - matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """ ) - matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """ ) - matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """ ) - matchingArgs("and", "UInt<8>", "UInt<8>", """ i """ ) - matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) + matchingArgs("sub", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("sub", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """) + matchingArgs("div", "UInt<8>", "UInt<8>", """ UInt<8>("h1") """) + matchingArgs("div", "SInt<8>", "SInt<8>", """ SInt<8>("h1") """) + matchingArgs("rem", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("rem", "SInt<8>", "SInt<8>", """ SInt<8>("h0") """) + matchingArgs("and", "UInt<8>", "UInt<8>", """ i """) + matchingArgs("and", "SInt<8>", "UInt<8>", """ asUInt(i) """) // Signedness doesn't matter - matchingArgs("or", "UInt<8>", "UInt<8>", """ i """ ) - matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """ ) - matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) - matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """ ) + matchingArgs("or", "UInt<8>", "UInt<8>", """ i """) + matchingArgs("or", "SInt<8>", "UInt<8>", """ asUInt(i) """) + matchingArgs("xor", "UInt<8>", "UInt<8>", """ UInt<8>("h0") """) + matchingArgs("xor", "SInt<8>", "UInt<8>", """ UInt<8>("h0") """) // Always true - matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) - matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) - matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """ ) + matchingArgs("eq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) + matchingArgs("leq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) + matchingArgs("geq", "UInt<8>", "UInt<1>", """ UInt<1>("h1") """) // Never true - matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) - matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) - matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """ ) + matchingArgs("neq", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) + matchingArgs("lt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) + matchingArgs("gt", "UInt<8>", "UInt<1>", """ UInt<1>("h0") """) } - behavior of "Reduction operators" + behavior.of("Reduction operators") it should "optimize andr of a literal" in { val input = @@ -1534,7 +1530,6 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { } - class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { private val srcDir = "/constant_propagation_tests" private val transforms = Seq(new ConstantPropagation) @@ -1642,15 +1637,15 @@ class ConstantPropagationEquivalenceSpec extends FirrtlFlatSpec { firrtlEquivalenceTest(input, transforms) } - "addition of negative literals" should "be propagated" in { - val input = - s"""circuit AddTester : - | module AddTester : - | output ref : SInt<2> - | ref <= add(SInt<1>("h-1"), SInt<1>("h-1")) - |""".stripMargin - firrtlEquivalenceTest(input, transforms) - } + "addition of negative literals" should "be propagated" in { + val input = + s"""circuit AddTester : + | module AddTester : + | output ref : SInt<2> + | ref <= add(SInt<1>("h-1"), SInt<1>("h-1")) + |""".stripMargin + firrtlEquivalenceTest(input, transforms) + } "propagation of signed expressions" should "have the correct signs" in { val input = diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala index 3e5fd254..6edd212d 100644 --- a/src/test/scala/firrtlTests/CustomTransformSpec.scala +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -19,28 +19,28 @@ object CustomTransformSpec { class ReplaceExtModuleTransform extends SeqTransform with FirrtlMatchers { // Simple module val delayModuleString = """ - |circuit Delay : - | module Delay : - | input clock : Clock - | input reset : UInt<1> - | input a : UInt<32> - | input en : UInt<1> - | output b : UInt<32> - | - | reg r : UInt<32>, clock - | r <= r - | when en : - | r <= a - | b <= r - |""".stripMargin + |circuit Delay : + | module Delay : + | input clock : Clock + | input reset : UInt<1> + | input a : UInt<32> + | input en : UInt<1> + | output b : UInt<32> + | + | reg r : UInt<32>, clock + | r <= r + | when en : + | r <= a + | b <= r + |""".stripMargin val delayModuleCircuit = parse(delayModuleString) val delayModule = delayModuleCircuit.modules.find(_.name == delayModuleCircuit.main).get class ReplaceExtModule extends Pass { def run(c: Circuit): Circuit = c.copy( - modules = c.modules map { + modules = c.modules.map { case ExtModule(_, "Delay", _, _, _) => delayModule - case other => other + case other => other } ) } @@ -50,10 +50,10 @@ object CustomTransformSpec { } val input = """ - |circuit test : - | module test : - | output out : UInt - | out <= UInt(123)""".stripMargin + |circuit test : + | module test : + | output out : UInt + | out <= UInt(123)""".stripMargin val errorString = "My Custom Transform failed!" class ErroringTransform extends Transform { def inputForm = HighForm @@ -122,7 +122,7 @@ class CustomTransformSpec extends FirrtlFlatSpec { import CustomTransformSpec._ - behavior of "Custom Transforms" + behavior.of("Custom Transforms") they should "be able to introduce high firrtl" in { runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) @@ -130,22 +130,24 @@ class CustomTransformSpec extends FirrtlFlatSpec { they should "not cause \"Internal Errors\"" in { val optionsManager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - firrtlOptions = FirrtlExecutionOptions( - firrtlSource = Some(input), - customTransforms = List(new ErroringTransform)) + firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), customTransforms = List(new ErroringTransform)) } - (the [java.lang.IllegalArgumentException] thrownBy { + (the[java.lang.IllegalArgumentException] thrownBy { Driver.execute(optionsManager) - }).getMessage should include (errorString) + }).getMessage should include(errorString) } they should "preserve the input order" in { - runFirrtlTest("CustomTransform", "/features", customTransforms = List( - new FirstTransform, - new SecondTransform, - new ThirdTransform, - new ReplaceExtModuleTransform - )) + runFirrtlTest( + "CustomTransform", + "/features", + customTransforms = List( + new FirstTransform, + new SecondTransform, + new ThirdTransform, + new ReplaceExtModuleTransform + ) + ) } they should "run right before the emitter* when inputForm=LowForm" in { @@ -159,11 +161,10 @@ class CustomTransformSpec extends FirrtlFlatSpec { val custom = Dependency[IdentityLowForm] val tm = new firrtl.stage.transforms.Compiler(custom :: emitter :: Nil) info(s"when using ${emitter.getObject.name}") - tm - .flattenedTransformOrder + tm.flattenedTransformOrder .map(Dependency.fromTransform) .sliding(2) - .toList should contain (Seq(custom, emitter)) + .toList should contain(Seq(custom, emitter)) } } diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala index b309467a..a9084f0b 100644 --- a/src/test/scala/firrtlTests/DCETests.scala +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -13,7 +13,8 @@ import java.io.File import java.nio.file.Paths case class AnnotationWithDontTouches(target: ReferenceTarget) - extends SingleTargetAnnotation[ReferenceTarget] with HasDontTouches { + extends SingleTargetAnnotation[ReferenceTarget] + with HasDontTouches { def targets = Seq(target) def duplicate(n: ReferenceTarget) = this.copy(n) def dontTouches: Seq[ReferenceTarget] = targets @@ -31,9 +32,9 @@ class DCETests extends FirrtlFlatSpec { val finalState = (new LowFirrtlCompiler).compileAndEmit(state, customTransforms) val res = finalState.getEmittedCircuit.value // Convert to sets for comparison - val resSet = Set(parse(res).serialize.split("\n"):_*) - val checkSet = Set(parse(check).serialize.split("\n"):_*) - resSet should be (checkSet) + val resSet = Set(parse(res).serialize.split("\n"): _*) + val checkSet = Set(parse(check).serialize.split("\n"): _*) + resSet should be(checkSet) } "Unread wire" should "be deleted" in { @@ -418,7 +419,7 @@ class DCETests extends FirrtlFlatSpec { exec(input, check) } // This currently does NOT work - behavior of "Single dead instances" + behavior.of("Single dead instances") ignore should "should be deleted" in { val input = """circuit Top : @@ -469,9 +470,9 @@ class DCETests extends FirrtlFlatSpec { val result = (new VerilogCompiler).compileAndEmit(state, List.empty) val verilog = result.getEmittedCircuit.value // Check that mux is removed! - verilog shouldNot include regex ("""a \? x : r;""") + (verilog shouldNot include).regex("""a \? x : r;""") // Check for register update - verilog should include regex ("""(?m)if \(a\) begin\n\s*r <= x;\s*end""") + (verilog should include).regex("""(?m)if \(a\) begin\n\s*r <= x;\s*end""") } "Emitted Verilog" should "not contain dead print or stop statements" in { @@ -487,8 +488,8 @@ class DCETests extends FirrtlFlatSpec { val state = CircuitState(input, ChirrtlForm) val result = (new VerilogCompiler).compileAndEmit(state, List.empty) val verilog = result.getEmittedCircuit.value - verilog shouldNot include regex ("""fwrite""") - verilog shouldNot include regex ("""fatal""") + (verilog shouldNot include).regex("""fwrite""") + (verilog shouldNot include).regex("""fatal""") } } @@ -502,7 +503,7 @@ class DCECommandLineSpec extends FirrtlFlatSpec { "Dead Code Elimination" should "run by default" in { firrtl.Driver.execute(args) match { case FirrtlExecutionSuccess(_, verilog) => - verilog should not include regex ("wire +a") + (verilog should not).include(regex("wire +a")) case _ => fail("Unexpected compilation failure") } } @@ -510,7 +511,7 @@ class DCECommandLineSpec extends FirrtlFlatSpec { it should "not run when given --no-dce option" in { firrtl.Driver.execute(args :+ "--no-dce") match { case FirrtlExecutionSuccess(_, verilog) => - verilog should include regex ("wire +a") + (verilog should include).regex("wire +a") case _ => fail("Unexpected compilation failure") } } diff --git a/src/test/scala/firrtlTests/DriverSpec.scala b/src/test/scala/firrtlTests/DriverSpec.scala index 400bf314..5352fadf 100644 --- a/src/test/scala/firrtlTests/DriverSpec.scala +++ b/src/test/scala/firrtlTests/DriverSpec.scala @@ -85,15 +85,13 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit optionsManager.commonOptions.programArgs should be("fox" :: "tardigrade" :: "stomatopod" :: Nil) optionsManager.commonOptions = CommonOptions() - optionsManager.parse( - Array("dog", "stomatopod")) should be(true) + optionsManager.parse(Array("dog", "stomatopod")) should be(true) info(s"programArgs ${optionsManager.commonOptions.programArgs}") optionsManager.commonOptions.programArgs.length should be(2) optionsManager.commonOptions.programArgs should be("dog" :: "stomatopod" :: Nil) optionsManager.commonOptions = CommonOptions() - optionsManager.parse( - Array("fox", "--top-name", "dog", "tardigrade", "stomatopod")) should be(true) + optionsManager.parse(Array("fox", "--top-name", "dog", "tardigrade", "stomatopod")) should be(true) info(s"programArgs ${optionsManager.commonOptions.programArgs}") optionsManager.commonOptions.programArgs.length should be(3) optionsManager.commonOptions.programArgs should be("fox" :: "tardigrade" :: "stomatopod" :: Nil) @@ -130,11 +128,11 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit outputFileName should be("carol.v") } val input = """ - |circuit Top : - | module Top : - | input x : UInt<8> - | output y : UInt<8> - | y <= x""".stripMargin + |circuit Top : + | module Top : + | input x : UInt<8> + | output y : UInt<8> + | y <= x""".stripMargin val circuit = Parser.parse(input.split("\n").toIterator) "firrtl source can be provided directly" in { val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { @@ -153,18 +151,15 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "Only one of inputFileNameOverride, firrtlSource, and firrtlCircuit can be used at a time" in { val manager1 = new ExecutionOptionsManager("test") with HasFirrtlOptions { commonOptions = CommonOptions(topName = "Top") - firrtlOptions = FirrtlExecutionOptions(firrtlCircuit = Some(circuit), - firrtlSource = Some(input)) + firrtlOptions = FirrtlExecutionOptions(firrtlCircuit = Some(circuit), firrtlSource = Some(input)) } val manager2 = new ExecutionOptionsManager("test") with HasFirrtlOptions { commonOptions = CommonOptions(topName = "Top") - firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", - firrtlSource = Some(input)) + firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", firrtlSource = Some(input)) } val manager3 = new ExecutionOptionsManager("test") with HasFirrtlOptions { commonOptions = CommonOptions(topName = "Top") - firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", - firrtlCircuit = Some(circuit)) + firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "hi", firrtlCircuit = Some(circuit)) } assert(firrtl.Driver.getCircuit(manager1).isFailure) assert(firrtl.Driver.getCircuit(manager2).isFailure) @@ -273,26 +268,25 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "verilog" -> "./Foo.v", "mverilog" -> "./Foo.v", "sverilog" -> "./Foo.sv" - ).foreach { case (compilerName, expectedOutputFileName) => - info(s"$compilerName -> $expectedOutputFileName") - val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - commonOptions = CommonOptions(topName = "Foo") - firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), compilerName = compilerName) - } - - firrtl.Driver.execute(manager) match { - case success: FirrtlExecutionSuccess => - success.emitted.size should not be (0) - success.circuitState.annotations.length should be > (0) - case a: FirrtlExecutionFailure => - fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${a.message}") - } - - - - val file = new File(expectedOutputFileName) - file.exists() should be(true) - file.delete() + ).foreach { + case (compilerName, expectedOutputFileName) => + info(s"$compilerName -> $expectedOutputFileName") + val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { + commonOptions = CommonOptions(topName = "Foo") + firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), compilerName = compilerName) + } + + firrtl.Driver.execute(manager) match { + case success: FirrtlExecutionSuccess => + success.emitted.size should not be (0) + success.circuitState.annotations.length should be > (0) + case a: FirrtlExecutionFailure => + fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${a.message}") + } + + val file = new File(expectedOutputFileName) + file.exists() should be(true) + file.delete() } } "To a single file per module if OneFilePerModule is specified" in { @@ -304,27 +298,30 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "verilog" -> Seq("./Top.v", "./Child.v"), "mverilog" -> Seq("./Top.v", "./Child.v"), "sverilog" -> Seq("./Top.sv", "./Child.sv") - ).foreach { case (compilerName, expectedOutputFileNames) => - info(s"$compilerName -> $expectedOutputFileNames") - val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), - compilerName = compilerName, - emitOneFilePerModule = true) - } - - firrtl.Driver.execute(manager) match { - case success: FirrtlExecutionSuccess => - success.emitted.size should not be (0) - success.circuitState.annotations.length should be > (0) - case failure: FirrtlExecutionFailure => - fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${failure.message}") - } - - for (name <- expectedOutputFileNames) { - val file = new File(name) - file.exists() should be(true) - file.delete() - } + ).foreach { + case (compilerName, expectedOutputFileNames) => + info(s"$compilerName -> $expectedOutputFileNames") + val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { + firrtlOptions = FirrtlExecutionOptions( + firrtlSource = Some(input), + compilerName = compilerName, + emitOneFilePerModule = true + ) + } + + firrtl.Driver.execute(manager) match { + case success: FirrtlExecutionSuccess => + success.emitted.size should not be (0) + success.circuitState.annotations.length should be > (0) + case failure: FirrtlExecutionFailure => + fail(s"Got a FirrtlExecutionFailure! Expected FirrtlExecutionSuccess. Full message:\n${failure.message}") + } + + for (name <- expectedOutputFileNames) { + val file = new File(name) + file.exists() should be(true) + file.delete() + } } } } @@ -348,7 +345,7 @@ class DriverSpec extends AnyFreeSpec with Matchers with BackendCompilationUtilit "Both paths do the same thing" in { val s1 = FileUtils.getText(verilogFromFir) val s2 = FileUtils.getText(verilogFromPb) - s1 should equal (s2) + s1 should equal(s2) } } @@ -378,12 +375,12 @@ class VcdSuppressionSpec extends FirrtlFlatSpec { copyResourceToFile(cppHarnessResourceName, harness) verilogToCpp(prefix, testDir, Seq.empty, harness, suppress) #&& - cppToExe(prefix, testDir) ! loggingProcessLogger + cppToExe(prefix, testDir) ! loggingProcessLogger assert(executeExpectingSuccess(prefix, testDir)) val vcdFile = new File(s"$testDir/dump.vcd") - vcdFile.exists() should be(! suppress) + vcdFile.exists() should be(!suppress) } testIfVcdCreated(suppress = false) diff --git a/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala b/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala index 863b6900..9f756927 100644 --- a/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala +++ b/src/test/scala/firrtlTests/ExecutionOptionsManagerSpec.scala @@ -10,32 +10,36 @@ class ExecutionOptionsManagerSpec extends AnyFreeSpec with Matchers { "ExecutionOptionsManager is a container for one more more ComposableOptions Block" - { "It has a default CommonOptionsBlock" in { val manager = new ExecutionOptionsManager("test") - manager.topName should be ("") - manager.targetDirName should be (".") - manager.commonOptions.topName should be ("") - manager.commonOptions.targetDirName should be (".") + manager.topName should be("") + manager.targetDirName should be(".") + manager.commonOptions.topName should be("") + manager.commonOptions.targetDirName should be(".") } "But can override defaults like this" in { - val manager = new ExecutionOptionsManager("test") { commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") } - manager.commonOptions shouldBe a [CommonOptions] - manager.topName should be ("dog") - manager.targetDirName should be ("a/b/c") - manager.commonOptions.topName should be ("dog") - manager.commonOptions.targetDirName should be ("a/b/c") + val manager = new ExecutionOptionsManager("test") { + commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") + } + manager.commonOptions shouldBe a[CommonOptions] + manager.topName should be("dog") + manager.targetDirName should be("a/b/c") + manager.commonOptions.topName should be("dog") + manager.commonOptions.targetDirName should be("a/b/c") } "The add method should put a new version of a given type the manager" in { - val manager = new ExecutionOptionsManager("test") { commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") } + val manager = new ExecutionOptionsManager("test") { + commonOptions = CommonOptions(topName = "dog", targetDirName = "a/b/c") + } val initialCommon = manager.commonOptions - initialCommon.topName should be ("dog") - initialCommon.targetDirName should be ("a/b/c") + initialCommon.topName should be("dog") + initialCommon.targetDirName should be("a/b/c") manager.commonOptions = CommonOptions(topName = "cat", targetDirName = "d/e/f") val afterCommon = manager.commonOptions - afterCommon.topName should be ("cat") - afterCommon.targetDirName should be ("d/e/f") - initialCommon.topName should be ("dog") - initialCommon.targetDirName should be ("a/b/c") + afterCommon.topName should be("cat") + afterCommon.targetDirName should be("d/e/f") + initialCommon.topName should be("dog") + initialCommon.targetDirName should be("a/b/c") } "multiple composable blocks should be separable" in { val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { @@ -43,8 +47,8 @@ class ExecutionOptionsManagerSpec extends AnyFreeSpec with Matchers { firrtlOptions = FirrtlExecutionOptions(inputFileNameOverride = "fork") } - manager.firrtlOptions.inputFileNameOverride should be ("fork") - manager.commonOptions.topName should be ("spoon") + manager.firrtlOptions.inputFileNameOverride should be("fork") + manager.commonOptions.topName should be("spoon") } } } diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index 3616397f..6737643a 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -22,54 +22,55 @@ class ExpandWhensSpec extends FirrtlFlatSpec { PullMuxes, ExpandConnects, RemoveAccesses, - ExpandWhens) + ExpandWhens + ) private def executeTest(input: String, check: String, expected: Boolean) = { val circuit = Parser.parse(input.split("\n").toIterator) - val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } val c = result.circuit - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) if (expected) { - c.serialize.contains(check) should be (true) + c.serialize.contains(check) should be(true) } else { - lines.foreach(_.contains(check) should be (false)) + lines.foreach(_.contains(check) should be(false)) } } "Expand Whens" should "not emit INVALID" in { val input = - """|circuit Tester : - | module Tester : - | input p : UInt<1> - | when p : - | wire a : {b : UInt<64>, c : UInt<64>} - | a is invalid - | a.b <= UInt<64>("h04000000000000000")""".stripMargin + """|circuit Tester : + | module Tester : + | input p : UInt<1> + | when p : + | wire a : {b : UInt<64>, c : UInt<64>} + | a is invalid + | a.b <= UInt<64>("h04000000000000000")""".stripMargin val check = "INVALID" executeTest(input, check, false) } it should "void unwritten memory fields" in { val input = - """|circuit Tester : - | module Tester : - | input clk : Clock - | mem memory: - | data-type => UInt<32> - | depth => 32 - | reader => r0 - | writer => w0 - | read-latency => 0 - | write-latency => 1 - | read-under-write => undefined - | memory.r0.addr <= UInt<1>(1) - | memory.r0.en <= UInt<1>(1) - | memory.r0.clk <= clk - | memory.w0.addr <= UInt<1>(1) - | memory.w0.data <= UInt<1>(1) - | memory.w0.en <= UInt<1>(1) - | memory.w0.clk <= clk - | """.stripMargin + """|circuit Tester : + | module Tester : + | input clk : Clock + | mem memory: + | data-type => UInt<32> + | depth => 32 + | reader => r0 + | writer => w0 + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | memory.r0.addr <= UInt<1>(1) + | memory.r0.en <= UInt<1>(1) + | memory.r0.clk <= clk + | memory.w0.addr <= UInt<1>(1) + | memory.w0.data <= UInt<1>(1) + | memory.w0.en <= UInt<1>(1) + | memory.w0.clk <= clk + | """.stripMargin val check = "VOID" executeTest(input, check, true) } diff --git a/src/test/scala/firrtlTests/ExtModuleSpec.scala b/src/test/scala/firrtlTests/ExtModuleSpec.scala index 7379f1aa..c684e57b 100644 --- a/src/test/scala/firrtlTests/ExtModuleSpec.scala +++ b/src/test/scala/firrtlTests/ExtModuleSpec.scala @@ -4,13 +4,12 @@ package firrtlTests import firrtl.testutils._ -class SimpleExtModuleExecutionTest extends ExecutionTest("SimpleExtModuleTester", "/blackboxes", - Seq("SimpleExtModule")) -class MultiExtModuleExecutionTest extends ExecutionTest("MultiExtModuleTester", "/blackboxes", - Seq("SimpleExtModule", "AdderExtModule")) -class RenamedExtModuleExecutionTest extends ExecutionTest("RenamedExtModuleTester", "/blackboxes", - Seq("SimpleExtModule")) -class ParameterizedExtModuleExecutionTest extends ExecutionTest( - "ParameterizedExtModuleTester", "/blackboxes", Seq("ParameterizedExtModule")) +class SimpleExtModuleExecutionTest extends ExecutionTest("SimpleExtModuleTester", "/blackboxes", Seq("SimpleExtModule")) +class MultiExtModuleExecutionTest + extends ExecutionTest("MultiExtModuleTester", "/blackboxes", Seq("SimpleExtModule", "AdderExtModule")) +class RenamedExtModuleExecutionTest + extends ExecutionTest("RenamedExtModuleTester", "/blackboxes", Seq("SimpleExtModule")) +class ParameterizedExtModuleExecutionTest + extends ExecutionTest("ParameterizedExtModuleTester", "/blackboxes", Seq("ParameterizedExtModule")) class LargeParamExecutionTest extends ExecutionTest("LargeParamTester", "/blackboxes", Seq("LargeParam")) diff --git a/src/test/scala/firrtlTests/ExtModuleTests.scala b/src/test/scala/firrtlTests/ExtModuleTests.scala index 9ab3429e..5a58df2b 100644 --- a/src/test/scala/firrtlTests/ExtModuleTests.scala +++ b/src/test/scala/firrtlTests/ExtModuleTests.scala @@ -20,7 +20,6 @@ class ExtModuleTests extends FirrtlFlatSpec { | parameter TYP = 'bit' | """.stripMargin val parsed = parse(input) - (parse(parsed.serialize)) should be (parsed) + (parse(parsed.serialize)) should be(parsed) } } - diff --git a/src/test/scala/firrtlTests/FeatureSpec.scala b/src/test/scala/firrtlTests/FeatureSpec.scala index c7c8f4ac..4972eeb5 100644 --- a/src/test/scala/firrtlTests/FeatureSpec.scala +++ b/src/test/scala/firrtlTests/FeatureSpec.scala @@ -6,4 +6,3 @@ import firrtl.testutils.ExecutionTest // Miscellaneous Feature Checks class NestedSubAccessExecutionTest extends ExecutionTest("NestedSubAccessTester", "/features") - diff --git a/src/test/scala/firrtlTests/FileUtilsSpec.scala b/src/test/scala/firrtlTests/FileUtilsSpec.scala index 43d35048..5a438251 100644 --- a/src/test/scala/firrtlTests/FileUtilsSpec.scala +++ b/src/test/scala/firrtlTests/FileUtilsSpec.scala @@ -2,17 +2,16 @@ package firrtlTests - import firrtl.FileUtils import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class FileUtilsSpec extends AnyFlatSpec with Matchers { - private val sampleAnnotations: String = "annotations/SampleAnnotations.anno.json" + private val sampleAnnotations: String = "annotations/SampleAnnotations.anno.json" private val sampleAnnotationsFileName: String = s"src/test/resources/$sampleAnnotations" - behavior of "FileUtils.getLines" + behavior.of("FileUtils.getLines") it should "read from a string filename" in { FileUtils.getLines(sampleAnnotationsFileName).size should be > 0 @@ -22,7 +21,7 @@ class FileUtilsSpec extends AnyFlatSpec with Matchers { FileUtils.getLines(new java.io.File(sampleAnnotationsFileName)).size should be > 0 } - behavior of "FileUtils.getText" + behavior.of("FileUtils.getText") it should "read from a string filename" in { FileUtils.getText(sampleAnnotationsFileName).size should be > 0 @@ -32,13 +31,13 @@ class FileUtilsSpec extends AnyFlatSpec with Matchers { FileUtils.getText(new java.io.File(sampleAnnotationsFileName)).size should be > 0 } - behavior of "FileUtils.getLinesResource" + behavior.of("FileUtils.getLinesResource") it should "read from a resource" in { FileUtils.getLinesResource(s"/$sampleAnnotations").size should be > 0 } - behavior of "FileUtils.getTextResource" + behavior.of("FileUtils.getTextResource") it should "read from a resource" in { FileUtils.getTextResource(s"/$sampleAnnotations").split("\n").size should be > 0 diff --git a/src/test/scala/firrtlTests/FlattenTests.scala b/src/test/scala/firrtlTests/FlattenTests.scala index 34edfe58..53604ee5 100644 --- a/src/test/scala/firrtlTests/FlattenTests.scala +++ b/src/test/scala/firrtlTests/FlattenTests.scala @@ -3,12 +3,12 @@ package firrtlTests import firrtl.annotations.{Annotation, CircuitName, ComponentName, ModuleName} -import firrtl.transforms.{FlattenAnnotation, Flatten, NoCircuitDedupAnnotation} +import firrtl.transforms.{Flatten, FlattenAnnotation, NoCircuitDedupAnnotation} import firrtl.testutils._ /** - * Tests deep inline transformation - */ + * Tests deep inline transformation + */ class FlattenTests extends LowTransformSpec { def transform = new Flatten def flatten(mod: String): Annotation = { @@ -19,204 +19,204 @@ class FlattenTests extends LowTransformSpec { } "The modules inside Top " should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline1 - | i.a <= a - | b <= i.b - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_b - | i_a <= a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Top"))) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline1 + | i.a <= a + | b <= i.b + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_b + | i_a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top"))) } "Two instances of the same module inside Top " should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i1 of Inline1 - | inst i2 of Inline1 - | i1.a <= a - | node tmp = i1.b - | i2.a <= tmp - | b <= i2.b - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i1_a : UInt<32> - | wire i1_b : UInt<32> - | i1_b <= i1_a - | wire i2_a : UInt<32> - | wire i2_b : UInt<32> - | i2_b <= i2_a - | node tmp = i1_b - | b <= i2_b - | i1_a <= a - | i2_a <= tmp - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Top"))) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i1 of Inline1 + | inst i2 of Inline1 + | i1.a <= a + | node tmp = i1.b + | i2.a <= tmp + | b <= i2.b + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i1_a : UInt<32> + | wire i1_b : UInt<32> + | i1_b <= i1_a + | wire i2_a : UInt<32> + | wire i2_b : UInt<32> + | i2_b <= i2_a + | node tmp = i1_b + | b <= i2_b + | i1_a <= a + | i2_a <= tmp + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top"))) } "The module instance i in Top " should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | inst i of Inline1 - | inst ni of NotInline1 - | i.a <= a - | b <= i.b - | ni.a <= na - | nb <= ni.b - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | wire i_i_a : UInt<32> - | wire i_i_b : UInt<32> - | i_i_b <= i_i_a - | i_b <= i_i_a - | i_i_a <= i_a - | inst ni of NotInline1 - | b <= i_b - | nb <= ni.b - | i_a <= a - | ni.a <= na - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | b <= i.a - | i.a <= a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | b <= i.a - | i.a <= a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Top.i"), NoCircuitDedupAnnotation)) + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | i.a <= a + | b <= i.b + | ni.a <= na + | nb <= ni.b + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_i_a : UInt<32> + | wire i_i_b : UInt<32> + | i_i_b <= i_i_a + | i_b <= i_i_a + | i_i_a <= i_a + | inst ni of NotInline1 + | b <= i_b + | nb <= ni.b + | i_a <= a + | ni.a <= na + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Top.i"), NoCircuitDedupAnnotation)) } "The module Inline1" should "be inlined" in { val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | inst i of Inline1 - | inst ni of NotInline1 - | i.a <= a - | b <= i.b - | ni.a <= na - | nb <= ni.b - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | i.a <= a - | b <= i.a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | input na : UInt<32> - | output b : UInt<32> - | output nb : UInt<32> - | inst i of Inline1 - | inst ni of NotInline1 - | b <= i.b - | nb <= ni.b - | i.a <= a - | ni.a <= na - | module NotInline1 : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline2 - | b <= i.a - | i.a <= a - | module Inline1 : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_a - | i_a <= a - | module Inline2 : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(flatten("Inline1"), NoCircuitDedupAnnotation)) + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | i.a <= a + | b <= i.b + | ni.a <= na + | nb <= ni.b + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | i.a <= a + | b <= i.a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | input na : UInt<32> + | output b : UInt<32> + | output nb : UInt<32> + | inst i of Inline1 + | inst ni of NotInline1 + | b <= i.b + | nb <= ni.b + | i.a <= a + | ni.a <= na + | module NotInline1 : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline2 + | b <= i.a + | i.a <= a + | module Inline1 : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_a + | i_a <= a + | module Inline2 : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(flatten("Inline1"), NoCircuitDedupAnnotation)) } - "The Flatten transform" should "do nothing if no flatten annotations are present" in{ + "The Flatten transform" should "do nothing if no flatten annotations are present" in { val input = """|circuit Foo: | module Foo: @@ -229,46 +229,46 @@ class FlattenTests extends LowTransformSpec { "The Flatten transform" should "ignore extmodules" in { val input = """ - |circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst i of ExternalMod - | i.a <= a - | b <= i.b - | extmodule ExternalMod : - | input a : UInt<32> - | output b : UInt<32> - | defname = ExternalMod + |circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst i of ExternalMod + | i.a <= a + | b <= i.b + | extmodule ExternalMod : + | input a : UInt<32> + | output b : UInt<32> + | defname = ExternalMod """.stripMargin val check = """ - |circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | inst i_i of ExternalMod - | i_b <= i_i.b - | i_i.a <= i_a - | b <= i_b - | i_a <= a - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst i of ExternalMod - | b <= i.b - | i.a <= a - | extmodule ExternalMod : - | input a : UInt<32> - | output b : UInt<32> - | defname = ExternalMod + |circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | inst i_i of ExternalMod + | i_b <= i_i.b + | i_i.a <= i_a + | b <= i_b + | i_a <= a + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst i of ExternalMod + | b <= i.b + | i.a <= a + | extmodule ExternalMod : + | input a : UInt<32> + | output b : UInt<32> + | defname = ExternalMod """.stripMargin execute(input, check, Seq(flatten("Top"))) } diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index e8be70ad..81f2df33 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -10,8 +10,7 @@ import firrtl.testutils._ import firrtl.testutils.FirrtlCheckers._ class InferReadWriteSpec extends SimpleTransformSpec { - class InferReadWriteCheckException extends PassException( - "Readwrite ports are not found!") + class InferReadWriteCheckException extends PassException("Readwrite ports are not found!") object InferReadWriteCheck extends Pass { override def prerequisites = Forms.MidForm @@ -23,18 +22,18 @@ class InferReadWriteSpec extends SimpleTransformSpec { case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 => s.name == "mem" && s.readwriters.head == "rw" case s: Block => - s.stmts exists findReadWrite + s.stmts.exists(findReadWrite) case _ => false } - def run (c: Circuit) = { + def run(c: Circuit) = { val errors = new Errors - val foundReadWrite = c.modules exists { - case m: Module => findReadWrite(m.body) + val foundReadWrite = c.modules.exists { + case m: Module => findReadWrite(m.body) case m: ExtModule => false } if (!foundReadWrite) { - errors append new InferReadWriteCheckException + errors.append(new InferReadWriteCheckException) errors.trigger } c @@ -176,6 +175,6 @@ circuit sram6t : val annos = Seq(memlib.InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl - res should containLine (s"mem.rw.wmode <= wen") + res should containLine(s"mem.rw.wmode <= wen") } } diff --git a/src/test/scala/firrtlTests/InferResetsSpec.scala b/src/test/scala/firrtlTests/InferResetsSpec.scala index b607fb46..057fb3b0 100644 --- a/src/test/scala/firrtlTests/InferResetsSpec.scala +++ b/src/test/scala/firrtlTests/InferResetsSpec.scala @@ -4,7 +4,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ -import firrtl.passes.{CheckHighForm, CheckTypes, CheckInitialization} +import firrtl.passes.{CheckHighForm, CheckInitialization, CheckTypes} import firrtl.transforms.{CheckCombLoops, InferResets} import firrtl.testutils._ import firrtl.testutils.FirrtlCheckers._ @@ -16,95 +16,93 @@ class InferResetsSpec extends FirrtlFlatSpec { def compile(input: String, compiler: Compiler = new MiddleFirrtlCompiler): CircuitState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) - behavior of "ResetType" + behavior.of("ResetType") val BoolType = UIntType(IntWidth(1)) it should "support casting to other types" in { val result = compile(s""" - |circuit top: - | module top: - | input a : UInt<1> - | output v : UInt<1> - | output w : SInt<1> - | output x : Clock - | output y : Fixed<1><<0>> - | output z : AsyncReset - | wire r : Reset - | r <= a - | v <= asUInt(r) - | w <= asSInt(r) - | x <= asClock(r) - | y <= asFixedPoint(r, 0) - | z <= asAsyncReset(r)""".stripMargin - ) - result should containLine ("wire r : UInt<1>") - result should containLine ("r <= a") - result should containLine ("v <= asUInt(r)") - result should containLine ("w <= asSInt(r)") - result should containLine ("x <= asClock(r)") - result should containLine ("y <= asSInt(r)") - result should containLine ("z <= asAsyncReset(r)") + |circuit top: + | module top: + | input a : UInt<1> + | output v : UInt<1> + | output w : SInt<1> + | output x : Clock + | output y : Fixed<1><<0>> + | output z : AsyncReset + | wire r : Reset + | r <= a + | v <= asUInt(r) + | w <= asSInt(r) + | x <= asClock(r) + | y <= asFixedPoint(r, 0) + | z <= asAsyncReset(r)""".stripMargin) + result should containLine("wire r : UInt<1>") + result should containLine("r <= a") + result should containLine("v <= asUInt(r)") + result should containLine("w <= asSInt(r)") + result should containLine("x <= asClock(r)") + result should containLine("y <= asSInt(r)") + result should containLine("z <= asAsyncReset(r)") } it should "work across Module boundaries" in { val result = compile(s""" - |circuit top : - | module child : - | input clock : Clock - | input childReset : Reset - | input x : UInt<8> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) - | r <= x - | z <= r - | module top : - | input clock : Clock - | input reset : UInt<1> - | input x : UInt<8> - | output z : UInt<8> - | inst c of child - | c.clock <= clock - | c.childReset <= reset - | c.x <= x - | z <= c.z - |""".stripMargin - ) + |circuit top : + | module child : + | input clock : Clock + | input childReset : Reset + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) + | r <= x + | z <= r + | module top : + | input clock : Clock + | input reset : UInt<1> + | input x : UInt<8> + | output z : UInt<8> + | inst c of child + | c.clock <= clock + | c.childReset <= reset + | c.x <= x + | z <= c.z + |""".stripMargin) result should containTree { case Port(_, "childReset", Input, BoolType) => true } } it should "work across multiple Module boundaries" in { val result = compile(s""" - |circuit top : - | module child : - | input resetIn : Reset - | output resetOut : Reset - | resetOut <= resetIn - | module top : - | input clock : Clock - | input reset : UInt<1> - | input x : UInt<8> - | output z : UInt<8> - | inst c of child - | c.resetIn <= reset - | reg r : UInt<8>, clock with : (reset => (c.resetOut, UInt(123))) - | r <= x - | z <= r - |""".stripMargin - ) + |circuit top : + | module child : + | input resetIn : Reset + | output resetOut : Reset + | resetOut <= resetIn + | module top : + | input clock : Clock + | input reset : UInt<1> + | input x : UInt<8> + | output z : UInt<8> + | inst c of child + | c.resetIn <= reset + | reg r : UInt<8>, clock with : (reset => (c.resetOut, UInt(123))) + | r <= x + | z <= r + |""".stripMargin) result should containTree { case Port(_, "resetIn", Input, BoolType) => true } result should containTree { case Port(_, "resetOut", Output, BoolType) => true } } it should "work in nested and flipped aggregates with regular and partial connect" in { - val result = compile(s""" - |circuit top : - | module top : - | output fizz : { flip foo : { a : AsyncReset, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset }[2] } - | output buzz : { flip foo : { a : AsyncReset, c: UInt<1>, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset, c: UInt<8> }[2] } - | fizz.bar <= fizz.foo - | buzz.bar <- buzz.foo - |""".stripMargin, + val result = compile( + s""" + |circuit top : + | module top : + | output fizz : { flip foo : { a : AsyncReset, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset }[2] } + | output buzz : { flip foo : { a : AsyncReset, c: UInt<1>, flip b: Reset }[2], bar : { a : Reset, flip b: AsyncReset, c: UInt<8> }[2] } + | fizz.bar <= fizz.foo + | buzz.bar <- buzz.foo + |""".stripMargin, new LowFirrtlCompiler ) result should containTree { case Port(_, "fizz_foo_0_a", Input, AsyncResetType) => true } @@ -126,386 +124,370 @@ class InferResetsSpec extends FirrtlFlatSpec { } it should "not crash if a ResetType has no drivers" in { - a [CheckInitialization.RefNotInitializedException] shouldBe thrownBy { + a[CheckInitialization.RefNotInitializedException] shouldBe thrownBy { compile(s""" - |circuit test : - | module test : - | output out : Reset - | wire w : Reset - | out <= w - | out <= UInt(1) - |""".stripMargin - ) + |circuit test : + | module test : + | output out : Reset + | wire w : Reset + | out <= w + | out <= UInt(1) + |""".stripMargin) } } it should "NOT allow last connect semantics to pick the right type for Reset" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : UInt<1> - | output out : Reset - | wire w0 : Reset - | wire w1 : Reset - | w0 <= reset0 - | w1 <= reset1 - | out <= w0 - | out <= w1 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : UInt<1> + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | w0 <= reset0 + | w1 <= reset1 + | out <= w0 + | out <= w1 + |""".stripMargin) } } it should "NOT support last connect semantics across whens" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : AsyncReset - | input reset2 : UInt<1> - | input en : UInt<1> - | output out : Reset - | wire w0 : Reset - | wire w1 : Reset - | wire w2 : Reset - | w0 <= reset0 - | w1 <= reset1 - | w2 <= reset2 - | out <= w2 - | when en : - | out <= w0 - | else : - | out <= w1 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : AsyncReset + | input reset2 : UInt<1> + | input en : UInt<1> + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | wire w2 : Reset + | w0 <= reset0 + | w1 <= reset1 + | w2 <= reset2 + | out <= w2 + | when en : + | out <= w0 + | else : + | out <= w1 + |""".stripMargin) } } it should "not allow different Reset Types to drive a single Reset" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : UInt<1> - | input en : UInt<1> - | output out : Reset - | wire w1 : Reset - | wire w2 : Reset - | w1 <= reset0 - | w2 <= reset1 - | out <= w1 - | when en : - | out <= w2 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : UInt<1> + | input en : UInt<1> + | output out : Reset + | wire w1 : Reset + | wire w2 : Reset + | w1 <= reset0 + | w2 <= reset1 + | out <= w1 + | when en : + | out <= w2 + |""".stripMargin) } } it should "allow concrete reset types to overrule invalidation" in { val result = compile(s""" - |circuit test : - | module test : - | input in : AsyncReset - | output out : Reset - | out is invalid - | out <= in - |""".stripMargin) + |circuit test : + | module test : + | input in : AsyncReset + | output out : Reset + | out is invalid + | out <= in + |""".stripMargin) result should containTree { case Port(_, "out", Output, AsyncResetType) => true } } it should "default to BoolType for Resets that are only invalidated" in { val result = compile(s""" - |circuit test : - | module test : - | output out : Reset - | out is invalid - |""".stripMargin) + |circuit test : + | module test : + | output out : Reset + | out is invalid + |""".stripMargin) result should containTree { case Port(_, "out", Output, BoolType) => true } } it should "not error if component of ResetType is invalidated and connected to an AsyncResetType" in { val result = compile(s""" - |circuit test : - | module test : - | input cond : UInt<1> - | input in : AsyncReset - | output out : Reset - | out is invalid - | when cond : - | out <= in - |""".stripMargin) + |circuit test : + | module test : + | input cond : UInt<1> + | input in : AsyncReset + | output out : Reset + | out is invalid + | when cond : + | out <= in + |""".stripMargin) result should containTree { case Port(_, "out", Output, AsyncResetType) => true } } it should "allow ResetType to drive AsyncResets or UInt<1>" in { val result1 = compile(s""" - |circuit top : - | module top : - | input in : UInt<1> - | output out : UInt<1> - | wire w : Reset - | w <= in - | out <= w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : UInt<1> + | output out : UInt<1> + | wire w : Reset + | w <= in + | out <= w + |""".stripMargin) result1 should containTree { case DefWire(_, "w", BoolType) => true } val result2 = compile(s""" - |circuit top : - | module top : - | output foo : { flip a : UInt<1> } - | input bar : { flip a : UInt<1> } - | wire w : { flip a : Reset } - | foo <= w - | w <= bar - |""".stripMargin - ) + |circuit top : + | module top : + | output foo : { flip a : UInt<1> } + | input bar : { flip a : UInt<1> } + | wire w : { flip a : Reset } + | foo <= w + | w <= bar + |""".stripMargin) val AggType = BundleType(Seq(Field("a", Flip, BoolType))) result2 should containTree { case DefWire(_, "w", AggType) => true } val result3 = compile(s""" - |circuit top : - | module top : - | input in : UInt<1> - | output out : UInt<1> - | wire w : Reset - | w <- in - | out <- w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : UInt<1> + | output out : UInt<1> + | wire w : Reset + | w <- in + | out <- w + |""".stripMargin) result3 should containTree { case DefWire(_, "w", BoolType) => true } } it should "error if a ResetType driving UInt<1> infers to AsyncReset" in { - an [Exception] shouldBe thrownBy { + an[Exception] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : UInt<1> - | wire w : Reset - | w <= in - | out <= w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : AsyncReset + | output out : UInt<1> + | wire w : Reset + | w <= in + | out <= w + |""".stripMargin) } } it should "error if a ResetType driving AsyncReset infers to UInt<1>" in { - an [Exception] shouldBe thrownBy { + an[Exception] shouldBe thrownBy { compile(s""" - |circuit top : - | module top : - | input in : UInt<1> - | output out : AsyncReset - | wire w : Reset - | w <= in - | out <= w - |""".stripMargin - ) + |circuit top : + | module top : + | input in : UInt<1> + | output out : AsyncReset + | wire w : Reset + | w <= in + | out <= w + |""".stripMargin) } } it should "not allow ResetType as an Input or ExtModule output" in { // TODO what exception should be thrown here? - an [CheckHighForm.ResetInputException] shouldBe thrownBy { + an[CheckHighForm.ResetInputException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input in : { foo : Reset } - | output out : Reset - | out <= in.foo - |""".stripMargin - ) + |circuit top : + | module top : + | input in : { foo : Reset } + | output out : Reset + | out <= in.foo + |""".stripMargin) } - an [CheckHighForm.ResetExtModuleOutputException] shouldBe thrownBy { + an[CheckHighForm.ResetExtModuleOutputException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | extmodule ext : - | output out : { foo : Reset } - | module top : - | output out : Reset - | inst e of ext - | out <= e.out.foo - |""".stripMargin - ) + |circuit top : + | extmodule ext : + | output out : { foo : Reset } + | module top : + | output out : Reset + | inst e of ext + | out <= e.out.foo + |""".stripMargin) } } it should "not allow Vecs to infer different Reset Types" in { - an [CheckTypes.InvalidConnect] shouldBe thrownBy { + an[CheckTypes.InvalidConnect] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input reset0 : AsyncReset - | input reset1 : UInt<1> - | output out : Reset[2] - | out[0] <= reset0 - | out[1] <= reset1 - |""".stripMargin - ) + |circuit top : + | module top : + | input reset0 : AsyncReset + | input reset1 : UInt<1> + | output out : Reset[2] + | out[0] <= reset0 + | out[1] <= reset1 + |""".stripMargin) } } // Or is this actually an error? The behavior is that out is inferred as AsyncReset[2] ignore should "not allow Vecs only be partially inferred" in { // Some exception should be thrown, TODO figure out which one - an [Exception] shouldBe thrownBy { + an[Exception] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input reset : AsyncReset - | output out : Reset[2] - | out is invalid - | out[0] <= reset - |""".stripMargin - ) + |circuit top : + | module top : + | input reset : AsyncReset + | output out : Reset[2] + | out is invalid + | out[0] <= reset + |""".stripMargin) } } - it should "support inferring modules that would dedup differently" in { val result = compile(s""" - |circuit top : - | module child : - | input clock : Clock - | input childReset : Reset - | input x : UInt<8> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) - | r <= x - | z <= r - | module child_1 : - | input clock : Clock - | input childReset : Reset - | input x : UInt<8> - | output z : UInt<8> - | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) - | r <= x - | z <= r - | module top : - | input clock : Clock - | input reset1 : UInt<1> - | input reset2 : AsyncReset - | input x : UInt<8>[2] - | output z : UInt<8>[2] - | inst c of child - | c.clock <= clock - | c.childReset <= reset1 - | c.x <= x[0] - | z[0] <= c.z - | inst c2 of child_1 - | c2.clock <= clock - | c2.childReset <= reset2 - | c2.x <= x[1] - | z[1] <= c2.z - |""".stripMargin - ) + |circuit top : + | module child : + | input clock : Clock + | input childReset : Reset + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) + | r <= x + | z <= r + | module child_1 : + | input clock : Clock + | input childReset : Reset + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt<8>, clock with : (reset => (childReset, UInt(123))) + | r <= x + | z <= r + | module top : + | input clock : Clock + | input reset1 : UInt<1> + | input reset2 : AsyncReset + | input x : UInt<8>[2] + | output z : UInt<8>[2] + | inst c of child + | c.clock <= clock + | c.childReset <= reset1 + | c.x <= x[0] + | z[0] <= c.z + | inst c2 of child_1 + | c2.clock <= clock + | c2.childReset <= reset2 + | c2.x <= x[1] + | z[1] <= c2.z + |""".stripMargin) result should containTree { case Port(_, "childReset", Input, BoolType) => true } result should containTree { case Port(_, "childReset", Input, AsyncResetType) => true } } it should "infer based on what a component *drives* not just what drives it" in { val result = compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : Reset - | wire w : Reset - | w is invalid - | out <= w - | out <= in - |""".stripMargin) + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | wire w : Reset + | w is invalid + | out <= w + | out <= in + |""".stripMargin) result should containTree { case DefWire(_, "w", AsyncResetType) => true } } it should "infer from connections, ignoring the fact that the invalidation wins" in { val result = compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : Reset - | out <= in - | out is invalid - |""".stripMargin) + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | out <= in + | out is invalid + |""".stripMargin) result should containTree { case Port(_, "out", Output, AsyncResetType) => true } } // The backwards type propagation constrains `w` to be the same as both `out0` and `out1` it should "not allow an invalidated Wire to drive both a UInt<1> and an AsyncReset" in { - an [InferResets.InferResetsException] shouldBe thrownBy { + an[InferResets.InferResetsException] shouldBe thrownBy { val result = compile(s""" - |circuit top : - | module top : - | input in0 : AsyncReset - | input in1 : UInt<1> - | output out0 : Reset - | output out1 : Reset - | wire w : Reset - | w is invalid - | out0 <= w - | out1 <= w - | out0 <= in0 - | out1 <= in1 - |""".stripMargin - ) + |circuit top : + | module top : + | input in0 : AsyncReset + | input in1 : UInt<1> + | output out0 : Reset + | output out1 : Reset + | wire w : Reset + | w is invalid + | out0 <= w + | out1 <= w + | out0 <= in0 + | out1 <= in1 + |""".stripMargin) } } it should "not propagate type info from downstream across a cast" in { val result = compile(s""" - |circuit top : - | module top : - | input in0 : AsyncReset - | input in1 : UInt<1> - | output out0 : Reset - | output out1 : Reset - | wire w : Reset - | w is invalid - | out0 <= asAsyncReset(w) - | out1 <= w - | out0 <= in0 - | out1 <= in1 - |""".stripMargin - ) + |circuit top : + | module top : + | input in0 : AsyncReset + | input in1 : UInt<1> + | output out0 : Reset + | output out1 : Reset + | wire w : Reset + | w is invalid + | out0 <= asAsyncReset(w) + | out1 <= w + | out0 <= in0 + | out1 <= in1 + |""".stripMargin) result should containTree { case Port(_, "out0", Output, AsyncResetType) => true } } // This tests for a bug unrelated to support or lackthereof for last connect in inference it should "take into account both internal and external constraints on Module port types" in { val result = compile(s""" - |circuit top : - | module child : - | input i : AsyncReset - | output o : Reset - | o <= i - | module top : - | input in : AsyncReset - | output out : AsyncReset - | inst c of child - | c.o is invalid - | c.i <= in - | out <= c.o - |""".stripMargin) + |circuit top : + | module child : + | input i : AsyncReset + | output o : Reset + | o <= i + | module top : + | input in : AsyncReset + | output out : AsyncReset + | inst c of child + | c.o is invalid + | c.i <= in + | out <= c.o + |""".stripMargin) result should containTree { case Port(_, "o", Output, AsyncResetType) => true } } it should "not crash on combinational loops" in { - a [CheckCombLoops.CombLoopException] shouldBe thrownBy { - val result = compile(s""" - |circuit top : - | module top : - | input in : AsyncReset - | output out : Reset - | wire w0 : Reset - | wire w1 : Reset - | w0 <= in - | w0 <= w1 - | w1 <= w0 - | out <= in - |""".stripMargin, + a[CheckCombLoops.CombLoopException] shouldBe thrownBy { + val result = compile( + s""" + |circuit top : + | module top : + | input in : AsyncReset + | output out : Reset + | wire w0 : Reset + | wire w1 : Reset + | w0 <= in + | w0 <= w1 + | w1 <= w0 + | out <= in + |""".stripMargin, compiler = new LowFirrtlCompiler ) } diff --git a/src/test/scala/firrtlTests/InfoSpec.scala b/src/test/scala/firrtlTests/InfoSpec.scala index a2410f9d..172ddfb9 100644 --- a/src/test/scala/firrtlTests/InfoSpec.scala +++ b/src/test/scala/firrtlTests/InfoSpec.scala @@ -13,9 +13,9 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } @@ -27,156 +27,151 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { "Source locators on module ports" should "be propagated to Verilog" in { val result = compileBody(s""" - |input x : UInt<8> $Info1 - |output y : UInt<8> $Info2 - |y <= x""".stripMargin - ) + |input x : UInt<8> $Info1 + |output y : UInt<8> $Info2 + |y <= x""".stripMargin) result should containTree { case Port(Info1, "x", Input, _) => true } - result should containLine (s"input [7:0] x, //$Info1") + result should containLine(s"input [7:0] x, //$Info1") result should containTree { case Port(Info2, "y", Output, _) => true } - result should containLine (s"output [7:0] y //$Info2") + result should containLine(s"output [7:0] y //$Info2") } "Source locators on aggregates" should "be propagated to Verilog" in { val result = compileBody(s""" - |input io : { x : UInt<8>, flip y : UInt<8> } $Info1 - |io.y <= io.x""".stripMargin - ) + |input io : { x : UInt<8>, flip y : UInt<8> } $Info1 + |io.y <= io.x""".stripMargin) result should containTree { case Port(Info1, "io_x", Input, _) => true } - result should containLine (s"input [7:0] io_x, //$Info1") + result should containLine(s"input [7:0] io_x, //$Info1") result should containTree { case Port(Info1, "io_y", Output, _) => true } - result should containLine (s"output [7:0] io_y //$Info1") + result should containLine(s"output [7:0] io_y //$Info1") } "Source locators" should "be propagated on declarations" in { val result = compileBody(s""" - |input clock : Clock - |input x : UInt<8> - |output y : UInt<8> - |reg r : UInt<8>, clock $Info1 - |wire w : UInt<8> $Info2 - |node n = or(w, x) $Info3 - |w <= and(x, r) - |r <= or(n, r) - |y <= r""".stripMargin - ) - result should containTree { case DefRegister(Info1, "r", _,_,_,_) => true } - result should containLine (s"reg [7:0] r; //$Info1") + |input clock : Clock + |input x : UInt<8> + |output y : UInt<8> + |reg r : UInt<8>, clock $Info1 + |wire w : UInt<8> $Info2 + |node n = or(w, x) $Info3 + |w <= and(x, r) + |r <= or(n, r) + |y <= r""".stripMargin) + result should containTree { case DefRegister(Info1, "r", _, _, _, _) => true } + result should containLine(s"reg [7:0] r; //$Info1") result should containTree { case DefNode(Info2, "w", _) => true } - result should containLine (s"wire [7:0] w = x & r; //$Info2") // Node "w" declaration in Verilog + result should containLine(s"wire [7:0] w = x & r; //$Info2") // Node "w" declaration in Verilog result should containTree { case DefNode(Info3, "n", _) => true } - result should containLine (s"wire [7:0] n = w | x; //$Info3") + result should containLine(s"wire [7:0] n = w | x; //$Info3") } it should "be propagated on memories" in { val result = compileBody(s""" - |input clock : Clock - |input addr : UInt<5> - |output z : UInt<8> - |mem m: $Info1 - | data-type => UInt<8> - | depth => 32 - | read-latency => 0 - | write-latency => 1 - | reader => r - | writer => w - |m.r.clk <= clock - |m.r.addr <= addr - |m.r.en <= UInt(1) - |m.w.clk <= clock - |m.w.addr <= addr - |m.w.en <= UInt(0) - |m.w.data <= UInt(0) - |m.w.mask <= UInt(0) - |z <= m.r.data - |""".stripMargin - ) + |input clock : Clock + |input addr : UInt<5> + |output z : UInt<8> + |mem m: $Info1 + | data-type => UInt<8> + | depth => 32 + | read-latency => 0 + | write-latency => 1 + | reader => r + | writer => w + |m.r.clk <= clock + |m.r.addr <= addr + |m.r.en <= UInt(1) + |m.w.clk <= clock + |m.w.addr <= addr + |m.w.en <= UInt(0) + |m.w.data <= UInt(0) + |m.w.mask <= UInt(0) + |z <= m.r.data + |""".stripMargin) - result should containTree { case DefMemory(Info1, "m", _,_,_,_,_,_,_,_) => true } - result should containLine (s"reg [7:0] m [0:31]; //$Info1") - result should containLine (s"wire [7:0] m_r_data; //$Info1") - result should containLine (s"wire [4:0] m_r_addr; //$Info1") - result should containLine (s"wire [7:0] m_w_data; //$Info1") - result should containLine (s"wire [4:0] m_w_addr; //$Info1") - result should containLine (s"wire m_w_mask; //$Info1") - result should containLine (s"wire m_w_en; //$Info1") - result should containLine (s"assign m_r_data = m[m_r_addr]; //$Info1") - result should containLine (s"m[m_w_addr] <= m_w_data; //$Info1") + result should containTree { case DefMemory(Info1, "m", _, _, _, _, _, _, _, _) => true } + result should containLine(s"reg [7:0] m [0:31]; //$Info1") + result should containLine(s"wire [7:0] m_r_data; //$Info1") + result should containLine(s"wire [4:0] m_r_addr; //$Info1") + result should containLine(s"wire [7:0] m_w_data; //$Info1") + result should containLine(s"wire [4:0] m_w_addr; //$Info1") + result should containLine(s"wire m_w_mask; //$Info1") + result should containLine(s"wire m_w_en; //$Info1") + result should containLine(s"assign m_r_data = m[m_r_addr]; //$Info1") + result should containLine(s"m[m_w_addr] <= m_w_data; //$Info1") } it should "be propagated on instances" in { val result = compile(s""" - |circuit Test : - | module Child : - | output io : { flip in : UInt<8>, out : UInt<8> } - | io.out <= io.in - | module Test : - | output io : { flip in : UInt<8>, out : UInt<8> } - | inst c of Child $Info1 - | io <= c.io - |""".stripMargin - ) + |circuit Test : + | module Child : + | output io : { flip in : UInt<8>, out : UInt<8> } + | io.out <= io.in + | module Test : + | output io : { flip in : UInt<8>, out : UInt<8> } + | inst c of Child $Info1 + | io <= c.io + |""".stripMargin) result should containTree { case WDefInstance(Info1, "c", "Child", _) => true } - result should containLine (s"Child c ( //$Info1") + result should containLine(s"Child c ( //$Info1") } it should "be propagated across direct node assignments and connections" in { val result = compile(s""" - |circuit Test : - | module Test : - | input in : UInt<8> - | output out : UInt<8> - | node a = in $Info1 - | node b = a - | out <= b - |""".stripMargin - ) - result should containTree { case Connect(Info1, Reference("out", _,_,_), Reference("in", _,_,_)) => true } - result should containLine (s"assign out = in; //$Info1") + |circuit Test : + | module Test : + | input in : UInt<8> + | output out : UInt<8> + | node a = in $Info1 + | node b = a + | out <= b + |""".stripMargin) + result should containTree { case Connect(Info1, Reference("out", _, _, _), Reference("in", _, _, _)) => true } + result should containLine(s"assign out = in; //$Info1") } "source locators" should "be propagated through ExpandWhens" in { - val input = """ - |;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030 - |circuit GCD : - | module GCD : - | input clock : Clock - | input reset : UInt<1> - | output io : {flip a : UInt<32>, flip b : UInt<32>, flip e : UInt<1>, z : UInt<32>, v : UInt<1>} - | - | io is invalid - | io is invalid - | reg x : UInt<32>, clock @[GCD.scala 15:14] - | reg y : UInt<32>, clock @[GCD.scala 16:14] - | node _T_14 = gt(x, y) @[GCD.scala 17:11] - | when _T_14 : @[GCD.scala 17:18] - | node _T_15 = sub(x, y) @[GCD.scala 17:27] - | node _T_16 = tail(_T_15, 1) @[GCD.scala 17:27] - | x <= _T_16 @[GCD.scala 17:22] - | skip @[GCD.scala 17:18] - | node _T_18 = eq(_T_14, UInt<1>("h00")) @[GCD.scala 17:18] - | when _T_18 : @[GCD.scala 18:18] - | node _T_19 = sub(y, x) @[GCD.scala 18:27] - | node _T_20 = tail(_T_19, 1) @[GCD.scala 18:27] - | y <= _T_20 @[GCD.scala 18:22] - | skip @[GCD.scala 18:18] - | when io.e : @[GCD.scala 19:15] - | x <= io.a @[GCD.scala 19:19] - | y <= io.b @[GCD.scala 19:30] - | skip @[GCD.scala 19:15] - | io.z <= x @[GCD.scala 20:8] - | node _T_22 = eq(y, UInt<1>("h00")) @[GCD.scala 21:13] - | io.v <= _T_22 @[GCD.scala 21:8] - | + val input = + """ + |;buildInfoPackage: chisel3, version: 3.1-SNAPSHOT, scalaVersion: 2.11.7, sbtVersion: 0.13.11, builtAtString: 2016-11-26 18:48:38.030, builtAtMillis: 1480186118030 + |circuit GCD : + | module GCD : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip a : UInt<32>, flip b : UInt<32>, flip e : UInt<1>, z : UInt<32>, v : UInt<1>} + | + | io is invalid + | io is invalid + | reg x : UInt<32>, clock @[GCD.scala 15:14] + | reg y : UInt<32>, clock @[GCD.scala 16:14] + | node _T_14 = gt(x, y) @[GCD.scala 17:11] + | when _T_14 : @[GCD.scala 17:18] + | node _T_15 = sub(x, y) @[GCD.scala 17:27] + | node _T_16 = tail(_T_15, 1) @[GCD.scala 17:27] + | x <= _T_16 @[GCD.scala 17:22] + | skip @[GCD.scala 17:18] + | node _T_18 = eq(_T_14, UInt<1>("h00")) @[GCD.scala 17:18] + | when _T_18 : @[GCD.scala 18:18] + | node _T_19 = sub(y, x) @[GCD.scala 18:27] + | node _T_20 = tail(_T_19, 1) @[GCD.scala 18:27] + | y <= _T_20 @[GCD.scala 18:22] + | skip @[GCD.scala 18:18] + | when io.e : @[GCD.scala 19:15] + | x <= io.a @[GCD.scala 19:19] + | y <= io.b @[GCD.scala 19:30] + | skip @[GCD.scala 19:15] + | io.z <= x @[GCD.scala 20:8] + | node _T_22 = eq(y, UInt<1>("h00")) @[GCD.scala 21:13] + | io.v <= _T_22 @[GCD.scala 21:8] + | """.stripMargin val result = (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) - result should containLine ("node _GEN_0 = mux(_T_14, _T_16, x) @[GCD.scala 17:18 GCD.scala 17:22 GCD.scala 15:14]") - result should containLine ("node _GEN_2 = mux(io_e, io_a, _GEN_0) @[GCD.scala 19:15 GCD.scala 19:19]") - result should containLine ("x <= _GEN_2") - result should containLine ("node _GEN_1 = mux(_T_18, _T_20, y) @[GCD.scala 18:18 GCD.scala 18:22 GCD.scala 16:14]") - result should containLine ("node _GEN_3 = mux(io_e, io_b, _GEN_1) @[GCD.scala 19:15 GCD.scala 19:30]") - result should containLine ("y <= _GEN_3") + result should containLine("node _GEN_0 = mux(_T_14, _T_16, x) @[GCD.scala 17:18 GCD.scala 17:22 GCD.scala 15:14]") + result should containLine("node _GEN_2 = mux(io_e, io_a, _GEN_0) @[GCD.scala 19:15 GCD.scala 19:19]") + result should containLine("x <= _GEN_2") + result should containLine("node _GEN_1 = mux(_T_18, _T_20, y) @[GCD.scala 18:18 GCD.scala 18:22 GCD.scala 16:14]") + result should containLine("node _GEN_3 = mux(io_e, io_b, _GEN_1) @[GCD.scala 19:15 GCD.scala 19:30]") + result should containLine("y <= _GEN_3") } "source locators for append option" should "use multiinfo" in { @@ -195,71 +190,68 @@ class InfoSpec extends FirrtlFlatSpec with FirrtlMatchers { "source locators for basic register updates" should "be propagated to Verilog" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |output io : { flip in : UInt<8>, out : UInt<8>} - |reg r : UInt<8>, clock - |r <= io.in $Info1 - |io.out <= r - |""".stripMargin - ) - result should containLine (s"r <= io_in; //$Info1") + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, out : UInt<8>} + |reg r : UInt<8>, clock + |r <= io.in $Info1 + |io.out <= r + |""".stripMargin) + result should containLine(s"r <= io_in; //$Info1") } "source locators for register reset" should "be propagated to Verilog" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |output io : { flip in : UInt<8>, out : UInt<8>} - |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info3 - |r <= io.in $Info1 - |io.out <= r - |""".stripMargin - ) - result should containLine (s"if (reset) begin //$Info3") - result should containLine (s"r <= 8'h0; //$Info3") - result should containLine (s"r <= io_in; //$Info1") + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, out : UInt<8>} + |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info3 + |r <= io.in $Info1 + |io.out <= r + |""".stripMargin) + result should containLine(s"if (reset) begin //$Info3") + result should containLine(s"r <= 8'h0; //$Info3") + result should containLine(s"r <= io_in; //$Info1") } "source locators for complex register updates" should "be propagated to Verilog" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |output io : { flip in : UInt<8>, flip a : UInt<1>, out : UInt<8>} - |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info1 - |r <= UInt<2>(2) $Info2 - |when io.a : $Info3 - | r <= io.in $Info4 - |io.out <= r - |""".stripMargin - ) - result should containLine (s"if (reset) begin //$Info1") - result should containLine (s"r <= 8'h0; //$Info1") - result should containLine (s"end else if (io_a) begin //$Info3") - result should containLine (s"r <= io_in; //$Info4") - result should containLine (s"r <= 8'h2; //$Info2") + |input clock : Clock + |input reset : UInt<1> + |output io : { flip in : UInt<8>, flip a : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock with : (reset => (reset, UInt<8>("h0"))) $Info1 + |r <= UInt<2>(2) $Info2 + |when io.a : $Info3 + | r <= io.in $Info4 + |io.out <= r + |""".stripMargin) + result should containLine(s"if (reset) begin //$Info1") + result should containLine(s"r <= 8'h0; //$Info1") + result should containLine(s"end else if (io_a) begin //$Info3") + result should containLine(s"r <= io_in; //$Info4") + result should containLine(s"r <= 8'h2; //$Info2") } "FileInfo" should "be able to contain a escaped characters" in { def input(info: String): String = s"""circuit m: @[$info] - | module m: - | skip - |""".stripMargin + | module m: + | skip + |""".stripMargin def parseInfo(info: String): FileInfo = { firrtl.Parser.parse(input(info)).info.asInstanceOf[FileInfo] } - parseInfo("test\\ntest").escaped should be ("test\\ntest") - parseInfo("test\\ntest").unescaped should be ("test\ntest") - parseInfo("test\\ttest").escaped should be ("test\\ttest") - parseInfo("test\\ttest").unescaped should be ("test\ttest") - parseInfo("test\\\\test").escaped should be ("test\\\\test") - parseInfo("test\\\\test").unescaped should be ("test\\test") - parseInfo("test\\]test").escaped should be ("test\\]test") - parseInfo("test\\]test").unescaped should be ("test]test") - parseInfo("test[\\][\\]test").escaped should be ("test[\\][\\]test") - parseInfo("test[\\][\\]test").unescaped should be ("test[][]test") + parseInfo("test\\ntest").escaped should be("test\\ntest") + parseInfo("test\\ntest").unescaped should be("test\ntest") + parseInfo("test\\ttest").escaped should be("test\\ttest") + parseInfo("test\\ttest").unescaped should be("test\ttest") + parseInfo("test\\\\test").escaped should be("test\\\\test") + parseInfo("test\\\\test").unescaped should be("test\\test") + parseInfo("test\\]test").escaped should be("test\\]test") + parseInfo("test\\]test").unescaped should be("test]test") + parseInfo("test[\\][\\]test").escaped should be("test[\\][\\]test") + parseInfo("test[\\][\\]test").unescaped should be("test[][]test") } } diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 27102785..e4f711ed 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -12,8 +12,8 @@ import firrtl.stage.TransformManager import firrtl.options.Dependency /** - * Tests inline instances transformation - */ + * Tests inline instances transformation + */ class InlineInstancesTests extends LowTransformSpec { def transform = new InlineInstances def inline(mod: String): Annotation = { @@ -22,181 +22,181 @@ class InlineInstancesTests extends LowTransformSpec { val name = if (parts.size == 1) modName else ComponentName(parts.tail.mkString("."), modName) InlineAnnotation(name) } - // Set this to debug, this will apply to all tests - // Logger.setLevel(this.getClass, Debug) - "The module Inline" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_b - | i_a <= a""".stripMargin - execute(input, check, Seq(inline("Inline"))) - } + // Set this to debug, this will apply to all tests + // Logger.setLevel(this.getClass, Debug) + "The module Inline" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_b + | i_a <= a""".stripMargin + execute(input, check, Seq(inline("Inline"))) + } - "The all instances of Simple" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i0 of Simple - | inst i1 of Simple - | i0.a <= a - | i1.a <= i0.b - | b <= i1.b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i0_a : UInt<32> - | wire i0_b : UInt<32> - | i0_b <= i0_a - | wire i1_a : UInt<32> - | wire i1_b : UInt<32> - | i1_b <= i1_a - | b <= i1_b - | i0_a <= a - | i1_a <= i0_b""".stripMargin - execute(input, check, Seq(inline("Simple"))) - } + "The all instances of Simple" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i0 of Simple + | inst i1 of Simple + | i0.a <= a + | i1.a <= i0.b + | b <= i1.b + | module Simple : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i0_a : UInt<32> + | wire i0_b : UInt<32> + | i0_b <= i0_a + | wire i1_a : UInt<32> + | wire i1_b : UInt<32> + | i1_b <= i1_a + | b <= i1_b + | i0_a <= a + | i1_a <= i0_b""".stripMargin + execute(input, check, Seq(inline("Simple"))) + } - "Only one instance of Simple" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i0 of Simple - | inst i1 of Simple - | i0.a <= a - | i1.a <= i0.b - | b <= i1.b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i0_a : UInt<32> - | wire i0_b : UInt<32> - | i0_b <= i0_a - | inst i1 of Simple - | b <= i1.b - | i0_a <= a - | i1.a <= i0_b - | module Simple : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(inline("Top.i0"))) - } + "Only one instance of Simple" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i0 of Simple + | inst i1 of Simple + | i0.a <= a + | i1.a <= i0.b + | b <= i1.b + | module Simple : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i0_a : UInt<32> + | wire i0_b : UInt<32> + | i0_b <= i0_a + | inst i1 of Simple + | b <= i1.b + | i0_a <= a + | i1.a <= i0_b + | module Simple : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(inline("Top.i0"))) + } - "All instances of A" should "be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i0 of A - | inst i1 of B - | i0.a <= a - | i1.a <= i0.b - | b <= i1.b - | module A : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - | module B : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i0_a : UInt<32> - | wire i0_b : UInt<32> - | i0_b <= i0_a - | inst i1 of B - | b <= i1.b - | i0_a <= a - | i1.a <= i0_b - | module B : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | i_b <= i_a - | b <= i_b - | i_a <= a""".stripMargin - execute(input, check, Seq(inline("A"))) - } + "All instances of A" should "be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i0 of A + | inst i1 of B + | i0.a <= a + | i1.a <= i0.b + | b <= i1.b + | module A : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module B : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i0_a : UInt<32> + | wire i0_b : UInt<32> + | i0_b <= i0_a + | inst i1 of B + | b <= i1.b + | i0_a <= a + | i1.a <= i0_b + | module B : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | i_b <= i_a + | b <= i_b + | i_a <= a""".stripMargin + execute(input, check, Seq(inline("A"))) + } - "Non-inlined instances" should "still prepend prefix" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b - | module A : - | input a : UInt<32> - | output b : UInt<32> - | inst i of B - | i.a <= a - | b <= i.b - | module B : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | inst i_i of B - | i_b <= i_i.b - | i_i.a <= i_a - | b <= i_b - | i_a <= a - | module B : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - execute(input, check, Seq(inline("A"))) - } + "Non-inlined instances" should "still prepend prefix" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b + | module A : + | input a : UInt<32> + | output b : UInt<32> + | inst i of B + | i.a <= a + | b <= i.b + | module B : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | inst i_i of B + | i_b <= i_i.b + | i_i.a <= i_a + | b <= i_b + | i_a <= a + | module B : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + execute(input, check, Seq(inline("A"))) + } "A module with nested inlines" should "still prepend prefixes" in { val input = @@ -291,57 +291,57 @@ class InlineInstancesTests extends LowTransformSpec { execute(input, check, Seq(inline("Foo"), inline("Foo.bar"))) } - // ---- Errors ---- - // 1) ext module - "External module" should "not be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b - | extmodule A : - | input a : UInt<32> - | output b : UInt<32>""".stripMargin - failingexecute(input, Seq(inline("A"))) - } - // 2) ext instance - "External instance" should "not be inlined" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of A - | i.a <= a - | b <= i.b - | extmodule A : - | input a : UInt<32> - | output b : UInt<32>""".stripMargin - failingexecute(input, Seq(inline("A"))) - } - // 3) no module - "Inlined module" should "exist" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - failingexecute(input, Seq(inline("A"))) - } - // 4) no inst - "Inlined instance" should "exist" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | b <= a""".stripMargin - failingexecute(input, Seq(inline("A"))) - } + // ---- Errors ---- + // 1) ext module + "External module" should "not be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b + | extmodule A : + | input a : UInt<32> + | output b : UInt<32>""".stripMargin + failingexecute(input, Seq(inline("A"))) + } + // 2) ext instance + "External instance" should "not be inlined" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of A + | i.a <= a + | b <= i.b + | extmodule A : + | input a : UInt<32> + | output b : UInt<32>""".stripMargin + failingexecute(input, Seq(inline("A"))) + } + // 3) no module + "Inlined module" should "exist" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + failingexecute(input, Seq(inline("A"))) + } + // 4) no inst + "Inlined instance" should "exist" in { + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | b <= a""".stripMargin + failingexecute(input, Seq(inline("A"))) + } "Jack's Bug" should "not fail" in { @@ -384,163 +384,167 @@ class InlineInstancesTests extends LowTransformSpec { override def duplicate(n: ReferenceTarget): Annotation = DummyAnno(n) } "annotations" should "be renamed" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst foo of NestedInline - | inst bar of NestedNoInline - | foo.a <= a - | bar.a <= foo.b - | b <= bar.b - | module NestedInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | wire i_foo_a : UInt<32> - | wire i_foo_b : UInt<32> - | i_foo_b <= i_foo_a - | inst i_bar of NestedNoInline - | i_b <= i_bar.b - | i_foo_a <= i_a - | i_bar.a <= i_foo_b - | b <= i_b - | i_a <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | inst bar of NestedNoInline + | foo.a <= a + | bar.a <= foo.b + | b <= bar.b + | module NestedInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_foo_a : UInt<32> + | wire i_foo_b : UInt<32> + | i_foo_b <= i_foo_a + | inst i_bar of NestedNoInline + | i_b <= i_bar.b + | i_foo_a <= i_a + | i_bar.a <= i_foo_b + | b <= i_b + | i_a <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + |""".stripMargin val top = CircuitTarget("Top").module("Top") val inlined = top.instOf("i", "Inline") val nestedInlined = top.instOf("i", "Inline").instOf("foo", "NestedInline") val nestedNotInlined = top.instOf("i", "Inline").instOf("bar", "NestedNoInline") - executeWithAnnos(input, check, - Seq( - inline("Inline"), - inline("NestedInline"), - NoCircuitDedupAnnotation, - DummyAnno(inlined.ref("a")), - DummyAnno(inlined.ref("b")), - DummyAnno(nestedInlined.ref("a")), - DummyAnno(nestedInlined.ref("b")), - DummyAnno(nestedNotInlined.ref("a")), - DummyAnno(nestedNotInlined.ref("b")) - ), - Seq( - DummyAnno(top.ref("i_a")), - DummyAnno(top.ref("i_b")), - DummyAnno(top.ref("i_foo_a")), - DummyAnno(top.ref("i_foo_b")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")) - ) - ) + executeWithAnnos( + input, + check, + Seq( + inline("Inline"), + inline("NestedInline"), + NoCircuitDedupAnnotation, + DummyAnno(inlined.ref("a")), + DummyAnno(inlined.ref("b")), + DummyAnno(nestedInlined.ref("a")), + DummyAnno(nestedInlined.ref("b")), + DummyAnno(nestedNotInlined.ref("a")), + DummyAnno(nestedNotInlined.ref("b")) + ), + Seq( + DummyAnno(top.ref("i_a")), + DummyAnno(top.ref("i_b")), + DummyAnno(top.ref("i_foo_a")), + DummyAnno(top.ref("i_foo_b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")) + ) + ) } "inlining both grandparent and grandchild" should "should work" in { - val input = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | inst i of Inline - | i.a <= a - | b <= i.b - | module Inline : - | input a : UInt<32> - | output b : UInt<32> - | inst foo of NestedInline - | inst bar of NestedNoInline - | foo.a <= a - | bar.a <= foo.b - | b <= bar.b - | module NestedInline : - | input a : UInt<32> - | output b : UInt<32> - | b <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | inst foo of NestedInline - | foo.a <= a - | b <= foo.b - |""".stripMargin - val check = - """circuit Top : - | module Top : - | input a : UInt<32> - | output b : UInt<32> - | wire i_a : UInt<32> - | wire i_b : UInt<32> - | wire i_foo_a : UInt<32> - | wire i_foo_b : UInt<32> - | i_foo_b <= i_foo_a - | inst i_bar of NestedNoInline - | i_b <= i_bar.b - | i_foo_a <= i_a - | i_bar.a <= i_foo_b - | b <= i_b - | i_a <= a - | module NestedNoInline : - | input a : UInt<32> - | output b : UInt<32> - | wire foo_a : UInt<32> - | wire foo_b : UInt<32> - | foo_b <= foo_a - | b <= foo_b - | foo_a <= a - |""".stripMargin + val input = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | inst i of Inline + | i.a <= a + | b <= i.b + | module Inline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | inst bar of NestedNoInline + | foo.a <= a + | bar.a <= foo.b + | b <= bar.b + | module NestedInline : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of NestedInline + | foo.a <= a + | b <= foo.b + |""".stripMargin + val check = + """circuit Top : + | module Top : + | input a : UInt<32> + | output b : UInt<32> + | wire i_a : UInt<32> + | wire i_b : UInt<32> + | wire i_foo_a : UInt<32> + | wire i_foo_b : UInt<32> + | i_foo_b <= i_foo_a + | inst i_bar of NestedNoInline + | i_b <= i_bar.b + | i_foo_a <= i_a + | i_bar.a <= i_foo_b + | b <= i_b + | i_a <= a + | module NestedNoInline : + | input a : UInt<32> + | output b : UInt<32> + | wire foo_a : UInt<32> + | wire foo_b : UInt<32> + | foo_b <= foo_a + | b <= foo_b + | foo_a <= a + |""".stripMargin val top = CircuitTarget("Top").module("Top") val inlined = top.instOf("i", "Inline") val nestedInlined = inlined.instOf("foo", "NestedInline") val nestedNotInlined = inlined.instOf("bar", "NestedNoInline") val innerNestedInlined = nestedNotInlined.instOf("foo", "NestedInline") - executeWithAnnos(input, check, - Seq( - inline("Inline"), - inline("NestedInline"), - DummyAnno(inlined.ref("a")), - DummyAnno(inlined.ref("b")), - DummyAnno(nestedInlined.ref("a")), - DummyAnno(nestedInlined.ref("b")), - DummyAnno(nestedNotInlined.ref("a")), - DummyAnno(nestedNotInlined.ref("b")), - DummyAnno(innerNestedInlined.ref("a")), - DummyAnno(innerNestedInlined.ref("b")) - ), - Seq( - DummyAnno(top.ref("i_a")), - DummyAnno(top.ref("i_b")), - DummyAnno(top.ref("i_foo_a")), - DummyAnno(top.ref("i_foo_b")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_a")), - DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_b")) - ) - ) + executeWithAnnos( + input, + check, + Seq( + inline("Inline"), + inline("NestedInline"), + DummyAnno(inlined.ref("a")), + DummyAnno(inlined.ref("b")), + DummyAnno(nestedInlined.ref("a")), + DummyAnno(nestedInlined.ref("b")), + DummyAnno(nestedNotInlined.ref("a")), + DummyAnno(nestedNotInlined.ref("b")), + DummyAnno(innerNestedInlined.ref("a")), + DummyAnno(innerNestedInlined.ref("b")) + ), + Seq( + DummyAnno(top.ref("i_a")), + DummyAnno(top.ref("i_b")), + DummyAnno(top.ref("i_foo_a")), + DummyAnno(top.ref("i_foo_b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("b")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_a")), + DummyAnno(top.instOf("i_bar", "NestedNoInline").ref("foo_b")) + ) + ) } "InlineInstances" should "properly invalidate ResolveKinds" in { @@ -562,7 +566,7 @@ class InlineInstancesTests extends LowTransformSpec { val result = manager.execute(state) result shouldNot containTree { case WRef("i_a", _, PortKind, _) => true } - result should containTree { case WRef("i_a", _, WireKind, _) => true } + result should containTree { case WRef("i_a", _, WireKind, _) => true } } } diff --git a/src/test/scala/firrtlTests/IntegrationSpec.scala b/src/test/scala/firrtlTests/IntegrationSpec.scala index b399923f..ff21a90b 100644 --- a/src/test/scala/firrtlTests/IntegrationSpec.scala +++ b/src/test/scala/firrtlTests/IntegrationSpec.scala @@ -23,10 +23,11 @@ class GCDSplitEmissionExecutionTest extends FirrtlFlatSpec { val optionsManager = new ExecutionOptionsManager("GCDTesterSplitEmission") with HasFirrtlOptions { commonOptions = CommonOptions(topName = top, targetDirName = testDir.getPath) firrtlOptions = FirrtlExecutionOptions( - inputFileNameOverride = sourceFile.getPath, - compilerName = "verilog", - infoModeName = "ignore", - emitOneFilePerModule = true) + inputFileNameOverride = sourceFile.getPath, + compilerName = "verilog", + infoModeName = "ignore", + emitOneFilePerModule = true + ) } firrtl.Driver.execute(optionsManager) @@ -42,7 +43,7 @@ class GCDSplitEmissionExecutionTest extends FirrtlFlatSpec { // topFile will be compiled by Verilator command by default but we need to also include dutFile verilogToCpp(top, testDir, Seq(dutFile), harness) #&& - cppToExe(top, testDir) ! loggingProcessLogger + cppToExe(top, testDir) ! loggingProcessLogger assert(executeExpectingSuccess(top, testDir)) } } @@ -53,14 +54,14 @@ class ICacheCompilationTest extends CompilationTest("ICache", "/regress") class FPUCompilationTest extends CompilationTest("FPU", "/regress") class HwachaSequencerCompilationTest extends CompilationTest("HwachaSequencer", "/regress") -abstract class CommonSubexprEliminationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(firrtl.passes.CommonSubexpressionElimination), name, dir) -abstract class DeadCodeEliminationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(new firrtl.transforms.DeadCodeElimination), name, dir) -abstract class ConstantPropagationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(new firrtl.transforms.ConstantPropagation), name, dir) -abstract class LowFirrtlOptimizationEquivTest(name: String, dir: String) extends - EquivalenceTest(Seq(new LowFirrtlOptimization), name, dir) +abstract class CommonSubexprEliminationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(firrtl.passes.CommonSubexpressionElimination), name, dir) +abstract class DeadCodeEliminationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(new firrtl.transforms.DeadCodeElimination), name, dir) +abstract class ConstantPropagationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(new firrtl.transforms.ConstantPropagation), name, dir) +abstract class LowFirrtlOptimizationEquivTest(name: String, dir: String) + extends EquivalenceTest(Seq(new LowFirrtlOptimization), name, dir) class OpsCommonSubexprEliminationTest extends CommonSubexprEliminationEquivTest("Ops", "/regress") class OpsDeadCodeEliminationTest extends DeadCodeEliminationEquivTest("Ops", "/regress") diff --git a/src/test/scala/firrtlTests/LegalizeSpec.scala b/src/test/scala/firrtlTests/LegalizeSpec.scala index 22fef730..aa6458ba 100644 --- a/src/test/scala/firrtlTests/LegalizeSpec.scala +++ b/src/test/scala/firrtlTests/LegalizeSpec.scala @@ -5,4 +5,3 @@ package firrtlTests import firrtl.testutils.ExecutionTest class LegalizeExecutionTest extends ExecutionTest("Legalize", "/passes/Legalize") - diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 648c6b36..0d020252 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -21,14 +21,14 @@ class LowerTypesSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String]) = { val fir = Parser.parse(input.split("\n").toIterator) val c = compiler.runTransform(CircuitState(fir, Seq())).circuit - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } - behavior of "Lower Types" + behavior.of("Lower Types") it should "lower ports" in { val input = @@ -39,9 +39,23 @@ class LowerTypesSpec extends FirrtlFlatSpec { | input y : UInt<1>[4] | input z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2] """.stripMargin - val expected = Seq("w", "x_a", "x_b", "y_0", "y_1", "y_2", "y_3", "z_0_c_d", - "z_0_c_e", "z_0_f_0", "z_0_f_1", "z_1_c_d", "z_1_c_e", "z_1_f_0", - "z_1_f_1") map (x => s"input $x : UInt<1>") map normalized + val expected = Seq( + "w", + "x_a", + "x_b", + "y_0", + "y_1", + "y_2", + "y_3", + "z_0_c_d", + "z_0_c_e", + "z_0_f_0", + "z_0_f_1", + "z_1_c_d", + "z_1_c_e", + "z_1_f_0", + "z_1_f_1" + ).map(x => s"input $x : UInt<1>").map(normalized) executeTest(input, expected) } @@ -56,7 +70,7 @@ class LowerTypesSpec extends FirrtlFlatSpec { val expected = Seq( "output foo_0_a : UInt<1>", "input foo_0_b : UInt<1>" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -72,29 +86,47 @@ class LowerTypesSpec extends FirrtlFlatSpec { | reg y : UInt<1>[4], clock | reg z : { c : { d : UInt<1>, e : UInt<1>}, f : UInt<1>[2] }[2], clock """.stripMargin - val expected = Seq("w", "x_a", "x_b", "y_0", "y_1", "y_2", "y_3", "z_0_c_d", - "z_0_c_e", "z_0_f_0", "z_0_f_1", "z_1_c_d", "z_1_c_e", "z_1_f_0", - "z_1_f_1") map (x => s"reg $x : UInt<1>, clock with :") map normalized + val expected = Seq( + "w", + "x_a", + "x_b", + "y_0", + "y_1", + "y_2", + "y_3", + "z_0_c_d", + "z_0_c_e", + "z_0_f_0", + "z_0_f_1", + "z_1_c_d", + "z_1_c_e", + "z_1_f_0", + "z_1_f_1" + ).map(x => s"reg $x : UInt<1>, clock with :").map(normalized) executeTest(input, expected) } it should "lower registers with aggregate initialization" in { val input = - """circuit Test : - | module Test : - | input clock : Clock - | input reset : UInt<1> - | input init : { a : UInt<1>, b : UInt<1>}[2] - | reg x : { a : UInt<1>, b : UInt<1>}[2], clock with : - | reset => (reset, init) + """circuit Test : + | module Test : + | input clock : Clock + | input reset : UInt<1> + | input init : { a : UInt<1>, b : UInt<1>}[2] + | reg x : { a : UInt<1>, b : UInt<1>}[2], clock with : + | reset => (reset, init) """.stripMargin val expected = Seq( - "reg x_0_a : UInt<1>, clock with :", "reset => (reset, init_0_a)", - "reg x_0_b : UInt<1>, clock with :", "reset => (reset, init_0_b)", - "reg x_1_a : UInt<1>, clock with :", "reset => (reset, init_1_a)", - "reg x_1_b : UInt<1>, clock with :", "reset => (reset, init_1_b)" - ) map normalized + "reg x_0_a : UInt<1>, clock with :", + "reset => (reset, init_0_a)", + "reg x_0_b : UInt<1>, clock with :", + "reset => (reset, init_0_b)", + "reg x_1_a : UInt<1>, clock with :", + "reset => (reset, init_1_a)", + "reg x_1_b : UInt<1>, clock with :", + "reset => (reset, init_1_b)" + ).map(normalized) executeTest(input, expected) } @@ -112,77 +144,87 @@ class LowerTypesSpec extends FirrtlFlatSpec { val expected = Seq( "reg foo : UInt<4>, clock_1 with :", "reset => (reset_a, init_3_b_1_d)" - ) map normalized + ).map(normalized) executeTest(input, expected) } it should "lower DefInstances (but not too far!)" in { val input = - """circuit Test : - | module Other : - | input a : { b : UInt<1>, c : UInt<1>} - | output d : UInt<1>[2] - | d[0] <= a.b - | d[1] <= a.c - | module Test : - | input x : UInt<1> - | inst mod of Other - | mod.a.b <= x - | mod.a.c <= x - | node y = mod.d[0] + """circuit Test : + | module Other : + | input a : { b : UInt<1>, c : UInt<1>} + | output d : UInt<1>[2] + | d[0] <= a.b + | d[1] <= a.c + | module Test : + | input x : UInt<1> + | inst mod of Other + | mod.a.b <= x + | mod.a.c <= x + | node y = mod.d[0] """.stripMargin - val expected = Seq( - "mod.a_b <= x", - "mod.a_c <= x", - "node y = mod.d_0") map normalized + val expected = Seq("mod.a_b <= x", "mod.a_c <= x", "node y = mod.d_0").map(normalized) executeTest(input, expected) } it should "lower aggregate memories" in { val input = - """circuit Test : - | module Test : - | input clock : Clock - | mem m : - | data-type => { a : UInt<8>, b : UInt<8>}[2] - | depth => 32 - | read-latency => 0 - | write-latency => 1 - | reader => read - | writer => write - | m.read.clk <= clock - | m.read.en <= UInt<1>(1) - | m.read.addr is invalid - | node x = m.read.data - | node y = m.read.data[0].b - | - | m.write.clk <= clock - | m.write.en <= UInt<1>(0) - | m.write.mask is invalid - | m.write.addr is invalid - | wire w : { a : UInt<8>, b : UInt<8>}[2] - | w[0].a <= UInt<4>(2) - | w[0].b <= UInt<4>(3) - | w[1].a <= UInt<4>(4) - | w[1].b <= UInt<4>(5) - | m.write.data <= w + """circuit Test : + | module Test : + | input clock : Clock + | mem m : + | data-type => { a : UInt<8>, b : UInt<8>}[2] + | depth => 32 + | read-latency => 0 + | write-latency => 1 + | reader => read + | writer => write + | m.read.clk <= clock + | m.read.en <= UInt<1>(1) + | m.read.addr is invalid + | node x = m.read.data + | node y = m.read.data[0].b + | + | m.write.clk <= clock + | m.write.en <= UInt<1>(0) + | m.write.mask is invalid + | m.write.addr is invalid + | wire w : { a : UInt<8>, b : UInt<8>}[2] + | w[0].a <= UInt<4>(2) + | w[0].b <= UInt<4>(3) + | w[1].a <= UInt<4>(4) + | w[1].b <= UInt<4>(5) + | m.write.data <= w """.stripMargin val expected = Seq( - "mem m_0_a :", "mem m_0_b :", "mem m_1_a :", "mem m_1_b :", - "m_0_a.read.clk <= clock", "m_0_b.read.clk <= clock", - "m_1_a.read.clk <= clock", "m_1_b.read.clk <= clock", - "m_0_a.read.addr is invalid", "m_0_b.read.addr is invalid", - "m_1_a.read.addr is invalid", "m_1_b.read.addr is invalid", - "node x_0_a = m_0_a.read.data", "node x_0_b = m_0_b.read.data", - "node x_1_a = m_1_a.read.data", "node x_1_b = m_1_b.read.data", - "m_0_a.write.mask is invalid", "m_0_b.write.mask is invalid", - "m_1_a.write.mask is invalid", "m_1_b.write.mask is invalid", - "m_0_a.write.data <= w_0_a", "m_0_b.write.data <= w_0_b", - "m_1_a.write.data <= w_1_a", "m_1_b.write.data <= w_1_b" - ) map normalized + "mem m_0_a :", + "mem m_0_b :", + "mem m_1_a :", + "mem m_1_b :", + "m_0_a.read.clk <= clock", + "m_0_b.read.clk <= clock", + "m_1_a.read.clk <= clock", + "m_1_b.read.clk <= clock", + "m_0_a.read.addr is invalid", + "m_0_b.read.addr is invalid", + "m_1_a.read.addr is invalid", + "m_1_b.read.addr is invalid", + "node x_0_a = m_0_a.read.data", + "node x_0_b = m_0_b.read.data", + "node x_1_a = m_1_a.read.data", + "node x_1_b = m_1_b.read.data", + "m_0_a.write.mask is invalid", + "m_0_b.write.mask is invalid", + "m_1_a.write.mask is invalid", + "m_1_b.write.mask is invalid", + "m_0_a.write.data <= w_0_a", + "m_0_b.write.data <= w_0_b", + "m_1_a.write.data <= w_1_a", + "m_1_b.write.data <= w_1_b" + ).map(normalized) executeTest(input, expected) } @@ -192,12 +234,17 @@ class LowerTypesSpec extends FirrtlFlatSpec { class LowerTypesUniquifySpec extends FirrtlFlatSpec { private val compiler = new TransformManager(Seq(Dependency(firrtl.passes.LowerTypes))) - private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Seq.empty, Seq.empty) - private def executeTest(input: String, expected: Seq[String], - inputAnnos: Seq[Annotation], expectedAnnos: Seq[Annotation]): Unit = { + private def executeTest(input: String, expected: Seq[String]): Unit = + executeTest(input, expected, Seq.empty, Seq.empty) + private def executeTest( + input: String, + expected: Seq[String], + inputAnnos: Seq[Annotation], + expectedAnnos: Seq[Annotation] + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val result = compiler.runTransform(CircuitState(circuit, inputAnnos)) - val lines = result.circuit.serialize.split("\n") map normalized + val lines = result.circuit.serialize.split("\n").map(normalized) expected.map(normalized).foreach { e => assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}") @@ -206,7 +253,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { result.annotations.toSeq should equal(expectedAnnos) } - behavior of "LowerTypes" + behavior.of("LowerTypes") it should "rename colliding ports" in { val input = @@ -221,17 +268,16 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { "input a___0_c__0_d : UInt<2>", "output a___0_c__0_e : UInt<3>", "output a_0_c_ : UInt<5>", - "output a__0 : UInt<6>") + "output a__0 : UInt<6>" + ) val m = CircuitTarget("Test").module("Test") val inputAnnos = Seq( DontTouchAnnotation(m.ref("a").index(0).field("b")), - DontTouchAnnotation(m.ref("a").index(0).field("c").index(0).field("e"))) - - val expectedAnnos = Seq( - DontTouchAnnotation(m.ref("a___0_b")), - DontTouchAnnotation(m.ref("a___0_c__0_e"))) + DontTouchAnnotation(m.ref("a").index(0).field("c").index(0).field("e")) + ) + val expectedAnnos = Seq(DontTouchAnnotation(m.ref("a___0_b")), DontTouchAnnotation(m.ref("a___0_c__0_e"))) executeTest(input, expected, inputAnnos, expectedAnnos) } @@ -250,7 +296,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { "reg a___1_c__1_e : UInt<3>, clock with :", "reg a___0_c_1_e : UInt<4>, clock with :", "reg a_0_c_ : UInt<5>, clock with :", - "reg a__0 : UInt<6>, clock with :") + "reg a__0 : UInt<6>, clock with :" + ) executeTest(input, expected) } @@ -274,7 +321,6 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { executeTest(input, expected) } - it should "rename DefRegister expressions: clock, reset, and init" in { val input = """circuit Test : @@ -368,9 +414,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | node foo = data.a | node bar = data.b[1] """.stripMargin - val expected = Seq( - "node foo = data___a", - "node bar = data___b_1") + val expected = Seq("node foo = data___a", "node bar = data___b_1") executeTest(input, expected) } @@ -439,7 +483,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { "mem mem__0_b_0 :", "node mem_0_b_0 = mem__0_b_0.read.data", "node mem_0_b_1 = mem__0_b_1.read.data", - "mem__0_b_0.read.addr is invalid") + "mem__0_b_0.read.addr is invalid" + ) executeTest(input, expected) } @@ -467,12 +512,8 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | mem.write.en <= UInt(0) | mem.write.clk <= clock """.stripMargin - val expected = Seq( - "mem mem_a :", - "mem mem_b__0 :", - "mem mem_b__1 :", - "mem mem_b_0 :", - "node x = mem_b__0.read.data") + val expected = + Seq("mem mem_a :", "mem mem_b__0 :", "mem mem_b__1 :", "mem mem_b_0 :", "node x = mem_b__0.read.data") executeTest(input, expected) } @@ -492,11 +533,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | mod.a.c <= x | node mod_a_b = mod.a_b """.stripMargin - val expected = Seq( - "inst mod_ of Other", - "mod_.a__b <= x", - "mod_.a__c <= x", - "node mod_a_b = mod_.a_b") + val expected = Seq("inst mod_ of Other", "mod_.a__b <= x", "mod_.a__c <= x", "node mod_a_b = mod_.a_b") executeTest(input, expected) } @@ -515,7 +552,7 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { // Run the "quick" test three times and choose the longest time as the basis. val nCalibrationRuns = 3 def mkType(i: Int): String = { - if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + if (i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" } val timesMs = ( for (depth <- (List.fill(nCalibrationRuns)(1) :+ depth)) yield { @@ -528,12 +565,11 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { val (ms, _) = Utils.time(compileToVerilog(input)) ms } - ).toArray + ).toArray // The baseMs will be the maximum of the first calibration runs val baseMs = timesMs.slice(0, nCalibrationRuns - 1).max val renameMs = timesMs(nCalibrationRuns) if (TestOptions.accurateTiming) - renameMs shouldBe < (baseMs * threshold) + renameMs shouldBe <(baseMs * threshold) } } - diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala index f0f2042e..46416619 100644 --- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala +++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala @@ -21,115 +21,127 @@ object Transforms { } import firrtl.{ChirrtlForm => C, HighForm => H, MidForm => M, LowForm => L, UnknownForm => U} class ChirrtlToChirrtl extends IdentityTransformDiff(C, C) - class HighToChirrtl extends IdentityTransformDiff(H, C) - class HighToHigh extends IdentityTransformDiff(H, H) - class MidToMid extends IdentityTransformDiff(M, M) - class MidToChirrtl extends IdentityTransformDiff(M, C) - class MidToHigh extends IdentityTransformDiff(M, H) - class LowToChirrtl extends IdentityTransformDiff(L, C) - class LowToHigh extends IdentityTransformDiff(L, H) - class LowToMid extends IdentityTransformDiff(L, M) - class LowToLow extends IdentityTransformDiff(L, L) + class HighToChirrtl extends IdentityTransformDiff(H, C) + class HighToHigh extends IdentityTransformDiff(H, H) + class MidToMid extends IdentityTransformDiff(M, M) + class MidToChirrtl extends IdentityTransformDiff(M, C) + class MidToHigh extends IdentityTransformDiff(M, H) + class LowToChirrtl extends IdentityTransformDiff(L, C) + class LowToHigh extends IdentityTransformDiff(L, H) + class LowToMid extends IdentityTransformDiff(L, M) + class LowToLow extends IdentityTransformDiff(L, L) } class LoweringCompilersSpec extends AnyFlatSpec with Matchers { def legacyTransforms(a: CoreTransform): Seq[Transform] = a match { - case _: ChirrtlToHighFirrtl => Seq( - new firrtl.stage.transforms.CheckScalaVersion, - firrtl.passes.CheckChirrtl, - firrtl.passes.CInferTypes, - firrtl.passes.CInferMDir, - firrtl.passes.RemoveCHIRRTL) + case _: ChirrtlToHighFirrtl => + Seq( + new firrtl.stage.transforms.CheckScalaVersion, + firrtl.passes.CheckChirrtl, + firrtl.passes.CInferTypes, + firrtl.passes.CInferMDir, + firrtl.passes.RemoveCHIRRTL + ) case _: IRToWorkingIR => Seq(firrtl.passes.ToWorkingIR) - case _: ResolveAndCheck => Seq( - firrtl.passes.CheckHighForm, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.CheckTypes, - firrtl.passes.Uniquify, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.ResolveFlows, - firrtl.passes.CheckFlows, - new firrtl.passes.InferBinaryPoints, - new firrtl.passes.TrimIntervals, - new firrtl.passes.InferWidths, - firrtl.passes.CheckWidths, - new firrtl.transforms.InferResets) - case _: HighFirrtlToMiddleFirrtl => Seq( - firrtl.passes.PullMuxes, - firrtl.passes.ReplaceAccesses, - firrtl.passes.ExpandConnects, - firrtl.passes.ZeroLengthVecs, - firrtl.passes.RemoveAccesses, - firrtl.passes.Uniquify, - firrtl.passes.ExpandWhens, - firrtl.passes.CheckInitialization, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.CheckTypes, - firrtl.passes.ResolveFlows, - new firrtl.passes.InferWidths, - firrtl.passes.CheckWidths, - new firrtl.passes.RemoveIntervals, - firrtl.passes.ConvertFixedToSInt, - firrtl.passes.ZeroWidth, - firrtl.passes.InferTypes) - case _: MiddleFirrtlToLowFirrtl => Seq( - firrtl.passes.LowerTypes, - firrtl.passes.ResolveKinds, - firrtl.passes.InferTypes, - firrtl.passes.ResolveFlows, - new firrtl.passes.InferWidths, - firrtl.passes.Legalize, - firrtl.transforms.RemoveReset, - firrtl.passes.ResolveFlows, - new firrtl.transforms.CheckCombLoops, - new checks.CheckResets, - new firrtl.transforms.RemoveWires) - case _: LowFirrtlOptimization => Seq( - firrtl.passes.RemoveValidIf, - new firrtl.transforms.ConstantPropagation, - firrtl.passes.PadWidths, - new firrtl.transforms.ConstantPropagation, - firrtl.passes.Legalize, - firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - new firrtl.transforms.ConstantPropagation, - firrtl.passes.SplitExpressions, - new firrtl.transforms.CombineCats, - firrtl.passes.CommonSubexpressionElimination, - new firrtl.transforms.DeadCodeElimination) - case _: MinimumLowFirrtlOptimization => Seq( - firrtl.passes.RemoveValidIf, - firrtl.passes.PadWidths, - firrtl.passes.Legalize, - firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter - firrtl.passes.SplitExpressions) + case _: ResolveAndCheck => + Seq( + firrtl.passes.CheckHighForm, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.CheckTypes, + firrtl.passes.Uniquify, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.ResolveFlows, + firrtl.passes.CheckFlows, + new firrtl.passes.InferBinaryPoints, + new firrtl.passes.TrimIntervals, + new firrtl.passes.InferWidths, + firrtl.passes.CheckWidths, + new firrtl.transforms.InferResets + ) + case _: HighFirrtlToMiddleFirrtl => + Seq( + firrtl.passes.PullMuxes, + firrtl.passes.ReplaceAccesses, + firrtl.passes.ExpandConnects, + firrtl.passes.ZeroLengthVecs, + firrtl.passes.RemoveAccesses, + firrtl.passes.Uniquify, + firrtl.passes.ExpandWhens, + firrtl.passes.CheckInitialization, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.CheckTypes, + firrtl.passes.ResolveFlows, + new firrtl.passes.InferWidths, + firrtl.passes.CheckWidths, + new firrtl.passes.RemoveIntervals, + firrtl.passes.ConvertFixedToSInt, + firrtl.passes.ZeroWidth, + firrtl.passes.InferTypes + ) + case _: MiddleFirrtlToLowFirrtl => + Seq( + firrtl.passes.LowerTypes, + firrtl.passes.ResolveKinds, + firrtl.passes.InferTypes, + firrtl.passes.ResolveFlows, + new firrtl.passes.InferWidths, + firrtl.passes.Legalize, + firrtl.transforms.RemoveReset, + firrtl.passes.ResolveFlows, + new firrtl.transforms.CheckCombLoops, + new checks.CheckResets, + new firrtl.transforms.RemoveWires + ) + case _: LowFirrtlOptimization => + Seq( + firrtl.passes.RemoveValidIf, + new firrtl.transforms.ConstantPropagation, + firrtl.passes.PadWidths, + new firrtl.transforms.ConstantPropagation, + firrtl.passes.Legalize, + firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + new firrtl.transforms.ConstantPropagation, + firrtl.passes.SplitExpressions, + new firrtl.transforms.CombineCats, + firrtl.passes.CommonSubexpressionElimination, + new firrtl.transforms.DeadCodeElimination + ) + case _: MinimumLowFirrtlOptimization => + Seq( + firrtl.passes.RemoveValidIf, + firrtl.passes.PadWidths, + firrtl.passes.Legalize, + firrtl.passes.memlib.VerilogMemDelays, // TODO move to Verilog emitter + firrtl.passes.SplitExpressions + ) } def compare(a: Seq[Transform], b: TransformManager, patches: Seq[PatchAction] = Seq.empty): Unit = { info(s"""Transform Order:\n${b.prettyPrint(" ")}""") val m = new scala.collection.mutable.HashMap[Int, Seq[Dependency[Transform]]].withDefault(_ => Seq.empty) - a.map(Dependency.fromTransform).zipWithIndex.foreach{ case (t, idx) => m(idx) = Seq(t) } + a.map(Dependency.fromTransform).zipWithIndex.foreach { case (t, idx) => m(idx) = Seq(t) } patches.foreach { case Add(line, txs) => m(line - 1) = m(line - 1) ++ txs case Del(line) => m.remove(line - 1) } - val patched = scala.collection.immutable.TreeMap(m.toArray:_*).values.flatten + val patched = scala.collection.immutable.TreeMap(m.toArray: _*).values.flatten patched .zip(b.flattenedTransformOrder.map(Dependency.fromTransform)) - .foreach{ case (aa, bb) => bb should be (aa) } + .foreach { case (aa, bb) => bb should be(aa) } info(s"found ${b.flattenedTransformOrder.size} transforms") - patched.size should be (b.flattenedTransformOrder.size) + patched.size should be(b.flattenedTransformOrder.size) } - behavior of "ChirrtlToHighFirrtl" + behavior.of("ChirrtlToHighFirrtl") it should "replicate the old order" in { val tm = new TransformManager(Forms.MinimalHighForm, Forms.ChirrtlForm) @@ -139,26 +151,28 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { compare(legacyTransforms(new firrtl.ChirrtlToHighFirrtl), tm, patches) } - behavior of "IRToWorkingIR" + behavior.of("IRToWorkingIR") it should "replicate the old order" in { val tm = new TransformManager(Forms.WorkingIR, Forms.MinimalHighForm) compare(legacyTransforms(new firrtl.IRToWorkingIR), tm) } - behavior of "ResolveAndCheck" + behavior.of("ResolveAndCheck") it should "replicate the old order" in { val tm = new TransformManager(Forms.Resolved, Forms.WorkingIR) val patches = Seq( // Uniquify is now part of [[firrtl.passes.LowerTypes]] - Del(5), Del(6), Del(7), + Del(5), + Del(6), + Del(7), Add(14, Seq(Dependency.fromTransform(firrtl.passes.CheckTypes))) ) compare(legacyTransforms(new ResolveAndCheck), tm, patches) } - behavior of "HighFirrtlToMiddleFirrtl" + behavior.of("HighFirrtlToMiddleFirrtl") it should "replicate the old order" in { val tm = new TransformManager(Forms.MidForm, Forms.Deduped) @@ -174,56 +188,54 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { Del(11), Del(12), Del(13), - Add(12, Seq(Dependency(firrtl.passes.ResolveFlows), - Dependency[firrtl.passes.InferWidths])), + Add(12, Seq(Dependency(firrtl.passes.ResolveFlows), Dependency[firrtl.passes.InferWidths])), Del(14), - Add(15, Seq(Dependency(firrtl.passes.ResolveKinds), - Dependency(firrtl.passes.InferTypes))), + Add(15, Seq(Dependency(firrtl.passes.ResolveKinds), Dependency(firrtl.passes.InferTypes))), // TODO Add(17, Seq(Dependency[firrtl.transforms.formal.AssertSubmoduleAssumptions])) ) compare(legacyTransforms(new HighFirrtlToMiddleFirrtl), tm, patches) } - behavior of "MiddleFirrtlToLowFirrtl" + behavior.of("MiddleFirrtlToLowFirrtl") it should "replicate the old order" in { val tm = new TransformManager(Forms.LowForm, Forms.MidForm) val patches = Seq( // Uniquify is now part of [[firrtl.passes.LowerTypes]] - Del(2), Del(3), Del(5), + Del(2), + Del(3), + Del(5), // RemoveWires now visibly invalidates ResolveKinds Add(11, Seq(Dependency(firrtl.passes.ResolveKinds))) ) compare(legacyTransforms(new MiddleFirrtlToLowFirrtl), tm, patches) } - behavior of "MinimumLowFirrtlOptimization" + behavior.of("MinimumLowFirrtlOptimization") it should "replicate the old order" in { val tm = new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm) val patches = Seq( Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))), - Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], - Dependency(firrtl.passes.ResolveKinds))) + Add(6, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) ) compare(legacyTransforms(new MinimumLowFirrtlOptimization), tm, patches) } - behavior of "LowFirrtlOptimization" + behavior.of("LowFirrtlOptimization") it should "replicate the old order" in { val tm = new TransformManager(Forms.LowFormOptimized, Forms.LowForm) val patches = Seq( Add(6, Seq(Dependency(firrtl.passes.ResolveFlows))), Add(7, Seq(Dependency(firrtl.passes.Legalize))), - Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], - Dependency(firrtl.passes.ResolveKinds))) + Add(8, Seq(Dependency[firrtl.transforms.LegalizeAndReductionsTransform], Dependency(firrtl.passes.ResolveKinds))) ) compare(legacyTransforms(new LowFirrtlOptimization), tm, patches) } - behavior of "VerilogMinimumOptimized" + behavior.of("VerilogMinimumOptimized") it should "replicate the old order" in { val legacy = Seq( @@ -238,12 +250,13 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { firrtl.passes.VerilogModulusCleanup, new firrtl.transforms.VerilogRename, firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes) + new firrtl.AddDescriptionNodes + ) val tm = new TransformManager(Forms.VerilogMinimumOptimized, (new firrtl.VerilogEmitter).prerequisites) compare(legacy, tm) } - behavior of "VerilogOptimized" + behavior.of("VerilogOptimized") it should "replicate the old order" in { val legacy = Seq( @@ -259,12 +272,13 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers { firrtl.passes.VerilogModulusCleanup, new firrtl.transforms.VerilogRename, firrtl.passes.VerilogPrep, - new firrtl.AddDescriptionNodes) + new firrtl.AddDescriptionNodes + ) val tm = new TransformManager(Forms.VerilogOptimized, Forms.LowFormOptimized) compare(legacy, tm) } - behavior of "Legacy Custom Transforms" + behavior.of("Legacy Custom Transforms") it should "work for Chirrtl -> Chirrtl" in { val expected = new Transforms.ChirrtlToChirrtl :: new firrtl.ChirrtlEmitter :: Nil diff --git a/src/test/scala/firrtlTests/MemLatencySpec.scala b/src/test/scala/firrtlTests/MemLatencySpec.scala index 79986cc2..8a04eeef 100644 --- a/src/test/scala/firrtlTests/MemLatencySpec.scala +++ b/src/test/scala/firrtlTests/MemLatencySpec.scala @@ -6,8 +6,8 @@ object MemLatencySpec { case class Write(addr: Int, data: Int, mask: Option[Boolean] = None) case class Read(addr: Int, expectedValue: Int) case class MemAccess(w: Option[Write], r: Option[Read]) - def writeOnly(addr: Int, data: Int) = MemAccess(Some(Write(addr, data)), None) - def readOnly(addr: Int, expectedValue: Int) = MemAccess(None, Some(Read(addr, expectedValue))) + def writeOnly(addr: Int, data: Int) = MemAccess(Some(Write(addr, data)), None) + def readOnly(addr: Int, expectedValue: Int) = MemAccess(None, Some(Read(addr, expectedValue))) } abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) @@ -36,7 +36,7 @@ abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) def mask2Poke(m: Option[Boolean]) = m match { case Some(false) => Poke("m.w.mask", 0) - case _ => Poke("m.w.mask", 1) + case _ => Poke("m.w.mask", 1) } def wPokes = memAccesses.map { @@ -47,24 +47,25 @@ abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) def rPokes = memAccesses.map { case MemAccess(_, Some(Read(a, _))) => Seq(Poke("m.r.en", 1), Poke("m.r.addr", a)) - case _ => Seq(Poke("m.r.en", 0), Invalidate("m.r.addr")) + case _ => Seq(Poke("m.r.en", 0), Invalidate("m.r.addr")) } // Need to idle for <rLatency> cycles at the end val idle = Seq(Poke("m.w.en", 0), Poke("m.r.en", 0)) - def pokes = (wPokes zip rPokes).map { case (wp, rp) => wp ++ rp } ++ Seq.fill(rLatency)(idle) + def pokes = (wPokes.zip(rPokes)).map { case (wp, rp) => wp ++ rp } ++ Seq.fill(rLatency)(idle) // Need to delay read value expects by <rLatency> def expects = Seq.fill(rLatency)(Seq(Step(1))) ++ memAccesses.map { case MemAccess(_, Some(Read(_, expected))) => Seq(Expect("m.r.data", expected), Step(1)) - case _ => Seq(Step(1)) + case _ => Seq(Step(1)) } - def commands: Seq[SimpleTestCommand] = (pokes zip expects).flatMap { case (p, e) => p ++ e } + def commands: Seq[SimpleTestCommand] = (pokes.zip(expects)).flatMap { case (p, e) => p ++ e } } trait ToggleMaskAndEnable { import MemLatencySpec._ + /** * A canonical sequence of memory accesses for sanity checking memories of different latencies. * The shortest true "RAW" hazard is reading address 14 two accesses after writing it. Since this @@ -76,19 +77,19 @@ trait ToggleMaskAndEnable { * @note Write-first mems should return expected values for (write-latency <= read-latency + 2) */ val memAccesses: Seq[MemAccess] = Seq( - MemAccess(Some(Write(6, 32)), None), - MemAccess(Some(Write(14, 87)), None), - MemAccess(None, None), - MemAccess(Some(Write(19, 63)), Some(Read(14, 87))), - MemAccess(Some(Write(22, 49)), None), - MemAccess(Some(Write(11, 99)), Some(Read(6, 32))), - MemAccess(Some(Write(42, 42)), None), - MemAccess(Some(Write(77, 81)), None), - MemAccess(Some(Write(6, 7)), Some(Read(19, 63))), - MemAccess(Some(Write(39, 5)), Some(Read(42, 42))), + MemAccess(Some(Write(6, 32)), None), + MemAccess(Some(Write(14, 87)), None), + MemAccess(None, None), + MemAccess(Some(Write(19, 63)), Some(Read(14, 87))), + MemAccess(Some(Write(22, 49)), None), + MemAccess(Some(Write(11, 99)), Some(Read(6, 32))), + MemAccess(Some(Write(42, 42)), None), + MemAccess(Some(Write(77, 81)), None), + MemAccess(Some(Write(6, 7)), Some(Read(19, 63))), + MemAccess(Some(Write(39, 5)), Some(Read(42, 42))), MemAccess(Some(Write(39, 6, Some(false))), Some(Read(77, 81))), // set mask to zero, should not write - MemAccess(None, Some(Read(6, 7))), // also read a twice-written address - MemAccess(None, Some(Read(39, 5))) // ensure masked writes didn't happen + MemAccess(None, Some(Read(6, 7))), // also read a twice-written address + MemAccess(None, Some(Read(39, 5))) // ensure masked writes didn't happen ) } @@ -111,20 +112,34 @@ class WriteFirstMemToggleSpec extends MemLatencySpec(rLatency = 1, wLatency = 1, class ReadFirstMemToggleSpec extends MemLatencySpec(rLatency = 1, wLatency = 1, ruw = "old") with ToggleMaskAndEnable // Read latency 2 -class WriteFirstMemToggleSpecRL2 extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "new") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL2 + extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "new") + with ToggleMaskAndEnable class ReadFirstMemToggleSpecRL2 extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "old") with ToggleMaskAndEnable // Write latency 2 -class WriteFirstMemToggleSpecWL2 extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "new") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecWL2 + extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "new") + with ToggleMaskAndEnable class ReadFirstMemToggleSpecWL2 extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "old") with ToggleMaskAndEnable // Read latency 2, write latency 2 -class WriteFirstMemToggleSpecRL2WL2 extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "new") with ToggleMaskAndEnable -class ReadFirstMemToggleSpecRL2WL2 extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "old") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL2WL2 + extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "new") + with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL2WL2 + extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "old") + with ToggleMaskAndEnable // Read latency 3, write latency 2 -class WriteFirstMemToggleSpecRL3WL2 extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "new") with ToggleMaskAndEnable -class ReadFirstMemToggleSpecRL3WL2 extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "old") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL3WL2 + extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "new") + with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL3WL2 + extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "old") + with ToggleMaskAndEnable // Read latency 2, write latency 4 -> ToggleSpec pattern only valid for write-first at this combo -class WriteFirstMemToggleSpecRL2WL4 extends MemLatencySpec(rLatency = 2, wLatency = 4, ruw = "new") with ToggleMaskAndEnable +class WriteFirstMemToggleSpecRL2WL4 + extends MemLatencySpec(rLatency = 2, wLatency = 4, ruw = "new") + with ToggleMaskAndEnable diff --git a/src/test/scala/firrtlTests/MemSpec.scala b/src/test/scala/firrtlTests/MemSpec.scala index c7ab8db7..e05aca86 100644 --- a/src/test/scala/firrtlTests/MemSpec.scala +++ b/src/test/scala/firrtlTests/MemSpec.scala @@ -50,7 +50,7 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { """.stripMargin val result = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, List.empty)) // TODO Not great that it includes the sparse comment for VCS - result should containLine (s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize-1}];") + result should containLine(s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize - 1}];") } property("Very large CHIRRTL memories should be supported") { @@ -76,7 +76,6 @@ class MemSpec extends FirrtlPropSpec with FirrtlMatchers { """.stripMargin val result = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, List.empty)) // TODO Not great that it includes the sparse comment for VCS - result should containLine (s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize-1}];") + result should containLine(s"reg /* sparse */ [7:0] m [0:$addrWidth'd${memSize - 1}];") } } - diff --git a/src/test/scala/firrtlTests/MemoryInitSpec.scala b/src/test/scala/firrtlTests/MemoryInitSpec.scala index 5598e58b..984bf0b4 100644 --- a/src/test/scala/firrtlTests/MemoryInitSpec.scala +++ b/src/test/scala/firrtlTests/MemoryInitSpec.scala @@ -11,37 +11,37 @@ import firrtlTests.execution._ class MemInitSpec extends FirrtlFlatSpec { def input(tpe: String): String = s""" - |circuit MemTest: - | module MemTest: - | input clock : Clock - | input rAddr : UInt<5> - | input rEnable : UInt<1> - | input wAddr : UInt<5> - | input wData : $tpe - | input wEnable : UInt<1> - | output rData : $tpe - | - | mem m: - | data-type => $tpe - | depth => 32 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => new - | - | m.r.clk <= clock - | m.r.addr <= rAddr - | m.r.en <= rEnable - | rData <= m.r.data - | - | m.w.clk <= clock - | m.w.addr <= wAddr - | m.w.en <= wEnable - | m.w.data <= wData - | m.w.mask is invalid - | - |""".stripMargin + |circuit MemTest: + | module MemTest: + | input clock : Clock + | input rAddr : UInt<5> + | input rEnable : UInt<1> + | input wAddr : UInt<5> + | input wData : $tpe + | input wEnable : UInt<1> + | output rData : $tpe + | + | mem m: + | data-type => $tpe + | depth => 32 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => new + | + | m.r.clk <= clock + | m.r.addr <= rAddr + | m.r.en <= rEnable + | rData <= m.r.data + | + | m.w.clk <= clock + | m.w.addr <= wAddr + | m.w.en <= wEnable + | m.w.data <= wData + | m.w.mask is invalid + | + |""".stripMargin val mRef = CircuitTarget("MemTest").module("MemTest").ref("m") def compile(annos: AnnotationSeq, tpe: String = "UInt<32>"): CircuitState = { @@ -51,13 +51,13 @@ class MemInitSpec extends FirrtlFlatSpec { "NoAnnotation" should "create a randomized initialization" in { val annos = Seq() val result = compile(annos) - result should containLine (" m[initvar] = _RAND_0[31:0];") + result should containLine(" m[initvar] = _RAND_0[31:0];") } "MemoryRandomInitAnnotation" should "create a randomized initialization" in { val annos = Seq(MemoryRandomInitAnnotation(mRef)) val result = compile(annos) - result should containLine (" m[initvar] = _RAND_0[31:0];") + result should containLine(" m[initvar] = _RAND_0[31:0];") } "MemoryScalarInitAnnotation w/ 0" should "create an initialization with all zeros" in { @@ -79,8 +79,9 @@ class MemInitSpec extends FirrtlFlatSpec { val values = Seq.tabulate(32)(ii => 2 * ii + 5).map(BigInt(_)) val annos = Seq(MemoryArrayInitAnnotation(mRef, values)) val result = compile(annos) - values.zipWithIndex.foreach { case (value, addr) => - result should containLine(s" m[$addr] = $value;") + values.zipWithIndex.foreach { + case (value, addr) => + result should containLine(s" m[$addr] = $value;") } } @@ -137,7 +138,9 @@ class MemInitSpec extends FirrtlFlatSpec { val annos = Seq(MemoryScalarInitAnnotation(mRef, 0)) compile(annos, "{real: SInt<10>, imag: SInt<10>}") } - assert(caught.getMessage.endsWith("Cannot initialize memory m of non ground type { real : SInt<10>, imag : SInt<10>}")) + assert( + caught.getMessage.endsWith("Cannot initialize memory m of non ground type { real : SInt<10>, imag : SInt<10>}") + ) } private def jsonAnno(name: String, suffix: String): String = @@ -165,39 +168,46 @@ class MemInitSpec extends FirrtlFlatSpec { } abstract class MemInitExecutionSpec(values: Seq[Int], init: ReferenceTarget => Annotation) - extends SimpleExecutionTest with VerilogExecution { + extends SimpleExecutionTest + with VerilogExecution { override val body: String = s""" - |mem m: - | data-type => UInt<32> - | depth => ${values.length} - | reader => r - | read-latency => 1 - | write-latency => 1 - | read-under-write => new - |m.r.clk <= clock - |m.r.en <= UInt<1>(1) - |""".stripMargin + |mem m: + | data-type => UInt<32> + | depth => ${values.length} + | reader => r + | read-latency => 1 + | write-latency => 1 + | read-under-write => new + |m.r.clk <= clock + |m.r.en <= UInt<1>(1) + |""".stripMargin val mRef = CircuitTarget("dut").module("dut").ref("m") override val customAnnotations: AnnotationSeq = Seq(init(mRef)) - override def commands: Seq[SimpleTestCommand] = (Seq(-1) ++ values).zipWithIndex.map { case (value, addr) => - if(value == -1) { Seq(Poke("m.r.addr", addr)) } - else if(addr >= values.length) { Seq(Expect("m.r.data", value)) } - else { Seq(Poke("m.r.addr", addr), Expect("m.r.data", value)) } + override def commands: Seq[SimpleTestCommand] = (Seq(-1) ++ values).zipWithIndex.map { + case (value, addr) => + if (value == -1) { Seq(Poke("m.r.addr", addr)) } + else if (addr >= values.length) { Seq(Expect("m.r.data", value)) } + else { Seq(Poke("m.r.addr", addr), Expect("m.r.data", value)) } }.flatMap(_ ++ Seq(Step(1))) } -class MemScalarInit0ExecutionSpec extends MemInitExecutionSpec( - Seq.tabulate(31)(_ => 0), r => MemoryScalarInitAnnotation(r, 0) -) {} - -class MemScalarInit17ExecutionSpec extends MemInitExecutionSpec( - Seq.tabulate(31)(_ => 17), r => MemoryScalarInitAnnotation(r, 17) -) {} - -class MemArrayInitExecutionSpec extends MemInitExecutionSpec( - Seq.tabulate(31)(ii => ii * 5 + 7), - r => MemoryArrayInitAnnotation(r, Seq.tabulate(31)(ii => ii * 5 + 7).map(BigInt(_))) -) {} +class MemScalarInit0ExecutionSpec + extends MemInitExecutionSpec( + Seq.tabulate(31)(_ => 0), + r => MemoryScalarInitAnnotation(r, 0) + ) {} + +class MemScalarInit17ExecutionSpec + extends MemInitExecutionSpec( + Seq.tabulate(31)(_ => 17), + r => MemoryScalarInitAnnotation(r, 17) + ) {} + +class MemArrayInitExecutionSpec + extends MemInitExecutionSpec( + Seq.tabulate(31)(ii => ii * 5 + 7), + r => MemoryArrayInitAnnotation(r, Seq.tabulate(31)(ii => ii * 5 + 7).map(BigInt(_))) + ) {} diff --git a/src/test/scala/firrtlTests/MultiThreadingSpec.scala b/src/test/scala/firrtlTests/MultiThreadingSpec.scala index c7b18624..6ec1a2bd 100644 --- a/src/test/scala/firrtlTests/MultiThreadingSpec.scala +++ b/src/test/scala/firrtlTests/MultiThreadingSpec.scala @@ -24,7 +24,8 @@ class MultiThreadingSpec extends FirrtlPropSpec { new firrtl.HighFirrtlCompiler, new firrtl.MiddleFirrtlCompiler, new firrtl.LowFirrtlCompiler, - new firrtl.VerilogCompiler) + new firrtl.VerilogCompiler + ) val inputFilePath = s"/integration/GCDTester.fir" // arbitrary val numThreads = 64 // arbitrary @@ -35,20 +36,20 @@ class MultiThreadingSpec extends FirrtlPropSpec { import ExecutionContext.Implicits.global try { // Use try-catch because error can manifest in many ways // Execute for each compiler - val compilerResults = compilers map { compiler => + val compilerResults = compilers.map { compiler => // Run compiler serially once val serialResult = runCompiler(inputStrings, compiler) Future { - val threadFutures = (0 until numThreads) map { i => - Future { - runCompiler(inputStrings, compiler) == serialResult - } + val threadFutures = (0 until numThreads).map { i => + Future { + runCompiler(inputStrings, compiler) == serialResult } + } Await.result(Future.sequence(threadFutures), Duration.Inf) } } val results = Await.result(Future.sequence(compilerResults), Duration.Inf) - assert(results.flatten reduce (_ && _)) // check all true (ie. success) + assert(results.flatten.reduce(_ && _)) // check all true (ie. success) } catch { case _: Throwable => fail("The Compiler is not thread safe") } diff --git a/src/test/scala/firrtlTests/NamespaceSpec.scala b/src/test/scala/firrtlTests/NamespaceSpec.scala index a9bb844d..bf7cb019 100644 --- a/src/test/scala/firrtlTests/NamespaceSpec.scala +++ b/src/test/scala/firrtlTests/NamespaceSpec.scala @@ -9,19 +9,19 @@ class NamespaceSpec extends FirrtlFlatSpec { "A Namespace" should "not allow collisions" in { val namespace = Namespace() - namespace.newName("foo") should be ("foo") - namespace.newName("foo") should be ("foo_0") + namespace.newName("foo") should be("foo") + namespace.newName("foo") should be("foo_0") } it should "start temps with a suffix of 0" in { - Namespace().newTemp.last should be ('0') + Namespace().newTemp.last should be('0') } it should "handle multiple prefixes with independent suffixes" in { val namespace = Namespace() - namespace.newName("foo") should be ("foo") - namespace.newName("foo") should be ("foo_0") - namespace.newName("bar") should be ("bar") - namespace.newName("bar") should be ("bar_0") + namespace.newName("foo") should be("foo") + namespace.newName("foo") should be("foo_0") + namespace.newName("bar") should be("bar") + namespace.newName("bar") should be("bar_0") } } diff --git a/src/test/scala/firrtlTests/ParserSpec.scala b/src/test/scala/firrtlTests/ParserSpec.scala index 3d377901..25e52e57 100644 --- a/src/test/scala/firrtlTests/ParserSpec.scala +++ b/src/test/scala/firrtlTests/ParserSpec.scala @@ -12,16 +12,17 @@ class ParserSpec extends FirrtlFlatSpec { private object MemTests { val prelude = Seq("circuit top :", " module top :", " mem m : ") - val fields = Map("data-type" -> "UInt<32>", - "depth" -> "4", - "read-latency" -> "1", - "write-latency" -> "1", - "reader" -> "a", - "writer" -> "b", - "readwriter" -> "c" - ) + val fields = Map( + "data-type" -> "UInt<32>", + "depth" -> "4", + "read-latency" -> "1", + "write-latency" -> "1", + "reader" -> "a", + "writer" -> "b", + "readwriter" -> "c" + ) def fieldsToSeq(m: Map[String, String]): Seq[String] = - m.map { case (k,v) => s" ${k} => ${v}" }.toSeq + m.map { case (k, v) => s" ${k} => ${v}" }.toSeq } private object RegTests { @@ -36,11 +37,51 @@ class ParserSpec extends FirrtlFlatSpec { private object KeywordTests { val prelude = Seq("circuit top :", " module top :") - val keywords = Seq("circuit", "module", "extmodule", "parameter", "input", "output", "UInt", - "SInt", "Analog", "Fixed", "flip", "Clock", "wire", "reg", "reset", "with", "mem", "depth", - "reader", "writer", "readwriter", "inst", "of", "node", "is", "invalid", "when", "else", - "stop", "printf", "skip", "old", "new", "undefined", "mux", "validif", "cmem", "smem", - "mport", "infer", "read", "write", "rdwr") ++ PrimOps.listing + val keywords = Seq( + "circuit", + "module", + "extmodule", + "parameter", + "input", + "output", + "UInt", + "SInt", + "Analog", + "Fixed", + "flip", + "Clock", + "wire", + "reg", + "reset", + "with", + "mem", + "depth", + "reader", + "writer", + "readwriter", + "inst", + "of", + "node", + "is", + "invalid", + "when", + "else", + "stop", + "printf", + "skip", + "old", + "new", + "undefined", + "mux", + "validif", + "cmem", + "smem", + "mport", + "infer", + "read", + "write", + "rdwr" + ) ++ PrimOps.listing } // ********** Memories ********** @@ -48,7 +89,7 @@ class ParserSpec extends FirrtlFlatSpec { val fields = MemTests.fieldsToSeq(MemTests.fields) val golden = firrtl.Parser.parse((MemTests.prelude ++ fields)) - fields.permutations foreach { permutation => + fields.permutations.foreach { permutation => val circuit = firrtl.Parser.parse((MemTests.prelude ++ permutation)) assert(golden === circuit) } @@ -56,13 +97,13 @@ class ParserSpec extends FirrtlFlatSpec { it should "have exactly one of each: data-type, depth, read-latency, and write-latency" in { import MemTests._ - def parseWithoutField(s: String) = firrtl.Parser.parse((prelude ++ fieldsToSeq(fields - s))) + def parseWithoutField(s: String) = firrtl.Parser.parse((prelude ++ fieldsToSeq(fields - s))) def parseWithDuplicate(k: String, v: String) = firrtl.Parser.parse((prelude ++ fieldsToSeq(fields) :+ s" ${k} => ${v}")) - Seq("data-type", "depth", "read-latency", "write-latency") foreach { field => - an [ParameterNotSpecifiedException] should be thrownBy { parseWithoutField(field) } - an [ParameterRedefinedException] should be thrownBy { parseWithDuplicate(field, fields(field)) } + Seq("data-type", "depth", "read-latency", "write-latency").foreach { field => + an[ParameterNotSpecifiedException] should be thrownBy { parseWithoutField(field) } + an[ParameterRedefinedException] should be thrownBy { parseWithDuplicate(field, fields(field)) } } } @@ -86,7 +127,7 @@ class ParserSpec extends FirrtlFlatSpec { import RegTests._ val res = firrtl.Parser.parse((prelude :+ s"${reg} with : (${reset}) $finfo" :+ " wire a : UInt")) CircuitState(res, Nil) should containTree { - case DefRegister(`fileInfo`, `regName`, _,_,_,_) => true + case DefRegister(`fileInfo`, `regName`, _, _, _, _) => true } } @@ -94,7 +135,7 @@ class ParserSpec extends FirrtlFlatSpec { import RegTests._ val res = firrtl.Parser.parse((prelude :+ s"${reg} with :\n (${reset}) $finfo")) CircuitState(res, Nil) should containTree { - case DefRegister(`fileInfo`, `regName`, _,_,_,_) => true + case DefRegister(`fileInfo`, `regName`, _, _, _, _) => true } } @@ -102,35 +143,34 @@ class ParserSpec extends FirrtlFlatSpec { import RegTests._ val res = firrtl.Parser.parse((prelude :+ s"${reg} $finfo")) CircuitState(res, Nil) should containTree { - case DefRegister(`fileInfo`, `regName`, _,_,_,_) => true + case DefRegister(`fileInfo`, `regName`, _, _, _, _) => true } } // ********** Keywords ********** "Keywords" should "be allowed as Ids" in { import KeywordTests._ - keywords foreach { keyword => + keywords.foreach { keyword => firrtl.Parser.parse((prelude :+ s" wire ${keyword} : UInt")) } } it should "be allowed on lhs in connects" in { import KeywordTests._ - keywords foreach { keyword => - firrtl.Parser.parse((prelude ++ Seq(s" wire ${keyword} : UInt", - s" ${keyword} <= ${keyword}"))) + keywords.foreach { keyword => + firrtl.Parser.parse((prelude ++ Seq(s" wire ${keyword} : UInt", s" ${keyword} <= ${keyword}"))) } } // ********** Digits as Fields ********** "Digits" should "be legal fields in bundles and in subexpressions" in { val input = """ - |circuit Test : - | module Test : - | input in : { 0 : { 0 : UInt<32>, flip 1 : UInt<32> } } - | input in2 : { 4 : { 23 : { foo : UInt<32>, bar : { flip 123 : UInt<32> } } } } - | in.0.1 <= in.0.0 - | in2.4.23.bar.123 <= in2.4.23.foo + |circuit Test : + | module Test : + | input in : { 0 : { 0 : UInt<32>, flip 1 : UInt<32> } } + | input in2 : { 4 : { 23 : { foo : UInt<32>, bar : { flip 123 : UInt<32> } } } } + | in.0.1 <= in.0.0 + | in2.4.23.bar.123 <= in2.4.23.foo """.stripMargin val c = firrtl.Parser.parse(input) firrtl.Parser.parse(c.serialize) @@ -148,7 +188,7 @@ class ParserSpec extends FirrtlFlatSpec { } def check(inFormat: String, ref: Integer): Unit = { - (circuit(inFormat)) should be (circuit(ref.toString)) + (circuit(inFormat)) should be(circuit(ref.toString)) } val checks = Map( @@ -166,25 +206,25 @@ class ParserSpec extends FirrtlFlatSpec { ) checks.foreach { case (k, v) => check(k, v) } - } + } // ********** Doubles as parameters ********** "Doubles" should "be legal parameters for extmodules" in { val nums = Seq("1.0", "7.6", "3.00004", "1.0E10", "1.0023E-17") val signs = Seq("", "+", "-") - val tests = "0.0" +: (signs flatMap (s => nums map (n => s + n))) + val tests = "0.0" +: (signs.flatMap(s => nums.map(n => s + n))) for (test <- tests) { val input = s""" - |circuit Test : - | extmodule Ext : - | input foo : UInt<32> - | - | defname = MyExtModule - | parameter REAL = $test - | - | module Test : - | input foo : UInt<32> - | output bar : UInt<32> + |circuit Test : + | extmodule Ext : + | input foo : UInt<32> + | + | defname = MyExtModule + | parameter REAL = $test + | + | module Test : + | input foo : UInt<32> + | output bar : UInt<32> """.stripMargin val c = firrtl.Parser.parse(input) firrtl.Parser.parse(c.serialize) @@ -193,16 +233,16 @@ class ParserSpec extends FirrtlFlatSpec { "Strings" should "be legal parameters for extmodules" in { val input = s""" - |circuit Test : - | extmodule Ext : - | input foo : UInt<32> - | - | defname = MyExtModule - | parameter STR = "hello=%d" - | - | module Test : - | input foo : UInt<32> - | output bar : UInt<32> + |circuit Test : + | extmodule Ext : + | input foo : UInt<32> + | + | defname = MyExtModule + | parameter STR = "hello=%d" + | + | module Test : + | input foo : UInt<32> + | output bar : UInt<32> """.stripMargin val c = firrtl.Parser.parse(input) firrtl.Parser.parse(c.serialize) @@ -210,37 +250,37 @@ class ParserSpec extends FirrtlFlatSpec { "Parsing errors" should "be reported as normal exceptions" in { val input = s""" - |circuit Test - | module Test : + |circuit Test + | module Test : - |""".stripMargin + |""".stripMargin val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input)) } - a [SyntaxErrorsException] shouldBe thrownBy { + a[SyntaxErrorsException] shouldBe thrownBy { Driver.execute(manager) } } "Trailing syntax errors" should "be caught in the parser" in { val input = s""" - |circuit Foo: - | module Bar: - | input a: UInt<1> - |output b: UInt<1> - | b <- a - | - | module Foo: - | input a: UInt<1> - | output b: UInt<1> - | inst bar of Bar - | bar.a <- a - | b <- bar.b + |circuit Foo: + | module Bar: + | input a: UInt<1> + |output b: UInt<1> + | b <- a + | + | module Foo: + | input a: UInt<1> + | output b: UInt<1> + | inst bar of Bar + | bar.a <- a + | b <- bar.b """.stripMargin val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input)) } - a [SyntaxErrorsException] shouldBe thrownBy { + a[SyntaxErrorsException] shouldBe thrownBy { Driver.execute(manager) } } @@ -250,9 +290,9 @@ class ParserSpec extends FirrtlFlatSpec { val info = ir.MultiInfo(Seq(ir.MultiInfo(Seq(ir.FileInfo("a"))), ir.FileInfo("b"), ir.FileInfo("c"))) val input = s"""circuit m:${info.serialize} - | module m: - | skip - |""".stripMargin + | module m: + | skip + |""".stripMargin val c = firrtl.Parser.parse(input) assert(c.info == ir.FileInfo("a b c")) } @@ -272,14 +312,14 @@ class ParserPropSpec extends FirrtlPropSpec { } yield (x :: xs).mkString property("Identifiers should allow [A-Za-z0-9_$] but not allow starting with a digit or $") { - forAll (identifier) { id => + forAll(identifier) { id => whenever(id.nonEmpty) { val input = s""" - |circuit Test : - | module Test : - | input $id : UInt<32> - |""".stripMargin - firrtl.Parser.parse(input split "\n") + |circuit Test : + | module Test : + | input $id : UInt<32> + |""".stripMargin + firrtl.Parser.parse(input.split("\n")) } } } @@ -289,15 +329,16 @@ class ParserPropSpec extends FirrtlPropSpec { } yield xs.mkString property("Bundle fields should allow [A-Za-z0-9_] including starting with a digit or $") { - forAll (identifier, bundleField) { case (id, field) => - whenever(id.nonEmpty && field.nonEmpty) { - val input = s""" - |circuit Test : - | module Test : - | input $id : { $field : UInt<32> } - |""".stripMargin - firrtl.Parser.parse(input split "\n") - } + forAll(identifier, bundleField) { + case (id, field) => + whenever(id.nonEmpty && field.nonEmpty) { + val input = s""" + |circuit Test : + | module Test : + | input $id : { $field : UInt<32> } + |""".stripMargin + firrtl.Parser.parse(input.split("\n")) + } } } } diff --git a/src/test/scala/firrtlTests/PresetSpec.scala b/src/test/scala/firrtlTests/PresetSpec.scala index 689a910d..9fa64647 100644 --- a/src/test/scala/firrtlTests/PresetSpec.scala +++ b/src/test/scala/firrtlTests/PresetSpec.scala @@ -13,156 +13,178 @@ class PresetSpec extends FirrtlFlatSpec { def compile(input: String, annos: AnnotationSeq): CircuitState = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), List.empty) def compileBody(modules: ModuleSeq) = { - val annos = Seq(new PresetAnnotation(CircuitTarget("Test").module("Test").ref("reset")), firrtl.transforms.NoDCEAnnotation) + val annos = + Seq(new PresetAnnotation(CircuitTarget("Test").module("Test").ref("reset")), firrtl.transforms.NoDCEAnnotation) var str = """ - |circuit Test : - |""".stripMargin - modules foreach ((m: Mod) => { + |circuit Test : + |""".stripMargin + modules.foreach((m: Mod) => { val header = "|module " + m(0) + " :" - str += header.stripMargin.stripMargin.split("\n").mkString(" ", "\n ", "") + str += header.stripMargin.stripMargin.split("\n").mkString(" ", "\n ", "") str += m(1).split("\n").mkString(" ", "\n ", "") str += """ - |""".stripMargin + |""".stripMargin }) - compile(str,annos) + compile(str, annos) } "Preset" should """behave properly given a `Preset` annotated `AsyncReset` INPUT reset: - replace AsyncReset specific blocks by standard Register blocks - add inline declaration of all registers connected to reset - remove the useless input port""" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result shouldNot containLine ("input reset,") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result shouldNot containLine("input reset,") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } - + it should """behave properly given a `Preset` annotated `AsyncReset` WIRE reset: - replace AsyncReset specific blocks by standard Register blocks - add inline declaration of all registers connected to reset - remove the useless wire declaration and assignation""" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input x : UInt<1> - |output z : UInt<1> - |wire reset : AsyncReset - |reset <= asAsyncReset(UInt(0)) - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") - // it should also remove useless asyncReset signal, all along the path down to registers - result shouldNot containLine ("wire reset;") - result shouldNot containLine ("assign reset = 1'h0;") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") + // it should also remove useless asyncReset signal, all along the path down to registers + result shouldNot containLine("wire reset;") + result shouldNot containLine("assign reset = 1'h0;") } it should "raise TreeCleanUpOrphantException on cast of annotated AsyncReset" in { - an [firrtl.transforms.PropagatePresetAnnotations.TreeCleanUpOrphanException] shouldBe thrownBy { - compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input x : UInt<1> - |output z : UInt<1> - |output sz : UInt<1> - |wire reset : AsyncReset - |reset <= asAsyncReset(UInt(0)) - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |wire sreset : UInt<1> - |sreset <= asUInt(reset) ; this is FORBIDDEN - |reg s : UInt<1>, clock with : (reset => (sreset, UInt(0))) - |r <= x - |s <= x - |z <= r - |sz <= s""".stripMargin)) + an[firrtl.transforms.PropagatePresetAnnotations.TreeCleanUpOrphanException] shouldBe thrownBy { + compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |output sz : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |wire sreset : UInt<1> + |sreset <= asUInt(reset) ; this is FORBIDDEN + |reg s : UInt<1>, clock with : (reset => (sreset, UInt(0))) + |r <= x + |s <= x + |z <= r + |sz <= s""".stripMargin + ) + ) ) } } - + it should "propagate through bundles" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} - |bundle.in_rst <= reset - |bundle.out_rst <= bundle.in_rst - |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} + |bundle.in_rst <= reset + |bundle.out_rst <= bundle.in_rst + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } it should "propagate through vectors" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire vector : AsyncReset[2] - |vector[0] <= reset - |vector[1] <= vector[0] - |reg r : UInt<1>, clock with : (reset => (vector[1], UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire vector : AsyncReset[2] + |vector[0] <= reset + |vector[1] <= vector[0] + |reg r : UInt<1>, clock with : (reset => (vector[1], UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } - + it should "propagate through bundles of vectors" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire bundle : {in_rst: AsyncReset[2], out_rst:AsyncReset} - |bundle.in_rst[0] <= reset - |bundle.in_rst[1] <= bundle.in_rst[0] - |bundle.out_rst <= bundle.in_rst[1] - |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset[2], out_rst:AsyncReset} + |bundle.in_rst[0] <= reset + |bundle.in_rst[1] <= bundle.in_rst[0] + |bundle.out_rst <= bundle.in_rst[1] + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } it should """propagate properly accross modules: - replace AsyncReset specific blocks by standard Register blocks @@ -171,70 +193,79 @@ class PresetSpec extends FirrtlFlatSpec { - remove the useless instance connections - remove wires and assignations used in instance connections """ in { - val result = compileBody(Seq( - Seq("TestA",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) - |r <= x - |z <= r - |""".stripMargin), - Seq("Test",s""" - |input clock : Clock - |input x : UInt<1> - |output z : UInt<1> - |wire reset : AsyncReset - |reset <= asAsyncReset(UInt(0)) - |inst i of TestA - |i.clock <= clock - |i.reset <= reset - |i.x <= x - |z <= i.z""".stripMargin) - )) + val result = compileBody( + Seq( + Seq( + "TestA", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |reg r : UInt<1>, clock with : (reset => (reset, UInt(0))) + |r <= x + |z <= r + |""".stripMargin + ), + Seq( + "Test", + s""" + |input clock : Clock + |input x : UInt<1> + |output z : UInt<1> + |wire reset : AsyncReset + |reset <= asAsyncReset(UInt(0)) + |inst i of TestA + |i.clock <= clock + |i.reset <= reset + |i.x <= x + |z <= i.z""".stripMargin + ) + ) + ) // assess that all useless connections are not emitted - result shouldNot containLine ("wire i_reset;") - result shouldNot containLine (".reset(i_reset),") - result shouldNot containLine ("assign i_reset = reset;") - result shouldNot containLine ("input reset,") - - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("wire i_reset;") + result shouldNot containLine(".reset(i_reset),") + result shouldNot containLine("assign i_reset = reset;") + result shouldNot containLine("input reset,") + + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } it should "propagate even through disordonned statements" in { - val result = compileBody(Seq(Seq("Test",s""" - |input clock : Clock - |input reset : AsyncReset - |input x : UInt<1> - |output z : UInt<1> - |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} - |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) - |bundle.out_rst <= bundle.in_rst - |bundle.in_rst <= reset - |r <= x - |z <= r""".stripMargin)) + val result = compileBody( + Seq( + Seq( + "Test", + s""" + |input clock : Clock + |input reset : AsyncReset + |input x : UInt<1> + |output z : UInt<1> + |wire bundle : {in_rst: AsyncReset, out_rst:AsyncReset} + |reg r : UInt<1>, clock with : (reset => (bundle.out_rst, UInt(0))) + |bundle.out_rst <= bundle.in_rst + |bundle.in_rst <= reset + |r <= x + |z <= r""".stripMargin + ) + ) ) - result shouldNot containLine ("input reset,") - result shouldNot containLine ("always @(posedge clock or posedge reset) begin") - result shouldNot containLines ( - "if (reset) begin", - "r = 1'h0;", - "end") - result should containLine ("always @(posedge clock) begin") - result should containLine ("reg r = 1'h0;") + result shouldNot containLine("input reset,") + result shouldNot containLine("always @(posedge clock or posedge reset) begin") + result shouldNot containLines("if (reset) begin", "r = 1'h0;", "end") + result should containLine("always @(posedge clock) begin") + result should containLine("reg r = 1'h0;") } } -class PresetExecutionTest extends ExecutionTest( - "PresetTester", - "/features", - annotations = Seq(new PresetAnnotation(CircuitTarget("PresetTester").module("PresetTester").ref("preset"))) -) +class PresetExecutionTest + extends ExecutionTest( + "PresetTester", + "/features", + annotations = Seq(new PresetAnnotation(CircuitTarget("PresetTester").module("PresetTester").ref("preset"))) + ) diff --git a/src/test/scala/firrtlTests/ProtoBufSpec.scala b/src/test/scala/firrtlTests/ProtoBufSpec.scala index 3a94ec3f..7cfdc4dc 100644 --- a/src/test/scala/firrtlTests/ProtoBufSpec.scala +++ b/src/test/scala/firrtlTests/ProtoBufSpec.scala @@ -44,50 +44,50 @@ class ProtoBufSpec extends FirrtlFlatSpec { val cistream = com.google.protobuf.CodedInputStream.newInstance(istream) cistream.setRecursionLimit(Integer.MAX_VALUE) val protobuf2 = firrtl.FirrtlProtos.Firrtl.parseFrom(cistream) - protobuf2 should equal (protobuf) + protobuf2 should equal(protobuf) // Test that our faster serialization matches generated serialization val ostream2 = new java.io.ByteArrayOutputStream proto.ToProto.writeToStream(ostream2, circuit) - ostream2.toByteArray.toList should equal (ostream.toByteArray.toList) + ostream2.toByteArray.toList should equal(ostream.toByteArray.toList) } } // ********** Focused Tests ********** // The goal is to fill coverage holes left after the above - behavior of "ProtoBuf serialization and deserialization" + behavior.of("ProtoBuf serialization and deserialization") import firrtl.proto._ it should "support UnknownWidth" in { // Note that this has to be handled in the parent object so we need to test everything that has a width val uint = ir.UIntType(ir.UnknownWidth) - FromProto.convert(ToProto.convert(uint).build) should equal (uint) + FromProto.convert(ToProto.convert(uint).build) should equal(uint) val sint = ir.SIntType(ir.UnknownWidth) - FromProto.convert(ToProto.convert(sint).build) should equal (sint) + FromProto.convert(ToProto.convert(sint).build) should equal(sint) val ftpe = ir.FixedType(ir.UnknownWidth, ir.UnknownWidth) - FromProto.convert(ToProto.convert(ftpe).build) should equal (ftpe) + FromProto.convert(ToProto.convert(ftpe).build) should equal(ftpe) val atpe = ir.AnalogType(ir.UnknownWidth) - FromProto.convert(ToProto.convert(atpe).build) should equal (atpe) + FromProto.convert(ToProto.convert(atpe).build) should equal(atpe) val ulit = ir.UIntLiteral(123, ir.UnknownWidth) - FromProto.convert(ToProto.convert(ulit).build) should equal (ulit) + FromProto.convert(ToProto.convert(ulit).build) should equal(ulit) val slit = ir.SIntLiteral(-123, ir.UnknownWidth) - FromProto.convert(ToProto.convert(slit).build) should equal (slit) + FromProto.convert(ToProto.convert(slit).build) should equal(slit) val flit = ir.FixedLiteral(-123, ir.UnknownWidth, ir.UnknownWidth) - FromProto.convert(ToProto.convert(flit).build) should equal (flit) + FromProto.convert(ToProto.convert(flit).build) should equal(flit) } it should "support all Primops" in { val builtInOps = PrimOps.listing.map(PrimOps.fromString(_)) for (op <- builtInOps) { val expr = DoPrim(op, List.empty, List.empty, ir.UnknownType) - FromProto.convert(ToProto.convert(expr).build) should equal (expr) + FromProto.convert(ToProto.convert(expr).build) should equal(expr) } } @@ -103,25 +103,25 @@ class ProtoBufSpec extends FirrtlFlatSpec { RawStringParam("param4", "get some raw strings") ) val ext = ir.ExtModule(ir.NoInfo, "MyModule", ports, "DefNameHere", params) - FromProto.convert(ToProto.convert(ext).build) should equal (ext) + FromProto.convert(ToProto.convert(ext).build) should equal(ext) } it should "support FixedType" in { val ftpe = ir.FixedType(IntWidth(8), IntWidth(4)) - FromProto.convert(ToProto.convert(ftpe).build) should equal (ftpe) + FromProto.convert(ToProto.convert(ftpe).build) should equal(ftpe) } it should "support FixedLiteral" in { val flit = ir.FixedLiteral(3, IntWidth(8), IntWidth(4)) - FromProto.convert(ToProto.convert(flit).build) should equal (flit) + FromProto.convert(ToProto.convert(flit).build) should equal(flit) } it should "support Analog and Attach" in { val analog = ir.AnalogType(IntWidth(8)) - FromProto.convert(ToProto.convert(analog).build) should equal (analog) + FromProto.convert(ToProto.convert(analog).build) should equal(analog) val attach = ir.Attach(ir.NoInfo, Seq(Reference("hi", ir.UnknownType))) - FromProto.convert(ToProto.convert(attach).head.build) should equal (attach) + FromProto.convert(ToProto.convert(attach).head.build) should equal(attach) } // Regression tests were generated before Chisel could emit else @@ -129,12 +129,12 @@ class ProtoBufSpec extends FirrtlFlatSpec { val expr = Reference("hi", ir.UnknownType) val stmt = Connect(ir.NoInfo, expr, expr) val when = ir.Conditionally(ir.NoInfo, expr, stmt, stmt) - FromProto.convert(ToProto.convert(when).head.build) should equal (when) + FromProto.convert(ToProto.convert(when).head.build) should equal(when) } it should "support SIntLiteral with a width" in { val slit = ir.SIntLiteral(-123) - FromProto.convert(ToProto.convert(slit).build) should equal (slit) + FromProto.convert(ToProto.convert(slit).build) should equal(slit) } // Backwards compatibility @@ -143,18 +143,21 @@ class ProtoBufSpec extends FirrtlFlatSpec { val mem = DefMemory(NoInfo, "m", UIntType(IntWidth(8)), size, 1, 1, List("r"), List("w"), List("rw")) val builder = ToProto.convert(mem).head val defaultProto = builder.build() - val oldProto = Firrtl.Statement.newBuilder().setMemory( - builder.getMemoryBuilder.clearDepth().setUintDepth(size) - ).build() + val oldProto = Firrtl.Statement + .newBuilder() + .setMemory( + builder.getMemoryBuilder.clearDepth().setUintDepth(size) + ) + .build() // These Proto messages are not the same - defaultProto shouldNot equal (oldProto) + defaultProto shouldNot equal(oldProto) val defaultMem = FromProto.convert(defaultProto) val oldMem = FromProto.convert(oldProto) // But they both deserialize to the original! - defaultMem should equal (mem) - oldMem should equal (mem) + defaultMem should equal(mem) + oldMem should equal(mem) } // Backwards compatibility @@ -164,43 +167,46 @@ class ProtoBufSpec extends FirrtlFlatSpec { val vtpe = ToProto.convert(VectorType(UIntType(IntWidth(8)), size)) val builder = ToProto.convert(cmem).head val defaultProto = builder.build() - val oldProto = Firrtl.Statement.newBuilder().setCmemory( - builder.getCmemoryBuilder.clearTypeAndDepth().setVectorType(vtpe) - ).build() + val oldProto = Firrtl.Statement + .newBuilder() + .setCmemory( + builder.getCmemoryBuilder.clearTypeAndDepth().setVectorType(vtpe) + ) + .build() // These Proto messages are not the same - defaultProto shouldNot equal (oldProto) + defaultProto shouldNot equal(oldProto) val defaultCMem = FromProto.convert(defaultProto) val oldCMem = FromProto.convert(oldProto) // But they both deserialize to the original! - defaultCMem should equal (cmem) - oldCMem should equal (cmem) + defaultCMem should equal(cmem) + oldCMem should equal(cmem) } // readunderwrite support it should "support readunderwrite parameters" in { val m1 = DefMemory(NoInfo, "m", UIntType(IntWidth(8)), 128, 1, 1, List("r"), List("w"), Nil, ir.ReadUnderWrite.Old) - FromProto.convert(ToProto.convert(m1).head.build) should equal (m1) + FromProto.convert(ToProto.convert(m1).head.build) should equal(m1) val m2 = m1.copy(readUnderWrite = ir.ReadUnderWrite.New) - FromProto.convert(ToProto.convert(m2).head.build) should equal (m2) + FromProto.convert(ToProto.convert(m2).head.build) should equal(m2) val cm1 = CDefMemory(NoInfo, "m", UIntType(IntWidth(8)), 128, true, ir.ReadUnderWrite.Old) - FromProto.convert(ToProto.convert(cm1).head.build) should equal (cm1) + FromProto.convert(ToProto.convert(cm1).head.build) should equal(cm1) val cm2 = cm1.copy(readUnderWrite = ir.ReadUnderWrite.New) - FromProto.convert(ToProto.convert(cm2).head.build) should equal (cm2) + FromProto.convert(ToProto.convert(cm2).head.build) should equal(cm2) } it should "support AsyncResetTypes" in { val port = ir.Port(ir.NoInfo, "reset", ir.Input, ir.AsyncResetType) - FromProto.convert(ToProto.convert(port).build) should equal (port) + FromProto.convert(ToProto.convert(port).build) should equal(port) } it should "support ResetTypes" in { val port = ir.Port(ir.NoInfo, "reset", ir.Input, ir.ResetType) - FromProto.convert(ToProto.convert(port).build) should equal (port) + FromProto.convert(ToProto.convert(port).build) should equal(port) } it should "support ValidIf" in { @@ -209,7 +215,7 @@ class ProtoBufSpec extends FirrtlFlatSpec { val vi = ir.ValidIf(en, value, value.tpe) // Deserialized has almost nothing filled in val expected = ir.ValidIf(ir.Reference("en"), ir.Reference("x"), UnknownType) - FromProto.convert(ToProto.convert(vi).build) should equal (expected) + FromProto.convert(ToProto.convert(vi).build) should equal(expected) } it should "appropriately escape and unescape FileInfo strings" in { @@ -220,10 +226,11 @@ class ProtoBufSpec extends FirrtlFlatSpec { "test\\]test" -> "test]test" ) - pairs.foreach { case (escaped, unescaped) => - val info = ir.FileInfo(escaped) - ToProto.convert(info).build().getText should equal (unescaped) - FromProto.convert(ToProto.convert(info).build) should equal (info) + pairs.foreach { + case (escaped, unescaped) => + val info = ir.FileInfo(escaped) + ToProto.convert(info).build().getText should equal(unescaped) + FromProto.convert(ToProto.convert(info).build) should equal(info) } } } diff --git a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala index dfef5955..d335becc 100644 --- a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala +++ b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala @@ -22,7 +22,8 @@ object RegisterUpdateSpec { override def invalidates(a: Transform): Boolean = false def execute(state: CircuitState): CircuitState = { val emittedAnno = EmittedFirrtlCircuitAnnotation( - EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir")) + EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir") + ) val capturedState = state.copy(annotations = emittedAnno +: state.annotations) state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations) } @@ -37,64 +38,61 @@ class RegisterUpdateSpec extends FirrtlFlatSpec { } def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } "Register update logic" should "not duplicate common subtrees" in { val result = compileBody(s""" - |input clock : Clock - |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} - |reg r : UInt<8>, clock - |when io.a : - | r <= io.in - |when io.b : - | when io.c : - | r <= UInt(2) - |io.out <= r""".stripMargin - ) + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock + |when io.a : + | r <= io.in + |when io.b : + | when io.c : + | r <= UInt(2) + |io.out <= r""".stripMargin) // Checking intermediate state between FlattenRegUpdate and Verilog emission val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get - fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + fstate should containLine("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") // Checking the Verilog val verilog = result.getEmittedCircuit.value - result shouldNot containLine ("r <= io_in;") - verilog shouldNot include ("if (io_a) begin") - result should containLine ("r <= _GEN_0;") + result shouldNot containLine("r <= io_in;") + verilog shouldNot include("if (io_a) begin") + result should containLine("r <= _GEN_0;") } it should "not let duplicate subtrees on one register affect another" in { val result = compileBody(s""" - |input clock : Clock - |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} - |reg r : UInt<8>, clock - |reg r2 : UInt<8>, clock - |when io.a : - | r <= io.in - | r2 <= io.in - |when io.b : - | r2 <= UInt(3) - | when io.c : - | r <= UInt(2) - |io.out <= and(r, r2)""".stripMargin - ) + |reg r : UInt<8>, clock + |reg r2 : UInt<8>, clock + |when io.a : + | r <= io.in + | r2 <= io.in + |when io.b : + | r2 <= UInt(3) + | when io.c : + | r <= UInt(2) + |io.out <= and(r, r2)""".stripMargin) // Checking intermediate state between FlattenRegUpdate and Verilog emission val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get - fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") - fstate should containLine ("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""") + fstate should containLine("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + fstate should containLine("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""") // Checking the Verilog val verilog = result.getEmittedCircuit.value - result shouldNot containLine ("r <= io_in;") - result should containLine ("r <= _GEN_0;") - result should containLine ("r2 <= io_in;") - verilog should include ("if (io_a) begin") // For r2 + result shouldNot containLine("r <= io_in;") + result should containLine("r <= _GEN_0;") + result should containLine("r2 <= io_in;") + verilog should include("if (io_a) begin") // For r2 // 1 time for r2, old versions would have 3 occurences - Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be (1) + Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be(1) } } - diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index df3ceef6..3438da67 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -14,9 +14,9 @@ class RemoveWiresSpec extends FirrtlFlatSpec { (new LowFirrtlCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } @@ -26,7 +26,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { val nodes = mutable.ArrayBuffer.empty[DefNode] val wires = mutable.ArrayBuffer.empty[DefWire] def onStmt(stmt: Statement): Statement = { - stmt map onStmt match { + stmt.map(onStmt) match { case node: DefNode => nodes += node case wire: DefWire => wires += wire case _ => @@ -35,7 +35,7 @@ class RemoveWiresSpec extends FirrtlFlatSpec { } circuit.modules.head match { - case Module(_,_,_, body) => onStmt(body) + case Module(_, _, _, body) => onStmt(body) } (nodes.toSeq, wires.toSeq) } @@ -44,98 +44,90 @@ class RemoveWiresSpec extends FirrtlFlatSpec { require(circuit.modules.size == 1) val names = mutable.ArrayBuffer.empty[String] def onStmt(stmt: Statement): Statement = { - stmt map onStmt match { - case reg: DefRegister => names += reg.name - case wire: DefWire => names += wire.name - case node: DefNode => names += node.name + stmt.map(onStmt) match { + case reg: DefRegister => names += reg.name + case wire: DefWire => names += wire.name + case node: DefNode => names += node.name case _ => } stmt } circuit.modules.head match { - case Module(_,_,_, body) => onStmt(body) + case Module(_, _, _, body) => onStmt(body) } names.toSeq } "Remove Wires" should "turn wires and their single connect into nodes" in { val result = compileBody(s""" - |input a : UInt<8> - |output b : UInt<8> - |wire w : UInt<8> - |w <= a - |b <= w""".stripMargin - ) + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |w <= a + |b <= w""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) + wires.size should be(0) - nodes.map(_.serialize) should be (Seq("node w = a")) + nodes.map(_.serialize) should be(Seq("node w = a")) } it should "order nodes in a legal, flow-forward way" in { val result = compileBody(s""" - |input a : UInt<8> - |output b : UInt<8> - |wire w : UInt<8> - |wire x : UInt<8> - |node y = x - |x <= w - |w <= a - |b <= y""".stripMargin - ) + |input a : UInt<8> + |output b : UInt<8> + |wire w : UInt<8> + |wire x : UInt<8> + |node y = x + |x <= w + |w <= a + |b <= y""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( - Seq("node w = a", - "node x = w", - "node y = x") + wires.size should be(0) + nodes.map(_.serialize) should be( + Seq("node w = a", "node x = w", "node y = x") ) } it should "properly pad rhs of introduced nodes if necessary" in { val result = compileBody(s""" - |output b : UInt<8> - |wire w : UInt<8> - |w <= UInt(2) - |b <= w""".stripMargin - ) + |output b : UInt<8> + |wire w : UInt<8> + |w <= UInt(2) + |b <= w""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( + wires.size should be(0) + nodes.map(_.serialize) should be( Seq("""node w = pad(UInt<2>("h2"), 8)""") ) } it should "support arbitrary expression for wire connection rhs" in { val result = compileBody(s""" - |input a : UInt<8> - |input b : UInt<8> - |output c : UInt<8> - |wire w : UInt<8> - |w <= tail(add(a, b), 1) - |c <= w""".stripMargin - ) + |input a : UInt<8> + |input b : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |w <= tail(add(a, b), 1) + |c <= w""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( + wires.size should be(0) + nodes.map(_.serialize) should be( Seq("""node w = tail(add(a, b), 1)""") ) } it should "do a reasonable job preserving input order for unrelatd logic" in { val result = compileBody(s""" - |input a : UInt<8> - |input b : UInt<8> - |output z : UInt<8> - |node x = not(a) - |node y = not(b) - |z <= and(x, y)""".stripMargin - ) + |input a : UInt<8> + |input b : UInt<8> + |output z : UInt<8> + |node x = not(a) + |node y = not(b) + |z <= and(x, y)""".stripMargin) val (nodes, wires) = getNodesAndWires(result.circuit) - wires.size should be (0) - nodes.map(_.serialize) should be ( - Seq("node x = not(a)", - "node y = not(b)") + wires.size should be(0) + nodes.map(_.serialize) should be( + Seq("node x = not(a)", "node y = not(b)") ) } @@ -148,52 +140,49 @@ class RemoveWiresSpec extends FirrtlFlatSpec { |""".stripMargin ) val names = orderedNames(result.circuit) - names should be (Seq("a", "clock2", "b")) + names should be(Seq("a", "clock2", "b")) } it should "order registers correctly" in { val result = compileBody(s""" - |input clock : Clock - |input a : UInt<8> - |output c : UInt<8> - |wire w : UInt<8> - |node n = tail(add(w, UInt(1)), 1) - |reg r : UInt<8>, clock - |w <= tail(add(r, a), 1) - |c <= n""".stripMargin - ) + |input clock : Clock + |input a : UInt<8> + |output c : UInt<8> + |wire w : UInt<8> + |node n = tail(add(w, UInt(1)), 1) + |reg r : UInt<8>, clock + |w <= tail(add(r, a), 1) + |c <= n""".stripMargin) // Check declaration before use is maintained firrtl.passes.CheckHighForm.execute(result) } it should "order registers with async reset correctly" in { val result = compileBody(s""" - |input clock : Clock - |input reset : UInt<1> - |input in : UInt<8> - |output out : UInt<8> - |wire areset : AsyncReset - |reg r : UInt<8>, clock with : (reset => (areset, UInt(0))) - |areset <= asAsyncReset(reset) - |r <= in - |out <= r - |""".stripMargin - ) + |input clock : Clock + |input reset : UInt<1> + |input in : UInt<8> + |output out : UInt<8> + |wire areset : AsyncReset + |reg r : UInt<8>, clock with : (reset => (areset, UInt(0))) + |areset <= asAsyncReset(reset) + |r <= in + |out <= r + |""".stripMargin) // Check declaration before use is maintained firrtl.passes.CheckHighForm.execute(result) } it should "order registers respecting initializations" in { - val result = compileBody( - s"""|input clock : Clock - |input foo : UInt<2> - |output bar : UInt<2> - |wire y_fault : UInt<2> - |reg y : UInt<2>, clock with : - | reset => (UInt<1>("h0"), y_fault) - |y_fault <= foo - |bar <= y - |""".stripMargin) + val result = compileBody(s"""|input clock : Clock + |input foo : UInt<2> + |output bar : UInt<2> + |wire y_fault : UInt<2> + |reg y : UInt<2>, clock with : + | reset => (UInt<1>("h0"), y_fault) + |y_fault <= foo + |bar <= y + |""".stripMargin) // Check declaration before use is maintained firrtl.passes.CheckHighForm.execute(result) } diff --git a/src/test/scala/firrtlTests/RenameMapSpec.scala b/src/test/scala/firrtlTests/RenameMapSpec.scala index 7931b94f..609d8eef 100644 --- a/src/test/scala/firrtlTests/RenameMapSpec.scala +++ b/src/test/scala/firrtlTests/RenameMapSpec.scala @@ -8,10 +8,10 @@ import firrtl.annotations._ import firrtl.testutils._ class RenameMapSpec extends FirrtlFlatSpec { - val cir = CircuitTarget("Top") - val cir2 = CircuitTarget("Pot") - val cir3 = CircuitTarget("Cir3") - val modA = cir.module("A") + val cir = CircuitTarget("Top") + val cir2 = CircuitTarget("Pot") + val cir3 = CircuitTarget("Cir3") + val modA = cir.module("A") val modA2 = cir2.module("A") val modB = cir.module("B") val foo = modA.ref("foo") @@ -26,69 +26,69 @@ class RenameMapSpec extends FirrtlFlatSpec { val middle = cir.module("Middle") val middle2 = cir.module("Middle2") - behavior of "RenameMap" + behavior.of("RenameMap") it should "return None if it does not rename something" in { val renames = RenameMap() - renames.get(modA) should be (None) - renames.get(foo) should be (None) + renames.get(modA) should be(None) + renames.get(foo) should be(None) } it should "return a Seq of renamed things if it does rename something" in { val renames = RenameMap() renames.record(foo, bar) - renames.get(foo) should be (Some(Seq(bar))) + renames.get(foo) should be(Some(Seq(bar))) } it should "allow something to be renamed to multiple things" in { val renames = RenameMap() renames.record(foo, bar) renames.record(foo, fizz) - renames.get(foo) should be (Some(Seq(bar, fizz))) + renames.get(foo) should be(Some(Seq(bar, fizz))) } it should "allow something to be renamed to nothing (ie. deleted)" in { val renames = RenameMap() renames.record(foo, Seq()) - renames.get(foo) should be (Some(Seq())) + renames.get(foo) should be(Some(Seq())) } it should "return None if something is renamed to itself" in { val renames = RenameMap() renames.record(foo, foo) - renames.get(foo) should be (None) + renames.get(foo) should be(None) } it should "allow targets to change module" in { val renames = RenameMap() renames.record(foo, fooB) - renames.get(foo) should be (Some(Seq(fooB))) + renames.get(foo) should be(Some(Seq(fooB))) } it should "rename targets if their module is renamed" in { val renames = RenameMap() renames.record(modA, modB) - renames.get(foo) should be (Some(Seq(fooB))) - renames.get(bar) should be (Some(Seq(barB))) + renames.get(foo) should be(Some(Seq(fooB))) + renames.get(bar) should be(Some(Seq(barB))) } it should "not rename already renamed targets if the module of the target is renamed" in { val renames = RenameMap() renames.record(modA, modB) renames.record(foo, bar) - renames.get(foo) should be (Some(Seq(bar))) + renames.get(foo) should be(Some(Seq(bar))) } it should "rename modules if their circuit is renamed" in { val renames = RenameMap() renames.record(cir, cir2) - renames.get(modA) should be (Some(Seq(modA2))) + renames.get(modA) should be(Some(Seq(modA2))) } it should "rename targets if their circuit is renamed" in { val renames = RenameMap() renames.record(cir, cir2) - renames.get(foo) should be (Some(Seq(foo2))) + renames.get(foo) should be(Some(Seq(foo2))) } val TopCircuit = cir @@ -105,44 +105,44 @@ class RenameMapSpec extends FirrtlFlatSpec { it should "rename targets if modules in the path are renamed" in { val renames = RenameMap() renames.record(Middle, Middle2) - renames.get(Top_m) should be (Some(Seq(Top.instOf("m", "Middle2")))) + renames.get(Top_m) should be(Some(Seq(Top.instOf("m", "Middle2")))) } it should "rename only the instance if instance and module in the path are renamed" in { val renames = RenameMap() renames.record(Middle, Middle2) renames.record(Top.instOf("m", "Middle"), Top.instOf("m2", "Middle")) - renames.get(Top_m) should be (Some(Seq(Top.instOf("m2", "Middle")))) + renames.get(Top_m) should be(Some(Seq(Top.instOf("m2", "Middle")))) } it should "rename targets if instance in the path are renamed" in { val renames = RenameMap() renames.record(Top.instOf("m", "Middle"), Top.instOf("m2", "Middle")) - renames.get(Top_m) should be (Some(Seq(Top.instOf("m2", "Middle")))) + renames.get(Top_m) should be(Some(Seq(Top.instOf("m2", "Middle")))) } it should "rename targets if instance and ofmodule in the path are renamed" in { val renames = RenameMap() val Top_m2 = Top.instOf("m2", "Middle2") renames.record(Top_m, Top_m2) - renames.get(Top_m) should be (Some(Seq(Top_m2))) + renames.get(Top_m) should be(Some(Seq(Top_m2))) } it should "properly do nothing if no remaps" in { val renames = RenameMap() - renames.get(Top_m_l_a) should be (None) + renames.get(Top_m_l_a) should be(None) } it should "properly rename if leaf is inlined" in { val renames = RenameMap() renames.record(Middle_l_a, Middle_la) - renames.get(Top_m_l_a) should be (Some(Seq(Top_m_la))) + renames.get(Top_m_l_a) should be(Some(Seq(Top_m_la))) } it should "properly rename if middle is inlined" in { val renames = RenameMap() renames.record(Top_m_l, Top.instOf("m_l", "Leaf")) - renames.get(Top_m_l_a) should be (Some(Seq(Top.instOf("m_l", "Leaf").ref("a")))) + renames.get(Top_m_l_a) should be(Some(Seq(Top.instOf("m_l", "Leaf").ref("a")))) } it should "properly rename if leaf and middle are inlined" in { @@ -151,18 +151,20 @@ class RenameMapSpec extends FirrtlFlatSpec { renames.record(Top_m_l_a, inlined) renames.record(Top_m_l, Nil) renames.record(Top_m, Nil) - renames.get(Top_m_l_a) should be (Some(Seq(inlined))) + renames.get(Top_m_l_a) should be(Some(Seq(inlined))) } it should "quickly rename a target with a long path" in { (0 until 50 by 10).foreach { endIdx => val renames = RenameMap() renames.record(TopCircuit.module("Y0"), TopCircuit.module("X0")) - val deepTarget = (0 until endIdx).foldLeft(Top: IsModule) { (t, idx) => - t.instOf("a", "A" + idx) - }.ref("ref") + val deepTarget = (0 until endIdx) + .foldLeft(Top: IsModule) { (t, idx) => + t.instOf("a", "A" + idx) + } + .ref("ref") val (millis, rename) = firrtl.Utils.time(renames.get(deepTarget)) - //rename should be(None) + //rename should be(None) } } @@ -171,7 +173,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle2 = cir.module("Middle2") renames.record(Middle, Middle2) renames.record(Middle.ref("l"), Middle.ref("lx")) - renames.get(Middle.ref("l")) should be (Some(Seq(Middle.ref("lx")))) + renames.get(Middle.ref("l")) should be(Some(Seq(Middle.ref("lx")))) } it should "rename with fields" in { @@ -181,7 +183,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle_i_f = Middle.ref("i").field("f") val renames = RenameMap() renames.record(Middle_o, Middle_i) - renames.get(Middle_o_f) should be (Some(Seq(Middle_i_f))) + renames.get(Middle_o_f) should be(Some(Seq(Middle_i_f))) } it should "rename instances with same ofModule" in { @@ -189,7 +191,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle_i = Middle.instOf("i", "O") val renames = RenameMap() renames.record(Middle_o, Middle_i) - renames.get(Middle.instOf("o", "O")) should be (Some(Seq(Middle.instOf("i", "O")))) + renames.get(Middle.instOf("o", "O")) should be(Some(Seq(Middle.instOf("i", "O")))) } it should "not treat references as instances targets" in { @@ -197,14 +199,14 @@ class RenameMapSpec extends FirrtlFlatSpec { val Middle_i = Middle.ref("i") val renames = RenameMap() renames.record(Middle_o, Middle_i) - renames.get(Middle.instOf("o", "O")) should be (None) + renames.get(Middle.instOf("o", "O")) should be(None) } it should "be able to rename weird stuff" in { // Renaming `from` to each of the `tos` at the same time should be ok case class BadRename(from: CompleteTarget, tos: Seq[CompleteTarget]) val badRenames = - Seq(//BadRename(foo, Seq(cir)), + Seq( //BadRename(foo, Seq(cir)), //BadRename(foo, Seq(modB)), //BadRename(modA, Seq(fooB)), //BadRename(modA, Seq(cir)), @@ -217,17 +219,17 @@ class RenameMapSpec extends FirrtlFlatSpec { val fromN = from val tosN = tos.mkString(", ") //it should s"error if a $fromN is renamed to $tosN" in { - val renames = RenameMap() - for (to <- tos) { - (from, to) match { - case (f: CircuitTarget, t: CircuitTarget) => renames.record(f, t) - case (f: IsMember, t: IsMember) => renames.record(f, t) - case _ => sys.error("Unexpected!") - } + val renames = RenameMap() + for (to <- tos) { + (from, to) match { + case (f: CircuitTarget, t: CircuitTarget) => renames.record(f, t) + case (f: IsMember, t: IsMember) => renames.record(f, t) + case _ => sys.error("Unexpected!") } - //a [FIRRTLException] shouldBe thrownBy { - renames.get(from) - //} + } + //a [FIRRTLException] shouldBe thrownBy { + renames.get(from) + //} //} } } @@ -247,8 +249,8 @@ class RenameMapSpec extends FirrtlFlatSpec { val top = CircuitTarget("Top") renames.record(top.module("A"), top.module("B")) renames.record(top.module("B"), top.module("A")) - renames.get(top.module("A")) should be (Some(Seq(top.module("B")))) - renames.get(top.module("B")) should be (Some(Seq(top.module("A")))) + renames.get(top.module("A")) should be(Some(Seq(top.module("B")))) + renames.get(top.module("B")) should be(Some(Seq(top.module("A")))) } it should "error if a reference is renamed to a module and vice versa" in { @@ -256,10 +258,10 @@ class RenameMapSpec extends FirrtlFlatSpec { val top = CircuitTarget("Top") renames.record(top.module("A").ref("ref"), top.module("B")) renames.record(top.module("C"), top.module("D").ref("ref")) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("C")) } - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("A").ref("ref").field("field")) } renames.get(top.module("A").instOf("ref", "R")) should be(None) @@ -270,7 +272,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val top = CircuitTarget("Top") renames.record(top.module("C"), top.module("D").ref("x")) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("A").instOf("c", "C")) } } @@ -281,7 +283,7 @@ class RenameMapSpec extends FirrtlFlatSpec { renames.record(top.module("E").instOf("f", "F"), top.module("E").ref("g")) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(top.module("E").instOf("f", "F").ref("g")) } } @@ -403,7 +405,7 @@ class RenameMapSpec extends FirrtlFlatSpec { .ref("ref") .field("f1") .field("f2") - val to2 = modA + val to2 = modA .instOf("b", "B") .instOf("c", "C") .ref("ref") @@ -417,7 +419,7 @@ class RenameMapSpec extends FirrtlFlatSpec { .instOf("c", "C") .ref("ref") .field("f1") - val to3 = modB + val to3 = modB .instOf("c", "C") .ref("ref") .field("f11") @@ -426,7 +428,7 @@ class RenameMapSpec extends FirrtlFlatSpec { // to: ~Top|C>refref // renamed last because it has no path val from4 = modC.ref("ref") - val to4 = modC.ref("refref") + val to4 = modC.ref("refref") val renames1 = RenameMap() renames1.record(from1, to1) @@ -435,14 +437,17 @@ class RenameMapSpec extends FirrtlFlatSpec { renames1.record(from4, to4) renames1.get(from1) should be { - Some(Seq(modA - .instOf("b", "B") - .instOf("c", "C") - .ref("ref") - .field("f1") - .field("f2") - .field("f33") - )) + Some( + Seq( + modA + .instOf("b", "B") + .instOf("c", "C") + .ref("ref") + .field("f1") + .field("f2") + .field("f33") + ) + ) } val renames2 = RenameMap() @@ -451,14 +456,17 @@ class RenameMapSpec extends FirrtlFlatSpec { renames2.record(from4, to4) renames2.get(from1) should be { - Some(Seq(modA - .instOf("b", "B") - .instOf("c", "C") - .ref("ref") - .field("f1") - .field("f22") - .field("f3") - )) + Some( + Seq( + modA + .instOf("b", "B") + .instOf("c", "C") + .ref("ref") + .field("f1") + .field("f22") + .field("f3") + ) + ) } val renames3 = RenameMap() @@ -466,14 +474,17 @@ class RenameMapSpec extends FirrtlFlatSpec { renames3.record(from4, to4) renames3.get(from1) should be { - Some(Seq(modA - .instOf("b", "B") - .instOf("c", "C") - .ref("ref") - .field("f11") - .field("f2") - .field("f3") - )) + Some( + Seq( + modA + .instOf("b", "B") + .instOf("c", "C") + .ref("ref") + .field("f11") + .field("f2") + .field("f3") + ) + ) } } @@ -498,8 +509,18 @@ class RenameMapSpec extends FirrtlFlatSpec { val to = cir.module("D").instOf("e", "E").instOf("f", "F").ref("foo").field("foo") renames.record(from, to) renames.get(cir.module("A").instOf("b", "B").instOf("c", "C").ref("foo").field("bar")) should be { - Some(Seq(cir.module("A").instOf("b", "B").instOf("c", "D") - .instOf("e", "E").instOf("f", "F").ref("foo").field("foo"))) + Some( + Seq( + cir + .module("A") + .instOf("b", "B") + .instOf("c", "D") + .instOf("e", "E") + .instOf("f", "F") + .ref("foo") + .field("foo") + ) + ) } } @@ -509,7 +530,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val from = top.instOf("a", "A") val to = top.ref("b") renames.record(from, to) - a [IllegalRenameException] shouldBe thrownBy { + a[IllegalRenameException] shouldBe thrownBy { renames.get(from) } } @@ -520,7 +541,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val from = top.ref("a") val to = top.ref("b") renames.record(from, to) - renames.get(top.instOf("a", "Foo")) should be (None) + renames.get(top.instOf("a", "Foo")) should be(None) } it should "correctly chain renames together" in { @@ -651,8 +672,8 @@ class RenameMapSpec extends FirrtlFlatSpec { val dupMod1 = top.module("A1") val dupMod2 = top.module("A2") - val relPath1 = dupMod1.addHierarchy("Foo", "a")//top.module("Foo").instOf("a", "A1") - val relPath2 = dupMod2.addHierarchy("Foo", "a")//top.module("Foo").instOf("a", "A2") + val relPath1 = dupMod1.addHierarchy("Foo", "a") //top.module("Foo").instOf("a", "A1") + val relPath2 = dupMod2.addHierarchy("Foo", "a") //top.module("Foo").instOf("a", "A2") val absPath1 = relPath1.addHierarchy("Top", "foo") val absPath2 = relPath2.addHierarchy("Top", "foo") @@ -766,7 +787,7 @@ class RenameMapSpec extends FirrtlFlatSpec { r.record(foo, foo) r.get(foo) should not be (empty) - r.get(foo).get should contain allOf (foo, bar) + (r.get(foo).get should contain).allOf(foo, bar) } it should "not record the same rename multiple times" in { @@ -807,7 +828,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val r = RenameMap() r.delete(Mod) - r.get(foo) should be (Some(Nil)) + r.get(foo) should be(Some(Nil)) } it should "rename an instance if it has been renamed" in { @@ -818,8 +839,8 @@ class RenameMapSpec extends FirrtlFlatSpec { val i = top.instOf("i", "child") val i_ = top.instOf("i_", "child") r.record(i, i_) - r.get(i) should be (Some(Seq(i_))) - r.get(i.ref("a")) should be (Some(Seq(i_.ref("a")))) + r.get(i) should be(Some(Seq(i_))) + r.get(i.ref("a")) should be(Some(Seq(i_.ref("a")))) } it should "rename references to an instance's ports if the ports of the module have been renamed" in { @@ -830,7 +851,7 @@ class RenameMapSpec extends FirrtlFlatSpec { val r = RenameMap() r.record(child.ref("a"), Seq(child.ref("a_0"), child.ref("a_1"))) val i = top.instOf("i", "child") - r.get(i.ref("a")) should be (Some(Seq(i.ref("a_0"), i.ref("a_1")))) + r.get(i.ref("a")) should be(Some(Seq(i.ref("a_0"), i.ref("a_1")))) } it should "rename references to renamed instance's ports if the ports of the module have been renamed" in { @@ -848,6 +869,6 @@ class RenameMapSpec extends FirrtlFlatSpec { // The port and instance renames must be *explicitly* chained! val r = portRenames.andThen(instanceRenames) - r.get(i.ref("a")) should be (Some(Seq(i_.ref("a_0"), i_.ref("a_1")))) + r.get(i.ref("a")) should be(Some(Seq(i_.ref("a_0"), i_.ref("a_1")))) } } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index cd2fdb05..17f4dcfd 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -25,7 +25,8 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) + def transforms = + Seq(new ConstantPropagation, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) @@ -35,7 +36,12 @@ class ReplSeqMemSpec extends SimpleTransformSpec { // Verify that this does not throw an exception val fromConf = MemConf.fromString(text) // Verify the mems in the conf are the same as the expected ones - require(Set(fromConf: _*) == mems, "Parsed conf set:\n {\n " + fromConf.mkString(" ") + " }\n must be the same as reference conf set: \n {\n " + mems.toSeq.mkString(" ") + " }\n") + require( + Set(fromConf: _*) == mems, + "Parsed conf set:\n {\n " + fromConf.mkString( + " " + ) + " }\n must be the same as reference conf set: \n {\n " + mems.toSeq.mkString(" ") + " }\n" + ) } "ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in { @@ -63,7 +69,7 @@ circuit Top : MemConf("entries_info_ext", 24, 30, Map(WritePort -> 1, ReadPort -> 1), None) ) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -88,7 +94,7 @@ circuit Top : """.stripMargin val mems = Set(MemConf("mem_ext", 32, 64, Map(MaskedWritePort -> 1), Some(64))) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -116,7 +122,7 @@ circuit CustomMemory : """.stripMargin val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -144,7 +150,7 @@ circuit CustomMemory : """.stripMargin val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl parse(res.getEmittedCircuit.value) @@ -153,8 +159,8 @@ circuit CustomMemory : (new java.io.File(confLoc)).delete() } - "ReplSeqMem Utility -- getConnectOrigin" should - "determine connect origin across nodes/PrimOps even if ConstProp isn't performed" in { + "ReplSeqMem Utility -- getConnectOrigin" should + "determine connect origin across nodes/PrimOps even if ConstProp isn't performed" in { def checkConnectOrigin(hurdle: String, origin: String) = { val input = s""" circuit Top : @@ -172,7 +178,7 @@ circuit Top : val circuit = InferTypes.run(ToWorkingIR.run(parse(input))) val m = circuit.modules.head.asInstanceOf[ir.Module] val connects = AnalysisUtils.getConnects(m) - val calculatedOrigin = AnalysisUtils.getOrigin(connects, "f").serialize + val calculatedOrigin = AnalysisUtils.getOrigin(connects, "f").serialize require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !") } @@ -195,7 +201,7 @@ circuit Top : "validif(a, b)" -> "b" ) - tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) } + tests.foreach { case (hurdle, origin) => checkConnectOrigin(hurdle, origin) } } @@ -226,16 +232,17 @@ circuit CustomMemory : ) val confLoc = "ReplSeqMemTests.confTEMP" val annos = Seq( - ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("CustomMemory",CircuitName("CustomMemory"))))) + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), + NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("CustomMemory", CircuitName("CustomMemory")))) + ) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } - numExtMods should be (2) + numExtMods should be(2) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -272,16 +279,17 @@ circuit CustomMemory : ) val confLoc = "ReplSeqMemTests.confTEMP" val annos = Seq( - ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - NoDedupMemAnnotation(ComponentName("mem_1", ModuleName("CustomMemory",CircuitName("CustomMemory"))))) + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), + NoDedupMemAnnotation(ComponentName("mem_1", ModuleName("CustomMemory", CircuitName("CustomMemory")))) + ) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } - numExtMods should be (2) + numExtMods should be(2) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -329,20 +337,21 @@ circuit CustomMemory : ) val confLoc = "ReplSeqMemTests.confTEMP" val annos = Seq( - ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("ChildMemory",CircuitName("CustomMemory"))))) + ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), + NoDedupMemAnnotation(ComponentName("mem_0", ModuleName("ChildMemory", CircuitName("CustomMemory")))) + ) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } // Note that there are 3 identical SeqMems in this test // If the NoDedupMemAnnotation were ignored, we'd end up with just 1 ExtModule // If the NoDedupMemAnnotation were handled incorrectly as it was prior to this test, there // would be 3 ExtModules - numExtMods should be (2) + numExtMods should be(2) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -371,12 +380,12 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl val circuit = parse(res.getEmittedCircuit.value) val numExtMods = circuit.modules.count { - case e: ExtModule => true + case e: ExtModule => true case _ => false } require(numExtMods == 1) @@ -400,9 +409,9 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_ext", 1024, 16, Map(WritePort -> 1, ReadPort -> 1), None)) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) - res.getEmittedCircuit.value shouldNot include ("mask") + res.getEmittedCircuit.value shouldNot include("mask") // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -428,11 +437,11 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedWritePort -> 1, ReadPort -> 1), Some(8))) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask - res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)") - res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)") + res should containLine("mem.W0_mask_0 <= validif(io_en, io_mask_0)") + res should containLine("mem.W0_mask_1 <= validif(io_en, io_mask_1)") // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -462,12 +471,11 @@ circuit CustomMemory : """ val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedReadWritePort -> 1), Some(8))) val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - InferReadWriteAnnotation) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask - res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)") - res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)") + res should containLine("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)") + res should containLine("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)") // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() @@ -487,15 +495,14 @@ circuit NoMemsHere : """ val mems = Set.empty[MemConf] val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), - InferReadWriteAnnotation) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:" + confLoc), InferReadWriteAnnotation) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check the emitted conf checkMemConf(confLoc, mems) (new java.io.File(confLoc)).delete() } - "ReplSeqMem" should "throw an exception when encountering masks with variable granularity" in { + "ReplSeqMem" should "throw an exception when encountering masks with variable granularity" in { val input = """ circuit Top : module Top : @@ -518,10 +525,9 @@ circuit Top : """.stripMargin intercept[ReplaceMemMacros.UnsupportedBlackboxMemoryException] { val confLoc = "ReplSeqMemTests.confTEMP" - val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) + val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:" + confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) } } } - diff --git a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala index fcf36876..588b7c39 100644 --- a/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala +++ b/src/test/scala/firrtlTests/ReplaceAccessesSpec.scala @@ -7,17 +7,14 @@ import firrtl.passes._ import firrtl.testutils._ class ReplaceAccessesSpec extends FirrtlFlatSpec { - val transforms = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - ReplaceAccesses) + val transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ReplaceAccesses) protected def exec(input: String) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } } @@ -40,7 +37,7 @@ class ReplaceAccessesMultiDim extends ReplaceAccessesSpec { reset => (UInt<1>(0), r_vec) out <= r_vec[2][1] """ - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } "ReplacesAccesses" should "NOT generate out-of-bounds indices" in { @@ -61,6 +58,6 @@ class ReplaceAccessesMultiDim extends ReplaceAccessesSpec { reset => (UInt<1>(0), r_vec) out <= r_vec[1][UInt<3>(8)] """ - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } } diff --git a/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala b/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala index 05a5fe29..ea01ca00 100644 --- a/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala +++ b/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala @@ -11,50 +11,46 @@ class ReplaceTruncatingArithmeticSpec extends FirrtlFlatSpec { (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) def compileBody(body: String) = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } "Truncting addition" should "be inferred and emitted in Verilog" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |z <= tail(add(x, y), 1)""".stripMargin - ) - result should containLine (s"assign z = x + y;") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |z <= tail(add(x, y), 1)""".stripMargin) + result should containLine(s"assign z = x + y;") } it should "be inferred and emitted in Verilog even with an intermediate node" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |node n = add(x, y) - |z <= tail(n, 1)""".stripMargin - ) - result should containLine (s"assign z = x + y;") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |node n = add(x, y) + |z <= tail(n, 1)""".stripMargin) + result should containLine(s"assign z = x + y;") } "Truncting subtraction" should "be inferred and emitted in Verilog" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<8> - |z <= tail(sub(x, y), 1)""".stripMargin - ) - result should containLine (s"assign z = x - y;") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |z <= tail(sub(x, y), 1)""".stripMargin) + result should containLine(s"assign z = x - y;") } "Tailing more than 1" should "not result in a truncating operator" in { val result = compileBody(s""" - |input x : UInt<8> - |input y : UInt<8> - |output z : UInt<7> - |node n = sub(x, y) - |z <= tail(n, 2)""".stripMargin - ) - result should containLine (s"wire [8:0] n = x - y;") - result should containLine (s"assign z = n[6:0];") + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<7> + |node n = sub(x, y) + |z <= tail(n, 2)""".stripMargin) + result should containLine(s"wire [8:0] n = x - y;") + result should containLine(s"assign z = n[6:0];") } } diff --git a/src/test/scala/firrtlTests/SimplifyMemsSpec.scala b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala index ec947ecf..c7d04d46 100644 --- a/src/test/scala/firrtlTests/SimplifyMemsSpec.scala +++ b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala @@ -12,73 +12,73 @@ class SimplifyMemsSpec extends ConstantPropagationSpec { "SimplifyMems" should "lower aggregate memories" in { val input = - """circuit Test : - | module Test : - | input clock : Clock - | input wen : UInt<1> - | input wdata : { a : UInt<8>, b : UInt<8> } - | output rdata : { a : UInt<8>, b : UInt<8> } - | mem m : - | data-type => { a : UInt<8>, b : UInt<8>} - | depth => 32 - | read-latency => 1 - | write-latency => 1 - | reader => read - | writer => write - | m.read.clk <= clock - | m.read.en <= UInt<1>(1) - | m.read.addr is invalid - | rdata <= m.read.data - | m.write.clk <= clock - | m.write.en <= wen - | m.write.mask.a <= UInt<1>(1) - | m.write.mask.b <= UInt<1>(1) - | m.write.addr is invalid - | m.write.data <= wdata + """circuit Test : + | module Test : + | input clock : Clock + | input wen : UInt<1> + | input wdata : { a : UInt<8>, b : UInt<8> } + | output rdata : { a : UInt<8>, b : UInt<8> } + | mem m : + | data-type => { a : UInt<8>, b : UInt<8>} + | depth => 32 + | read-latency => 1 + | write-latency => 1 + | reader => read + | writer => write + | m.read.clk <= clock + | m.read.en <= UInt<1>(1) + | m.read.addr is invalid + | rdata <= m.read.data + | m.write.clk <= clock + | m.write.en <= wen + | m.write.mask.a <= UInt<1>(1) + | m.write.mask.b <= UInt<1>(1) + | m.write.addr is invalid + | m.write.data <= wdata """.stripMargin val check = - """circuit Test : - | module Test : - | input clock : Clock - | input wen : UInt<1> - | input wdata : { a : UInt<8>, b : UInt<8>} - | output rdata : { a : UInt<8>, b : UInt<8>} - | - | wire m : { flip read : { addr : UInt<5>, en : UInt<1>, clk : Clock, flip data : { a : UInt<8>, b : UInt<8>}}, flip write : { addr : UInt<5>, en : UInt<1>, clk : Clock, data : { a : UInt<8>, b : UInt<8>}, mask : { a : UInt<1>, b : UInt<1>}}} - | mem m_flattened : - | data-type => UInt<16> - | depth => 32 - | read-latency => 1 - | write-latency => 1 - | reader => read - | writer => write - | read-under-write => undefined - | m_flattened.read.addr <= m.read.addr - | m_flattened.read.en <= m.read.en - | m_flattened.read.clk <= m.read.clk - | m.read.data.b <= asUInt(bits(m_flattened.read.data, 7, 0)) - | m.read.data.a <= asUInt(bits(m_flattened.read.data, 15, 8)) - | m_flattened.write.addr <= m.write.addr - | m_flattened.write.en <= m.write.en - | m_flattened.write.clk <= m.write.clk - | m_flattened.write.data <= cat(asUInt(m.write.data.a), asUInt(m.write.data.b)) - | m_flattened.write.mask <= UInt<1>("h1") - | rdata.a <= m.read.data.a - | rdata.b <= m.read.data.b - | m.read.addr is invalid - | m.read.en <= UInt<1>("h1") - | m.read.clk <= clock - | m.write.addr is invalid - | m.write.en <= wen - | m.write.clk <= clock - | m.write.data.a <= wdata.a - | m.write.data.b <= wdata.b - | m.write.mask.a <= UInt<1>("h1") - | m.write.mask.b <= UInt<1>("h1") + """circuit Test : + | module Test : + | input clock : Clock + | input wen : UInt<1> + | input wdata : { a : UInt<8>, b : UInt<8>} + | output rdata : { a : UInt<8>, b : UInt<8>} + | + | wire m : { flip read : { addr : UInt<5>, en : UInt<1>, clk : Clock, flip data : { a : UInt<8>, b : UInt<8>}}, flip write : { addr : UInt<5>, en : UInt<1>, clk : Clock, data : { a : UInt<8>, b : UInt<8>}, mask : { a : UInt<1>, b : UInt<1>}}} + | mem m_flattened : + | data-type => UInt<16> + | depth => 32 + | read-latency => 1 + | write-latency => 1 + | reader => read + | writer => write + | read-under-write => undefined + | m_flattened.read.addr <= m.read.addr + | m_flattened.read.en <= m.read.en + | m_flattened.read.clk <= m.read.clk + | m.read.data.b <= asUInt(bits(m_flattened.read.data, 7, 0)) + | m.read.data.a <= asUInt(bits(m_flattened.read.data, 15, 8)) + | m_flattened.write.addr <= m.write.addr + | m_flattened.write.en <= m.write.en + | m_flattened.write.clk <= m.write.clk + | m_flattened.write.data <= cat(asUInt(m.write.data.a), asUInt(m.write.data.b)) + | m_flattened.write.mask <= UInt<1>("h1") + | rdata.a <= m.read.data.a + | rdata.b <= m.read.data.b + | m.read.addr is invalid + | m.read.en <= UInt<1>("h1") + | m.read.clk <= clock + | m.write.addr is invalid + | m.write.en <= wen + | m.write.clk <= clock + | m.write.data.a <= wdata.a + | m.write.data.b <= wdata.b + | m.write.mask.a <= UInt<1>("h1") + | m.write.mask.b <= UInt<1>("h1") """.stripMargin - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } } diff --git a/src/test/scala/firrtlTests/StringSpec.scala b/src/test/scala/firrtlTests/StringSpec.scala index 30535466..fc2fa486 100644 --- a/src/test/scala/firrtlTests/StringSpec.scala +++ b/src/test/scala/firrtlTests/StringSpec.scala @@ -21,7 +21,7 @@ class PrintfSpec extends FirrtlPropSpec { copyResourceToFile(cppHarnessResourceName, harness) verilogToCpp(prefix, testDir, Seq(), harness) #&& - cppToExe(prefix, testDir) ! loggingProcessLogger + cppToExe(prefix, testDir) ! loggingProcessLogger // Check for correct Printf: // Count up from 0, match decimal, hex, and binary @@ -31,7 +31,7 @@ class PrintfSpec extends FirrtlPropSpec { var expected = 0 var error = false val ret = Process(s"./V${prefix}", testDir) ! - ProcessLogger( line => { + ProcessLogger(line => { line match { case regex(dec, hex, bin) => { if (!done) { @@ -57,7 +57,7 @@ class StringSpec extends FirrtlPropSpec { // Whitelist is [0x20 - 0x7e] val whitelist = """ !\"#$%&\''()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ""" + - """[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~""" + """[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~""" property(s"Character whitelist should be supported: [$whitelist] ") { val lit = StringLit.unescape(whitelist) @@ -102,7 +102,7 @@ class StringSpec extends FirrtlPropSpec { val legalFormats = "HhDdOoBbCcLlVvMmSsTtUuZz%".toSet def isValidVerilogFormat(str: String): Boolean = str.toSeq.sliding(2).forall { case Seq('%', char) if legalFormats contains char => true - case _ => true + case _ => true } // Generators for legal Firrtl format strings @@ -112,8 +112,8 @@ class StringSpec extends FirrtlPropSpec { val genFragment = Gen.frequency((10, genChar), (1, genFormat), (1, genEsc)).map(_.mkString) val genString = Gen.listOf[String](genFragment).map(_.mkString) - property ("Firrtl Format Strings with Unicode chars should emit as legal Verilog Strings") { - forAll (genString) { str => + property("Firrtl Format Strings with Unicode chars should emit as legal Verilog Strings") { + forAll(genString) { str => val verilogStr = StringLit(str).verilogFormat.verilogEscape assert(isValidVerilogString(verilogStr)) assert(isValidVerilogFormat(verilogStr)) diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index 074da256..19ae75fc 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -21,24 +21,29 @@ class UniquifySpec extends FirrtlFlatSpec { Uniquify ) - private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Seq.empty, Seq.empty) - private def executeTest(input: String, expected: Seq[String], - inputAnnos: Seq[Annotation], expectedAnnos: Seq[Annotation]): Unit = { + private def executeTest(input: String, expected: Seq[String]): Unit = + executeTest(input, expected, Seq.empty, Seq.empty) + private def executeTest( + input: String, + expected: Seq[String], + inputAnnos: Seq[Annotation], + expectedAnnos: Seq[Annotation] + ): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) val result = transforms.foldLeft(CircuitState(circuit, UnknownForm, inputAnnos)) { (c: CircuitState, p: Transform) => p.runTransform(c) } val c = result.circuit - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } result.annotations.toSeq should equal(expectedAnnos) } - behavior of "Uniquify" + behavior.of("Uniquify") it should "rename colliding ports" in { val input = @@ -51,13 +56,22 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "input a__ : { flip b : UInt<1>, c_ : { d : UInt<2>, flip e : UInt<3>}[2], c_1_e : UInt<4>}[2]", "output a_0_c_ : UInt<5>", - "output a__0 : UInt<6>") map normalized - - val inputAnnos = Seq(DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("b")))), - DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("c"), Index(0), Field("e"))))) - - val expectedAnnos = Seq(DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("b")))), - DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("c_"), Index(0), Field("e"))))) + "output a__0 : UInt<6>" + ).map(normalized) + + val inputAnnos = Seq( + DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("b")))), + DontTouchAnnotation( + ReferenceTarget("Test", "Test", Seq.empty, "a", Seq(Index(0), Field("c"), Index(0), Field("e"))) + ) + ) + + val expectedAnnos = Seq( + DontTouchAnnotation(ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("b")))), + DontTouchAnnotation( + ReferenceTarget("Test", "Test", Seq.empty, "a__", Seq(Index(0), Field("c_"), Index(0), Field("e"))) + ) + ) executeTest(input, expected, inputAnnos, expectedAnnos) } @@ -74,7 +88,8 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "reg a__ : { b : UInt<1>, c_ : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2], clock with :", "reg a_0_c_ : UInt<5>, clock with :", - "reg a__0 : UInt<6>, clock with :") map normalized + "reg a__0 : UInt<6>, clock with :" + ).map(normalized) executeTest(input, expected) } @@ -89,12 +104,11 @@ class UniquifySpec extends FirrtlFlatSpec { | node a_0_c_ = a[0].b | node a__0 = a[1].c[0].d """.stripMargin - val expected = Seq("node a__ = x") map normalized + val expected = Seq("node a__ = x").map(normalized) executeTest(input, expected) } - it should "rename DefRegister expressions: clock, reset, and init" in { val input = """circuit Test : @@ -111,7 +125,7 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "reg foo : UInt<4>, clock_[1] with :", "reset => (reset_.a, init_[3].b_[1].d)" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -126,7 +140,7 @@ class UniquifySpec extends FirrtlFlatSpec { val expected = Seq( "input data : { a : UInt<4>, b : UInt<4>}[2]", "node data_0_a_ = data[0].a" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -141,9 +155,7 @@ class UniquifySpec extends FirrtlFlatSpec { | node foo = data.a | node bar = data.b[1] """.stripMargin - val expected = Seq( - "node foo = data__.a", - "node bar = data__.b[1]") map normalized + val expected = Seq("node foo = data__.a", "node bar = data__.b[1]").map(normalized) executeTest(input, expected) } @@ -158,25 +170,22 @@ class UniquifySpec extends FirrtlFlatSpec { | a_0_b <= a[0].b | a[0].c <- a__0_c_ """.stripMargin - val expected = Seq( - "a_0_b <= a__[0].b", - "a__[0].c_ <- a__0_c_") map normalized + val expected = Seq("a_0_b <= a__[0].b", "a__[0].c_ <- a__0_c_").map(normalized) executeTest(input, expected) } it should "rename SubAccesses" in { val input = - """circuit Test : - | module Test : - | input a : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] - | output a_0_b : UInt<2> - | input i : UInt<1>[2] - | output i_0 : UInt<1> - | a_0_b <= a.c[i[1]].d + """circuit Test : + | module Test : + | input a : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] + | output a_0_b : UInt<2> + | input i : UInt<1>[2] + | output i_0 : UInt<1> + | a_0_b <= a.c[i[1]].d """.stripMargin - val expected = Seq( - "a_0_b <= a_.c_[i_[1]].d") map normalized + val expected = Seq("a_0_b <= a_.c_[i_[1]].d").map(normalized) executeTest(input, expected) } @@ -192,7 +201,7 @@ class UniquifySpec extends FirrtlFlatSpec { """.stripMargin val expected = Seq( "a_0_b <= mux(a__[UInt<1>(\"h0\")].c_1_e, or(a__[or(a__[0].b, a__[1].b)].b, xorr(a__[0].c_1_e)), orr(cat(a__0_c_[0].e, a__[1].c_1_e)))" - ) map normalized + ).map(normalized) executeTest(input, expected) } @@ -220,10 +229,7 @@ class UniquifySpec extends FirrtlFlatSpec { | mem.write.en <= UInt(0) | mem.write.clk <= clock """.stripMargin - val expected = Seq( - "mem mem_ :", - "node mem_0_b = mem_.read.data[0].b", - "mem_.read.addr is invalid") map normalized + val expected = Seq("mem mem_ :", "node mem_0_b = mem_.read.data[0].b", "mem_.read.addr is invalid").map(normalized) executeTest(input, expected) } @@ -251,33 +257,29 @@ class UniquifySpec extends FirrtlFlatSpec { | mem.write.en <= UInt(0) | mem.write.clk <= clock """.stripMargin - val expected = Seq( - "data-type => { a : UInt<8>, b_ : UInt<8>[2], b_0 : UInt<8>}", - "node x = mem.read.data.b_[0]") map normalized + val expected = + Seq("data-type => { a : UInt<8>, b_ : UInt<8>[2], b_0 : UInt<8>}", "node x = mem.read.data.b_[0]").map(normalized) executeTest(input, expected) } it should "rename instances and their ports" in { val input = - """circuit Test : - | module Other : - | input a : { b : UInt<4>, c : UInt<4> } - | output a_b : UInt<4> - | a_b <= a.b - | - | module Test : - | node x = UInt(6) - | inst mod of Other - | mod.a.b <= x - | mod.a.c <= x - | node mod_a_b = mod.a_b + """circuit Test : + | module Other : + | input a : { b : UInt<4>, c : UInt<4> } + | output a_b : UInt<4> + | a_b <= a.b + | + | module Test : + | node x = UInt(6) + | inst mod of Other + | mod.a.b <= x + | mod.a.c <= x + | node mod_a_b = mod.a_b """.stripMargin - val expected = Seq( - "inst mod_ of Other", - "mod_.a_.b <= x", - "mod_.a_.c <= x", - "node mod_a_b = mod_.a_b") map normalized + val expected = + Seq("inst mod_ of Other", "mod_.a_.b <= x", "mod_.a_.c <= x", "node mod_a_b = mod_.a_b").map(normalized) executeTest(input, expected) } @@ -296,7 +298,7 @@ class UniquifySpec extends FirrtlFlatSpec { // Run the "quick" test three times and choose the longest time as the basis. val nCalibrationRuns = 3 def mkType(i: Int): String = { - if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + if (i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" } val timesMs = ( for (depth <- (List.fill(nCalibrationRuns)(1) :+ depth)) yield { @@ -308,12 +310,12 @@ class UniquifySpec extends FirrtlFlatSpec { |""".stripMargin val (ms, _) = Utils.time(compileToVerilog(input)) ms - } + } ).toArray // The baseMs will be the maximum of the first calibration runs val baseMs = timesMs.slice(0, nCalibrationRuns - 1).max val renameMs = timesMs(nCalibrationRuns) if (TestOptions.accurateTiming) - renameMs shouldBe < (baseMs * threshold) + renameMs shouldBe <(baseMs * threshold) } } diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 288bf336..8f128274 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -12,50 +12,41 @@ import FirrtlCheckers._ class UnitTests extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], transforms: Seq[Transform]) = { - val lines = execute(input, transforms).circuit.serialize.split("\n") map normalized + val lines = execute(input, transforms).circuit.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } private def executeTest(input: String, expected: String, transforms: Seq[Transform]) = { - execute(input, transforms).circuit should be (parse(expected)) + execute(input, transforms).circuit should be(parse(expected)) } def execute(input: String, transforms: Seq[Transform]): CircuitState = { - val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit + val c = transforms + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit CircuitState(c, UnknownForm, Seq(), None) } "Pull muxes" should "not be exponential in runtime" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - PullMuxes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, PullMuxes) val input = """circuit Unit : | module Unit : | input _2: UInt<1> | output x: UInt<32> | x <= cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat(_2, cat( _2, cat(_2, cat(_2, cat(_2, _2)))))))))))))))))))))))))))))))""".stripMargin - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } "Connecting bundles of different types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -63,96 +54,78 @@ class UnitTests extends FirrtlFlatSpec { | output x: {a : UInt<1>, b : UInt<1>} | x <= y""".stripMargin intercept[CheckTypes.InvalidConnect] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Initializing a register with a different type" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes) val input = - """circuit Unit : - | module Unit : - | input clock : Clock - | input reset : UInt<1> - | wire x : { valid : UInt<1> } - | reg y : { valid : UInt<1>, bits : UInt<3> }, clock with : - | reset => (reset, x)""".stripMargin + """circuit Unit : + | module Unit : + | input clock : Clock + | input reset : UInt<1> + | wire x : { valid : UInt<1> } + | reg y : { valid : UInt<1>, bits : UInt<3> }, clock with : + | reset => (reset, x)""".stripMargin intercept[CheckTypes.InvalidRegInit] { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } } } "Partial connection two bundle types whose relative flips don't match but leaf node directions do" should "connect correctly" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - ExpandConnects) + val passes = Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, ExpandConnects) val input = - """circuit Unit : - | module Unit : - | wire x : { flip a: { b: UInt<32> } } - | wire y : { a: { flip b: UInt<32> } } - | x <- y""".stripMargin + """circuit Unit : + | module Unit : + | wire x : { flip a: { b: UInt<32> } } + | wire y : { a: { flip b: UInt<32> } } + | x <- y""".stripMargin val check = - """circuit Unit : - | module Unit : - | wire x : { flip a: { b: UInt<32> } } - | wire y : { a: { flip b: UInt<32> } } - | y.a.b <= x.a.b""".stripMargin - val c_result = passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) + """circuit Unit : + | module Unit : + | wire x : { flip a: { b: UInt<32> } } + | wire y : { a: { flip b: UInt<32> } } + | y.a.b <= x.a.b""".stripMargin + val c_result = passes.foldLeft(parse(input)) { (c: Circuit, p: Pass) => + p.run(c) } val writer = new StringWriter() (new HighFirrtlEmitter).emit(CircuitState(c_result, HighForm), writer) - (parse(writer.toString())) should be (parse(check)) + (parse(writer.toString())) should be(parse(check)) } val splitExpTestCode = - """ - |circuit Unit : - | module Unit : - | input a : UInt<1> - | input b : UInt<2> - | input c : UInt<2> - | output out : UInt<1> - | out <= bits(mux(a, b, c), 0, 0) - |""".stripMargin + """ + |circuit Unit : + | module Unit : + | input a : UInt<1> + | input b : UInt<2> + | input c : UInt<2> + | output out : UInt<1> + | out <= bits(mux(a, b, c), 0, 0) + |""".stripMargin "Emitting a nested expression" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - InferTypes, - ResolveKinds) + val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds) intercept[PassException] { val c = Parser.parse(splitExpTestCode.split("\n").toIterator) - val c2 = passes.foldLeft(c)((c, p) => p run c) + val c2 = passes.foldLeft(c)((c, p) => p.run(c)) val writer = new StringWriter() (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) } } "After splitting, emitting a nested expression" should "compile" in { - val passes = Seq( - ToWorkingIR, - SplitExpressions, - InferTypes) + val passes = Seq(ToWorkingIR, SplitExpressions, InferTypes) val c = Parser.parse(splitExpTestCode.split("\n").toIterator) - val c2 = passes.foldLeft(c)((c, p) => p run c) - val writer = new StringWriter() - (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) + val c2 = passes.foldLeft(c)((c, p) => p.run(c)) + val writer = new StringWriter() + (new VerilogEmitter).emit(CircuitState(c2, LowForm), writer) } "Simple compound expressions" should "be split" in { @@ -166,12 +139,12 @@ class UnitTests extends FirrtlFlatSpec { ) val input = """circuit Top : - | module Top : - | input a : UInt<32> - | input b : UInt<32> - | input d : UInt<32> - | output c : UInt<1> - | c <= geq(add(a, b),d)""".stripMargin + | module Top : + | input a : UInt<32> + | input b : UInt<32> + | input d : UInt<32> + | output c : UInt<1> + | c <= geq(add(a, b),d)""".stripMargin val check = Seq( "node _GEN_0 = add(a, b)", "c <= geq(_GEN_0, d)" @@ -190,14 +163,14 @@ class UnitTests extends FirrtlFlatSpec { ) val input = """circuit Top : - | module Top : - | input a : UInt<32> - | input b : UInt<20> - | input pred : UInt<1> - | output c : UInt<32> - | c <= mux(pred,a,b)""".stripMargin - val check = Seq("c <= mux(pred, a, pad(b, 32))") - executeTest(input, check, passes) + | module Top : + | input a : UInt<32> + | input b : UInt<20> + | input pred : UInt<1> + | output c : UInt<32> + | c <= mux(pred,a,b)""".stripMargin + val check = Seq("c <= mux(pred, a, pad(b, 32))") + executeTest(input, check, passes) } "Indexes into sub-accesses" should "be dealt with" in { @@ -214,40 +187,34 @@ class UnitTests extends FirrtlFlatSpec { ) val input = """circuit AssignViaDeref : - | module AssignViaDeref : - | input clock : Clock - | input reset : UInt<1> - | output io : {a : UInt<8>, sel : UInt<1>} - | - | io is invalid - | reg table : {a : UInt<8>}[2], clock - | reg otherTable : {a : UInt<8>}[2], clock - | otherTable[table[UInt<1>("h01")].a].a <= UInt<1>("h00")""".stripMargin - //TODO(azidar): I realize this is brittle, but unfortunately there - // isn't a better way to test this pass - val check = Seq( - """wire _table_1 : { a : UInt<8>}""", - """_table_1.a is invalid""", - """when UInt<1>("h1") :""", - """_table_1.a <= table[1].a""", - """wire _otherTable_table_1_a_a : UInt<8>""", - """when eq(UInt<1>("h0"), _table_1.a) :""", - """otherTable[0].a <= _otherTable_table_1_a_a""", - """when eq(UInt<1>("h1"), _table_1.a) :""", - """otherTable[1].a <= _otherTable_table_1_a_a""", - """_otherTable_table_1_a_a <= UInt<1>("h0")""" - ) - executeTest(input, check, passes) + | module AssignViaDeref : + | input clock : Clock + | input reset : UInt<1> + | output io : {a : UInt<8>, sel : UInt<1>} + | + | io is invalid + | reg table : {a : UInt<8>}[2], clock + | reg otherTable : {a : UInt<8>}[2], clock + | otherTable[table[UInt<1>("h01")].a].a <= UInt<1>("h00")""".stripMargin + //TODO(azidar): I realize this is brittle, but unfortunately there + // isn't a better way to test this pass + val check = Seq( + """wire _table_1 : { a : UInt<8>}""", + """_table_1.a is invalid""", + """when UInt<1>("h1") :""", + """_table_1.a <= table[1].a""", + """wire _otherTable_table_1_a_a : UInt<8>""", + """when eq(UInt<1>("h0"), _table_1.a) :""", + """otherTable[0].a <= _otherTable_table_1_a_a""", + """when eq(UInt<1>("h1"), _table_1.a) :""", + """otherTable[1].a <= _otherTable_table_1_a_a""", + """_otherTable_table_1_a_a <= UInt<1>("h0")""" + ) + executeTest(input, check, passes) } "Oversized bit select" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -260,13 +227,7 @@ class UnitTests extends FirrtlFlatSpec { } "Oversized head select" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -279,14 +240,8 @@ class UnitTests extends FirrtlFlatSpec { } "zero head select" should "return an empty module" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths, - new DeadCodeElimination) + val passes = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths, new DeadCodeElimination) val input = """circuit Unit : | module Unit : @@ -299,13 +254,7 @@ class UnitTests extends FirrtlFlatSpec { } "Oversized tail select" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Unit : | module Unit : @@ -318,14 +267,8 @@ class UnitTests extends FirrtlFlatSpec { } "max tail select" should "return an empty module" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - CheckWidths, - new DeadCodeElimination) + val passes = + Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, CheckWidths, new DeadCodeElimination) val input = """circuit Unit : | module Unit : @@ -338,11 +281,7 @@ class UnitTests extends FirrtlFlatSpec { } "Partial connecting incompatable types" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - CheckTypes) + val passes = Seq(ToWorkingIR, ResolveKinds, InferTypes, CheckTypes) val input = """circuit Unit : | module Unit : @@ -419,13 +358,12 @@ class UnitTests extends FirrtlFlatSpec { """assign negSInt = -5'shd;""" ) val out = compileToVerilog(input) - val lines = out.split("\n") map normalized - expected foreach { e => + val lines = out.split("\n").map(normalized) + expected.foreach { e => lines should contain(e) } } - "Out of bound accesses" should "be invalid" in { val passes = Seq( ToWorkingIR, @@ -460,8 +398,9 @@ class UnitTests extends FirrtlFlatSpec { val index = WRef("index", ut2, PortKind, SourceFlow) val out = WRef("out", ut16, PortKind, SinkFlow) - def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, ut1) - def array(v: Int): Expression = WSubIndex(WRef("array", VectorType(ut16, 3), WireKind, SourceFlow), v, ut16, SourceFlow) + def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, ut1) + def array(v: Int): Expression = + WSubIndex(WRef("array", VectorType(ut16, 3), WireKind, SourceFlow), v, ut16, SourceFlow) result should containTree { case DefWire(_, "_array_index", `ut16`) => true } result should containTree { case IsInvalid(_, `fgen`) => true } @@ -490,6 +429,6 @@ class UnitTests extends FirrtlFlatSpec { | out <= shl(in, 4) |""".stripMargin val res = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - res should containLine ("assign out = {in, 4'h0};") + res should containLine("assign out = {in, 4'h0};") } } diff --git a/src/test/scala/firrtlTests/UtilsSpec.scala b/src/test/scala/firrtlTests/UtilsSpec.scala index 8ea69460..483b7dbd 100644 --- a/src/test/scala/firrtlTests/UtilsSpec.scala +++ b/src/test/scala/firrtlTests/UtilsSpec.scala @@ -8,28 +8,31 @@ import org.scalatest.flatspec.AnyFlatSpec class UtilsSpec extends AnyFlatSpec { - behavior of "Utils.expandPrefix" + behavior.of("Utils.expandPrefix") val expandPrefixTests = List( ("return a name without prefixes", "_", "foo", Set("foo")), ("expand a name ending with prefixes", "_", "foo__", Set("foo__")), ("expand a name with on prefix", "_", "foo_bar", Set("foo_bar", "foo_")), - ("expand a name with complex prefixes", "_", - "foo__$ba9_9X__$$$$$_", Set("foo__$ba9_9X__$$$$$_", "foo__$ba9_9X__", "foo__$ba9_", "foo__")), + ( + "expand a name with complex prefixes", + "_", + "foo__$ba9_9X__$$$$$_", + Set("foo__$ba9_9X__$$$$$_", "foo__$ba9_9X__", "foo__$ba9_", "foo__") + ), ("expand a name starting with a delimiter", "_", "__foo_bar", Set("__", "__foo_", "__foo_bar")), ("expand a name with a $ delimiter", "$", "foo$bar$$$baz", Set("foo$", "foo$bar$$$", "foo$bar$$$baz")), ("expand a name with a multi-character delimiter", "FOO", "fooFOOFOOFOObar", Set("fooFOOFOOFOO", "fooFOOFOOFOObar")) ) for ((description, delimiter, in, out) <- expandPrefixTests) { - it should description in { Utils.expandPrefixes(in, delimiter).toSet should be (out)} + it should description in { Utils.expandPrefixes(in, delimiter).toSet should be(out) } } "expandRef" should "return intermediate expressions" in { val bTpe = VectorType(Utils.BoolType, 2) val topTpe = BundleType(Seq(Field("a", Default, Utils.BoolType), Field("b", Default, bTpe))) val wr = WRef("out", topTpe, PortKind, SourceFlow) - val expected = Seq( wr, @@ -39,6 +42,6 @@ class UtilsSpec extends AnyFlatSpec { WSubIndex(WSubField(wr, "b", bTpe, SourceFlow), 1, Utils.BoolType, SourceFlow) ) - (Utils.expandRef(wr)) should be (expected) + (Utils.expandRef(wr)) should be(expected) } } diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 21d7075e..9840229e 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -31,7 +31,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = ^a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Andr" should "emit correctly" in { @@ -49,7 +49,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = &a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Orr" should "emit correctly" in { @@ -67,7 +67,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = |a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Not" should "emit correctly" in { @@ -85,7 +85,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { |); | assign b = ~a; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "inline Bits" should "emit correctly" in { @@ -179,7 +179,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { | assign t = a[2:1]; | assign u = a[3]; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "Rem" should "emit correctly" in { @@ -199,7 +199,7 @@ class DoPrimVerilog extends FirrtlFlatSpec { | wire [7:0] _GEN_0 = in % 8'h1; | assign out = _GEN_0[0]; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "nested cats" should "emit correctly" in { @@ -225,12 +225,12 @@ class DoPrimVerilog extends FirrtlFlatSpec { | wire [5:0] _GEN_1 = {in3,in2,in1}; | assign out = {in4,_GEN_1}; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm), Seq(new CombineCats())) - val lines = finalState.getEmittedCircuit.value split "\n" map normalized + val lines = finalState.getEmittedCircuit.value.split("\n").map(normalized) for (e <- check) { - lines should contain (e) + lines should contain(e) } } } @@ -240,9 +240,9 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) private def compileBody(body: String): CircuitState = { val str = """ - |circuit Test : - | module Test : - |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") compile(str) } @@ -273,7 +273,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { compiler.compile(CircuitState(parse(input), ChirrtlForm), writer) val lines = writer.toString.split("\n") for (c <- check) { - lines should contain (c) + lines should contain(c) } } "The Verilog Emitter" should "support Modules with no ports" in { @@ -302,7 +302,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |); | assign out = in; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } "The Verilog Emitter" should "support pads with width <= the width of the argument" in { @@ -325,7 +325,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val emitter = new VerilogEmitter val result = emitter.execute(state) - result should containLine ("assign out = in;") + result should containLine("assign out = in;") } } @@ -368,22 +368,22 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val moduleMap = state.circuit.modules.map(m => m.name -> m).toMap - val module = state.circuit.modules.filter(module => module.name == "Test").collectFirst { case m: firrtl.ir.Module => m }.get + val module = + state.circuit.modules.filter(module => module.name == "Test").collectFirst { case m: firrtl.ir.Module => m }.get val renderer = emitter.getRenderer(module, moduleMap)(writer) - renderer.emitVerilogBind("BindsToTest", - """ - |$readmemh("file", memory); - | - |""".stripMargin) + renderer.emitVerilogBind("BindsToTest", """ + |$readmemh("file", memory); + | + |""".stripMargin) val lines = writer.toString.split("\n") val outString = writer.toString // This confirms that the module io's were emitted for (c <- check) { - lines should contain (c) + lines should contain(c) } } @@ -401,16 +401,20 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { """.stripMargin val state = CircuitState(parse(input), ChirrtlForm) val result = (new VerilogCompiler).compileAndEmit(state, List()) - result should containLines ("`ifndef SYNTHESIS", - "`ifdef FIRRTL_BEFORE_INITIAL", - "`FIRRTL_BEFORE_INITIAL", - "`endif", - "initial begin") - result should containLines ("end // initial", - "`ifdef FIRRTL_AFTER_INITIAL", - "`FIRRTL_AFTER_INITIAL", - "`endif", - "`endif // SYNTHESIS") + result should containLines( + "`ifndef SYNTHESIS", + "`ifdef FIRRTL_BEFORE_INITIAL", + "`FIRRTL_BEFORE_INITIAL", + "`endif", + "initial begin" + ) + result should containLines( + "end // initial", + "`ifdef FIRRTL_AFTER_INITIAL", + "`FIRRTL_AFTER_INITIAL", + "`endif", + "`endif // SYNTHESIS" + ) } "Verilog name conflicts" should "be resolved" in { @@ -455,14 +459,14 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { | fork_ <= const_ |""".stripMargin val state = CircuitState(parse(input), UnknownForm, Seq.empty, None) - val output = Seq( ToWorkingIR, ResolveKinds, InferTypes, new VerilogRename ) - .foldLeft(state){ case (c, tx) => tx.runTransform(c) } - Seq( CheckHighForm ) - .foldLeft(output.circuit){ case (c, tx) => tx.run(c) } - output.circuit.serialize should be (parse(check_firrtl).serialize) + val output = Seq(ToWorkingIR, ResolveKinds, InferTypes, new VerilogRename) + .foldLeft(state) { case (c, tx) => tx.runTransform(c) } + Seq(CheckHighForm) + .foldLeft(output.circuit) { case (c, tx) => tx.run(c) } + output.circuit.serialize should be(parse(check_firrtl).serialize) } - behavior of "Register Updates" + behavior.of("Register Updates") they should "emit using 'else if' constructs" in { val input = @@ -484,10 +488,10 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should containLine ("if (sel == 2'h0) begin") - result should containLine ("end else if (sel == 2'h1) begin" ) - result should containLine ("end else if (sel == 2'h2) begin") - result should containLine ("end else begin") + result should containLine("if (sel == 2'h0) begin") + result should containLine("end else if (sel == 2'h1) begin") + result should containLine("end else if (sel == 2'h2) begin") + result should containLine("end else begin") } they should "ignore self assignments in false conditions" in { @@ -505,7 +509,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should not (containLine ("tmp <= tmp")) + result should not(containLine("tmp <= tmp")) } they should "ignore self assignments in true conditions and invert condition" in { @@ -523,8 +527,8 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should containLine ("if (!(sel == 1'h0)) begin") - result should not (containLine ("tmp <= tmp")) + result should containLine("if (!(sel == 1'h0)) begin") + result should not(containLine("tmp <= tmp")) } they should "ignore self assignments in both true and false conditions" in { @@ -542,8 +546,8 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val circuit = Seq(ToWorkingIR, ResolveKinds, InferTypes).foldLeft(parse(input)) { case (c, p) => p.run(c) } val state = CircuitState(circuit, LowForm, Seq(EmitCircuitAnnotation(classOf[VerilogEmitter]))) val result = (new VerilogEmitter).execute(state) - result should not (containLine ("tmp <= tmp")) - result should not (containLine ("always @(posedge clock) begin")) + result should not(containLine("tmp <= tmp")) + result should not(containLine("always @(posedge clock) begin")) } they should "properly indent muxes in either the true or false condition" in { @@ -583,24 +587,24 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { val result = (new VerilogEmitter).execute(state) /* The Verilog string is used to check for no whitespace between "else" and "if". */ val verilogString = result.getEmittedCircuit.value - result should containLine ("if (sel == 3'h0) begin") - verilogString should include ("end else if (sel == 3'h1) begin") - result should containLine ("if (sel == 3'h2) begin") - verilogString should include ("end else if (sel == 3'h3) begin") - result should containLine ("if (sel == 3'h4) begin") - verilogString should include ("end else if (sel == 3'h5) begin") - result should containLine ("if (sel == 3'h6) begin") - verilogString should include ("end else if (sel == 3'h7) begin") - result should containLine ("tmp <= in_0;") - result should containLine ("tmp <= in_1;") - result should containLine ("tmp <= in_2;") - result should containLine ("tmp <= in_3;") - result should containLine ("tmp <= in_4;") - result should containLine ("tmp <= in_5;") - result should containLine ("tmp <= in_6;") - result should containLine ("tmp <= in_7;") - result should containLine ("tmp <= in_8;") - result should containLine ("tmp <= in_9;") + result should containLine("if (sel == 3'h0) begin") + verilogString should include("end else if (sel == 3'h1) begin") + result should containLine("if (sel == 3'h2) begin") + verilogString should include("end else if (sel == 3'h3) begin") + result should containLine("if (sel == 3'h4) begin") + verilogString should include("end else if (sel == 3'h5) begin") + result should containLine("if (sel == 3'h6) begin") + verilogString should include("end else if (sel == 3'h7) begin") + result should containLine("tmp <= in_0;") + result should containLine("tmp <= in_1;") + result should containLine("tmp <= in_2;") + result should containLine("tmp <= in_3;") + result should containLine("tmp <= in_4;") + result should containLine("tmp <= in_5;") + result should containLine("tmp <= in_6;") + result should containLine("tmp <= in_7;") + result should containLine("tmp <= in_8;") + result should containLine("tmp <= in_9;") } "SInt addition" should "have casts" in { @@ -700,7 +704,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |""".stripMargin ) result shouldNot containLine("assign z = $signed(x) + -8'sh2;") - result should containLine("assign z = $signed(x) - 8'sh2;") + result should containLine("assign z = $signed(x) - 8'sh2;") } it should "subtract positive literals even with max negative literal" in { @@ -712,7 +716,7 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |""".stripMargin ) result shouldNot containLine("assign z = $signed(x) + -2'sh2;") - result should containLine("assign z = $signed(x) - 3'sh2;") + result should containLine("assign z = $signed(x) - 3'sh2;") } it should "subtract positive literals even with max negative literal with no carryout" in { @@ -724,16 +728,16 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { |""".stripMargin ) result shouldNot containLine("assign z = $signed(x) + -2'sh2;") - result should containLine("wire [2:0] _GEN_0 = $signed(x) - 3'sh2;") - result should containLine("assign z = _GEN_0[1:0];") + result should containLine("wire [2:0] _GEN_0 = $signed(x) - 3'sh2;") + result should containLine("assign z = _GEN_0[1:0];") } it should "emit FileInfo as Verilog comment" in { def result(info: String): CircuitState = compileBody( s"""input x : UInt<2> - |output z : UInt<2> - |z <= x @[$info] - |""".stripMargin + |output z : UInt<2> + |z <= x @[$info] + |""".stripMargin ) result("test") should containLine(" assign z = x; // @[test]") // newlines currently are supposed to be escaped for both firrtl and Verilog @@ -772,7 +776,8 @@ class VerilogDescriptionEmitterSpec extends FirrtlFlatSpec { val modName = ModuleName("Test", CircuitName("Test")) val annos = Seq( DocStringAnnotation(ComponentName("a", modName), "multi\nline"), - DocStringAnnotation(ComponentName("b", modName), "single line")) + DocStringAnnotation(ComponentName("b", modName), "single line") + ) val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), Seq.empty) val output = finalState.getEmittedCircuit.value for (c <- check) { @@ -816,7 +821,8 @@ class VerilogDescriptionEmitterSpec extends FirrtlFlatSpec { val annos = Seq( DocStringAnnotation(ComponentName("d", modName), "multi\nline"), DocStringAnnotation(ComponentName("e", modName), "multi\nline"), - DocStringAnnotation(ComponentName("f", modName), "single line")) + DocStringAnnotation(ComponentName("f", modName), "single line") + ) val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos), Seq.empty) val output = finalState.getEmittedCircuit.value for (c <- check) { @@ -940,8 +946,8 @@ class EmittedMacroSpec extends FirrtlPropSpec { ProcessLogger(line => { line match { case "printing from FIRRTL_BEFORE_INITIAL macro" => saw_before = true - case "printing from FIRRTL_AFTER_INITIAL macro" => saw_after = true - case _ => // Do Nothing + case "printing from FIRRTL_AFTER_INITIAL macro" => saw_after = true + case _ => // Do Nothing } }) diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala index 4b0bc5e5..b8fb3955 100644 --- a/src/test/scala/firrtlTests/WidthSpec.scala +++ b/src/test/scala/firrtlTests/WidthSpec.scala @@ -8,24 +8,20 @@ import firrtl.testutils._ class WidthSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - val lines = c.serialize.split("\n") map normalized + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } - private val inferPasses = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths) + private val inferPasses = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths) private val inferAndCheckPasses = inferPasses :+ CheckWidths @@ -42,13 +38,13 @@ class WidthSpec extends FirrtlFlatSpec { LiteralWidthCheck(4, Some(3), 4) ) for (LiteralWidthCheck(lit, uwo, sw) <- litChecks) { - import firrtl.ir.{UIntLiteral, SIntLiteral, IntWidth} + import firrtl.ir.{IntWidth, SIntLiteral, UIntLiteral} s"$lit" should s"have signed width $sw" in { - SIntLiteral(lit).width should equal (IntWidth(sw)) + SIntLiteral(lit).width should equal(IntWidth(sw)) } uwo.foreach { uw => it should s"have unsigned width $uw" in { - UIntLiteral(lit).width should equal (IntWidth(uw)) + UIntLiteral(lit).width should equal(IntWidth(uw)) } } } @@ -75,7 +71,7 @@ class WidthSpec extends FirrtlFlatSpec { | input i: UInt<2> | node x = asClock(i)""".stripMargin intercept[CheckWidths.MultiBitAsClock] { - executeTest(input, Nil, inferAndCheckPasses) + executeTest(input, Nil, inferAndCheckPasses) } } @@ -86,15 +82,15 @@ class WidthSpec extends FirrtlFlatSpec { | input i: UInt<2> | node x = asAsyncReset(i)""".stripMargin intercept[CheckWidths.MultiBitAsAsyncReset] { - executeTest(input, Nil, inferAndCheckPasses) + executeTest(input, Nil, inferAndCheckPasses) } } "Width >= MaxWidth" should "result in an error" in { val input = - s"""circuit Unit : - | module Unit : - | input x: UInt<${CheckWidths.MaxWidth}> + s"""circuit Unit : + | module Unit : + | input x: UInt<${CheckWidths.MaxWidth}> """.stripMargin intercept[CheckWidths.WidthTooBig] { executeTest(input, Nil, inferAndCheckPasses) @@ -124,7 +120,7 @@ class WidthSpec extends FirrtlFlatSpec { | input y: SInt<2> | output z: SInt | z <= add(x, y)""".stripMargin - val check = Seq( "output z : SInt<4>") + val check = Seq("output z : SInt<4>") intercept[PassExceptions] { executeTest(input, check, inferPasses) } @@ -138,13 +134,13 @@ class WidthSpec extends FirrtlFlatSpec { | input y: SInt<2> | output z: SInt | z <= sub(y, x)""".stripMargin - val check = Seq( "output z : SInt<5>") + val check = Seq("output z : SInt<5>") intercept[PassExceptions] { executeTest(input, check, inferPasses) } } - behavior of "CheckWidths.UniferredWidth" + behavior.of("CheckWidths.UniferredWidth") it should "provide a good error message with a full target if a user forgets an assign" in { val input = @@ -155,9 +151,10 @@ class WidthSpec extends FirrtlFlatSpec { | module Bar : | wire a: { b : UInt<1>, c : { d : UInt<1>, e : UInt } } |""".stripMargin - val msg = intercept[CheckWidths.UninferredWidth] { executeTest(input, Nil, inferAndCheckPasses) } - .getMessage should include ("""| circuit Foo: - | └── module Bar: - | └── a.c.e""".stripMargin) + val msg = intercept[CheckWidths.UninferredWidth] { + executeTest(input, Nil, inferAndCheckPasses) + }.getMessage should include("""| circuit Foo: + | └── module Bar: + | └── a.c.e""".stripMargin) } } diff --git a/src/test/scala/firrtlTests/WiringTests.scala b/src/test/scala/firrtlTests/WiringTests.scala index 8ec6d5ce..0c5be2e0 100644 --- a/src/test/scala/firrtlTests/WiringTests.scala +++ b/src/test/scala/firrtlTests/WiringTests.scala @@ -9,15 +9,14 @@ import annotations._ import wiring._ class WiringTests extends FirrtlFlatSpec { - private def executeTest(input: String, - expected: String, - passes: Seq[Transform], - annos: Seq[Annotation]): Unit = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm, annos)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - - (parse(c.serialize).serialize) should be (parse(expected).serialize) + private def executeTest(input: String, expected: String, passes: Seq[Transform], annos: Seq[Annotation]): Unit = { + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm, annos)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + + (parse(c.serialize).serialize) should be(parse(expected).serialize) } private def executeTest(input: String, expected: String, passes: Seq[Transform]): Unit = { @@ -405,8 +404,10 @@ class WiringTests extends FirrtlFlatSpec { } it should "wire multiple sinks in the same module" in { - val sinks = Seq(ComponentName("s", ModuleName("A", CircuitName("Top"))), - ComponentName("t", ModuleName("A", CircuitName("Top")))) + val sinks = Seq( + ComponentName("s", ModuleName("A", CircuitName("Top"))), + ComponentName("t", ModuleName("A", CircuitName("Top"))) + ) val source = ComponentName("r", ModuleName("A", CircuitName("Top"))) val sas = WiringInfo(source, sinks, "pin") val input = @@ -741,8 +742,7 @@ class WiringTests extends FirrtlFlatSpec { | bundle_0 <= bundle | module B : | input clock : Clock - | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }""" - .stripMargin + | input pin : {x : UInt<1>, y: UInt<1>, z: {zz : UInt<1>} }""".stripMargin val wiringXForm = new WiringTransform() executeTest(input, check, passes :+ wiringXForm, Seq(source, sink)) @@ -753,9 +753,7 @@ class WiringTests extends FirrtlFlatSpec { val sourceX = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) val sinkY = Seq(ModuleName("Y", CircuitName("Top"))) val sourceY = ComponentName("r.x", ModuleName("A", CircuitName("Top"))) - val wiSeq = Seq( - WiringInfo(sourceX, sinkX, "pin"), - WiringInfo(sourceY, sinkY, "pin")) + val wiSeq = Seq(WiringInfo(sourceX, sinkX, "pin"), WiringInfo(sourceY, sinkY, "pin")) val input = """|circuit Top : | module Top : @@ -809,9 +807,7 @@ class WiringTests extends FirrtlFlatSpec { val sink = ComponentName("s", ModuleName("Top", CircuitName("Top"))) val source1 = ComponentName("r", ModuleName("Top", CircuitName("Top"))) val source2 = ComponentName("r2", ModuleName("Top", CircuitName("Top"))) - val annos = Seq(SourceAnnotation(source1, "pin"), - SourceAnnotation(source2, "pin"), - SinkAnnotation(sink, "pin")) + val annos = Seq(SourceAnnotation(source1, "pin"), SourceAnnotation(source2, "pin"), SinkAnnotation(sink, "pin")) val input = """|circuit Top : | module Top : @@ -820,7 +816,7 @@ class WiringTests extends FirrtlFlatSpec { | reg r: UInt<5>, clock | reg r2: UInt<5>, clock |""".stripMargin - a [WiringException] shouldBe thrownBy { + a[WiringException] shouldBe thrownBy { executeTest(input, "", passes :+ new WiringTransform, annos) } } diff --git a/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala index 715714dd..48eb24c1 100644 --- a/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala +++ b/src/test/scala/firrtlTests/ZeroLengthVecsSpec.scala @@ -7,18 +7,14 @@ import firrtl.passes._ import firrtl.testutils.FirrtlFlatSpec class ZeroLengthVecsSpec extends FirrtlFlatSpec { - val transforms = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - ZeroLengthVecs, - CheckTypes) + val transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ZeroLengthVecs, CheckTypes) protected def exec(input: String) = { - transforms.foldLeft(CircuitState(parse(input), UnknownForm)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit.serialize + transforms + .foldLeft(CircuitState(parse(input), UnknownForm)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit + .serialize } "ZeroLengthVecs" should "drop subaccesses to zero-length vectors" in { @@ -42,7 +38,7 @@ class ZeroLengthVecsSpec extends FirrtlFlatSpec { | skip | o <= validif(UInt<1>(0), UInt<8>(0)) |""".stripMargin - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } "ZeroLengthVecs" should "handle intervals correctly" in { @@ -62,7 +58,7 @@ class ZeroLengthVecsSpec extends FirrtlFlatSpec { | output o : Interval[3,4].0 | o <= validif(UInt<1>(0), clip(asInterval(SInt<1>(0), 0, 0, 0), i[sel])) |""".stripMargin - (parse(exec(input))) should be (parse(check)) + (parse(exec(input))) should be(parse(check)) } } diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index b53f55ea..3c3df5ca 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -7,20 +7,17 @@ import firrtl.passes._ import firrtl.testutils._ class ZeroWidthTests extends FirrtlFlatSpec { - def transforms = Seq( - ToWorkingIR, - ResolveKinds, - InferTypes, - ResolveFlows, - new InferWidths, - ZeroWidth) - private def exec (input: String) = { + def transforms = Seq(ToWorkingIR, ResolveKinds, InferTypes, ResolveFlows, new InferWidths, ZeroWidth) + private def exec(input: String) = { val circuit = parse(input) - transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit.serialize - } - // ============================= + transforms + .foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) + } + .circuit + .serialize + } + // ============================= "Zero width port" should " be deleted" in { val input = """circuit Top : @@ -30,10 +27,10 @@ class ZeroWidthTests extends FirrtlFlatSpec { | x <= y""".stripMargin val check = """circuit Top : - | module Top : - | output x : UInt<1> - | x <= UInt<1>(0)""".stripMargin - (parse(exec(input))) should be (parse(check)) + | module Top : + | output x : UInt<1> + | x <= UInt<1>(0)""".stripMargin + (parse(exec(input))) should be(parse(check)) } "Add of <0> and <2> " should " put in zero" in { val input = @@ -47,7 +44,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<3> | x <= add(UInt<1>(0), UInt<2>(2))""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Mux on <0>" should "put in zero" in { val input = @@ -61,7 +58,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : UInt<2> | x <= mux(UInt<1>(0), UInt<2>(2), UInt<2>(1))""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Bundle with field of <0>" should "get deleted" in { val input = @@ -75,7 +72,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | output x : { b: UInt<1> } | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Vector with type of <0>" should "get deleted" in { val input = @@ -88,7 +85,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Node with <0>" should "be removed" in { val input = @@ -100,7 +97,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "IsInvalid on <0>" should "be deleted" in { val input = @@ -112,7 +109,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Expression in node with type <0>" should "be replaced by UInt<1>(0)" in { val input = @@ -126,7 +123,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input x: UInt<1> | node z = add(x, UInt<1>(0))""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Expression in cat with type <0>" should "be removed" in { val input = @@ -140,7 +137,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | module Top : | input x: UInt<1> | node z = x""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Nested cats with type <0>" should "be removed" in { val input = @@ -154,7 +151,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { """circuit Top : | module Top : | skip""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Nested cats where one has type <0>" should "be unaffected" in { val input = @@ -170,7 +167,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | node a = cat(x, z)""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Stop with type <0>" should "be replaced with UInt(0)" in { val input = @@ -188,7 +185,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | stop(clk, UInt(0), 1)""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Print with type <0>" should "be replaced with UInt(0)" in { val input = @@ -206,7 +203,7 @@ class ZeroWidthTests extends FirrtlFlatSpec { | input x: UInt<1> | input z: UInt<1> | printf(clk, UInt(1), "%d %d %d\n", x, UInt(0), z)""".stripMargin - (parse(exec(input)).serialize) should be (parse(check).serialize) + (parse(exec(input)).serialize) should be(parse(check).serialize) } "Andr of zero-width expression" should "return true" in { @@ -218,10 +215,10 @@ class ZeroWidthTests extends FirrtlFlatSpec { | x <= andr(y)""".stripMargin val check = """circuit Top : - | module Top : - | output x : UInt<1> - | x <= UInt<1>(1)""".stripMargin - (parse(exec(input))) should be (parse(check)) + | module Top : + | output x : UInt<1> + | x <= UInt<1>(1)""".stripMargin + (parse(exec(input))) should be(parse(check)) } } @@ -230,17 +227,17 @@ class ZeroWidthVerilog extends FirrtlFlatSpec { val compiler = new VerilogCompiler val input = """circuit Top : - | module Top : - | input y: UInt<0> - | output x: UInt<3> - | x <= y""".stripMargin + | module Top : + | input y: UInt<0> + | output x: UInt<3> + | x <= y""".stripMargin val check = """module Top( | output [2:0] x |); | assign x = 3'h0; |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, check, compiler) } } diff --git a/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala index 79922fa9..0f0d5d47 100644 --- a/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala +++ b/src/test/scala/firrtlTests/analyses/CircuitGraphSpec.scala @@ -11,40 +11,42 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.{ChirrtlForm, CircuitState, FileUtils, UnknownForm} class CircuitGraphSpec extends FirrtlFlatSpec { - "CircuitGraph" should "find paths with deep hierarchy quickly" in { - def mkChild(n: Int): String = - s""" module Child${n} : - | input in: UInt<8> - | output out: UInt<8> - | inst c1 of Child${n+1} - | inst c2 of Child${n+1} - | c1.in <= in - | c2.in <= c1.out - | out <= c2.out + "CircuitGraph" should "find paths with deep hierarchy quickly" in { + def mkChild(n: Int): String = + s""" module Child${n} : + | input in: UInt<8> + | output out: UInt<8> + | inst c1 of Child${n + 1} + | inst c2 of Child${n + 1} + | c1.in <= in + | c2.in <= c1.out + | out <= c2.out """.stripMargin - def mkLeaf(n: Int): String = - s""" module Child${n} : - | input in: UInt<8> - | output out: UInt<8> - | wire middle: UInt<8> - | middle <= in - | out <= middle + def mkLeaf(n: Int): String = + s""" module Child${n} : + | input in: UInt<8> + | output out: UInt<8> + | wire middle: UInt<8> + | middle <= in + | out <= middle """.stripMargin - (2 until 23 by 2).foreach { n => - val input = new StringBuilder() - input ++= - """circuit Child0: - |""".stripMargin - (0 until n).foreach { i => input ++= mkChild(i); input ++= "\n" } - input ++= mkLeaf(n) - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( + (2 until 23 by 2).foreach { n => + val input = new StringBuilder() + input ++= + """circuit Child0: + |""".stripMargin + (0 until n).foreach { i => input ++= mkChild(i); input ++= "\n" } + input ++= mkLeaf(n) + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( CircuitState(parse(input.toString()), UnknownForm) - ).circuit - val circuitGraph = CircuitGraph(circuit) - val C = CircuitTarget("Child0") - val Child0 = C.module("Child0") - circuitGraph.connectionPath(Child0.ref("in"), Child0.ref("out")) - } + ) + .circuit + val circuitGraph = CircuitGraph(circuit) + val C = CircuitTarget("Child0") + val Child0 = C.module("Child0") + circuitGraph.connectionPath(Child0.ref("in"), Child0.ref("out")) } + } } diff --git a/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala index 06f59a3c..e08b7efc 100644 --- a/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala +++ b/src/test/scala/firrtlTests/analyses/ConnectionGraphSpec.scala @@ -14,9 +14,11 @@ class ConnectionGraphSpec extends FirrtlFlatSpec { "ConnectionGraph" should "build connection graph for rocket-chip" in { ConnectionGraph( - new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(FileUtils.getTextResource("/regress/RocketCore.fir")), UnknownForm) - ).circuit + new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(FileUtils.getTextResource("/regress/RocketCore.fir")), UnknownForm) + ) + .circuit ) } @@ -44,9 +46,11 @@ class ConnectionGraphSpec extends FirrtlFlatSpec { | out <= in |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit "ConnectionGraph" should "work with pathsInDAG" in { val Test = ModuleTarget("Test", "Test") diff --git a/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala index 50ee75ac..b1e9fd73 100644 --- a/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala +++ b/src/test/scala/firrtlTests/analyses/IRLookupSpec.scala @@ -10,7 +10,6 @@ import firrtl.passes.ExpandWhensAndCheck import firrtl.stage.{Forms, TransformManager} import firrtl.testutils.FirrtlFlatSpec - class IRLookupSpec extends FirrtlFlatSpec { "IRLookup" should "return declarations" in { @@ -38,9 +37,11 @@ class IRLookupSpec extends FirrtlFlatSpec { | out <= UInt(1) |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val irLookup = IRLookup(circuit) val Test = ModuleTarget("Test", "Test") val uint8 = UIntType(IntWidth(8)) @@ -49,7 +50,10 @@ class IRLookupSpec extends FirrtlFlatSpec { irLookup.declaration(Test.ref("clk")) shouldBe Port(NoInfo, "clk", Input, ClockType) irLookup.declaration(Test.ref("reset")) shouldBe Port(NoInfo, "reset", Input, UIntType(IntWidth(1))) - val out = Port(NoInfo, "out", Output, + val out = Port( + NoInfo, + "out", + Output, BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))) ) irLookup.declaration(Test.ref("out")) shouldBe out @@ -73,7 +77,8 @@ class IRLookupSpec extends FirrtlFlatSpec { irLookup.declaration(Test.ref("y")) shouldBe DefWire(NoInfo, "y", uint8) irLookup.declaration(Test.ref("@and#0")) shouldBe - DoPrim(PrimOps.And, + DoPrim( + PrimOps.And, Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))), Nil, uint8 @@ -84,12 +89,14 @@ class IRLookupSpec extends FirrtlFlatSpec { irLookup.declaration(Test.ref("child").field("out")) shouldBe inst irLookup.declaration(Test.instOf("child", "Child").ref("out")) shouldBe Port(NoInfo, "out", Output, uint8) - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("missing")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Missing").ref("out")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("missing", "Child").ref("out")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("missing")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.ref("out").field("c")) } - intercept[IllegalArgumentException]{ irLookup.declaration(Test.instOf("child", "Child").ref("out").field("missing")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.instOf("child", "Child").ref("missing")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.instOf("child", "Missing").ref("out")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.instOf("missing", "Child").ref("out")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.ref("missing")) } + intercept[IllegalArgumentException] { irLookup.declaration(Test.ref("out").field("c")) } + intercept[IllegalArgumentException] { + irLookup.declaration(Test.instOf("child", "Child").ref("out").field("missing")) + } } "IRLookup" should "return mem declarations" in { @@ -152,9 +159,11 @@ class IRLookupSpec extends FirrtlFlatSpec { val Readwriter = Mem.field("rw") val allSignals = readerTargets(Reader) ++ writerTargets(Writer) ++ readwriterTargets(Readwriter) - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val irLookup = IRLookup(circuit) val uint8 = UIntType(IntWidth(8)) val mem = DefMemory(NoInfo, "m", uint8, 2, 1, 0, Seq("r"), Seq("w"), Seq("rw")) @@ -188,9 +197,11 @@ class IRLookupSpec extends FirrtlFlatSpec { | out <= UInt(1) |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val irLookup = IRLookup(circuit) val Test = ModuleTarget("Test", "Test") val uint8 = UIntType(IntWidth(8)) @@ -209,7 +220,8 @@ class IRLookupSpec extends FirrtlFlatSpec { val out = Test.ref("out") val outExpr = - WRef("out", + WRef( + "out", BundleType(Seq(Field("a", Default, uint8), Field("b", Default, VectorType(uint8, 2)))), PortKind, SinkFlow @@ -237,8 +249,10 @@ class IRLookupSpec extends FirrtlFlatSpec { check(Test.ref("y"), WRef("y", uint8, WireKind, DuplexFlow)) - check(Test.ref("@and#0"), - DoPrim(PrimOps.And, + check( + Test.ref("@and#0"), + DoPrim( + PrimOps.And, Seq(WRef("y", uint8, WireKind, SourceFlow), DoPrim(AsUInt, Seq(SIntLiteral(-1)), Nil, UIntType(IntWidth(1)))), Nil, uint8 @@ -247,33 +261,34 @@ class IRLookupSpec extends FirrtlFlatSpec { val child = WRef("child", BundleType(Seq(Field("out", Default, uint8))), InstanceKind, SourceFlow) check(Test.ref("child"), child) - check(Test.ref("child").field("out"), - WSubField(child, "out", uint8, SourceFlow) - ) + check(Test.ref("child").field("out"), WSubField(child, "out", uint8, SourceFlow)) } "IRLookup" should "cache expressions" in { def mkType(i: Int): String = { - if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" + if (i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}" } val depth = 500 val input = s"""circuit Test: - | module Test : - | input in: ${mkType(depth)} - | output out: ${mkType(depth)} - | out <= in - |""".stripMargin + | module Test : + | input in: ${mkType(depth)} + | output out: ${mkType(depth)} + | out <= in + |""".stripMargin - val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])).runTransform( - CircuitState(parse(input), UnknownForm) - ).circuit + val circuit = new firrtl.stage.transforms.Compiler(Seq(Dependency[ExpandWhensAndCheck])) + .runTransform( + CircuitState(parse(input), UnknownForm) + ) + .circuit val Test = ModuleTarget("Test", "Test") val irLookup = IRLookup(circuit) def mkReferences(parent: ReferenceTarget, i: Int): Seq[ReferenceTarget] = { - if(i == 0) Seq(parent) else { + if (i == 0) Seq(parent) + else { val newParent = parent.field("x") newParent +: mkReferences(newParent, i - 1) } diff --git a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala index a0d444b3..e134f6e5 100644 --- a/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala +++ b/src/test/scala/firrtlTests/analyses/InstanceGraphTests.scala @@ -9,10 +9,10 @@ import firrtl.testutils._ class InstanceGraphTests extends FirrtlFlatSpec { private def getEdgeSet(graph: DiGraph[String]): collection.Map[String, collection.Set[String]] = { - (graph.getVertices map {v => (v, graph.getEdges(v))}).toMap + (graph.getVertices.map { v => (v, graph.getEdges(v)) }).toMap } - behavior of "InstanceGraph" + behavior.of("InstanceGraph") it should "recognize a simple hierarchy" in { val input = """ @@ -33,7 +33,13 @@ circuit Top : """ val circuit = ToWorkingIR.run(parse(input)) val graph = new InstanceGraph(circuit).graph.transformNodes(_.module) - getEdgeSet(graph) shouldBe Map("Top" -> Set("Child1", "Child2"), "Child1" -> Set("Child1a", "Child1b"), "Child2" -> Set(), "Child1a" -> Set(), "Child1b" -> Set()) + getEdgeSet(graph) shouldBe Map( + "Top" -> Set("Child1", "Child2"), + "Child1" -> Set("Child1a", "Child1b"), + "Child2" -> Set(), + "Child1a" -> Set(), + "Child1b" -> Set() + ) } it should "find hierarchical instances correctly in disconnected hierarchies" in { @@ -97,12 +103,20 @@ circuit Top : """ val circuit = ToWorkingIR.run(parse(input)) val graph = new InstanceGraph(circuit).graph.transformNodes(_.module) - getEdgeSet(graph) shouldBe Map("Top" -> Set("Child1"), "Top2" -> Set("Child2", "Child3"), "Child2" -> Set("Child2a", "Child2b"), "Child1" -> Set(), "Child2a" -> Set(), "Child2b" -> Set(), "Child3" -> Set()) + getEdgeSet(graph) shouldBe Map( + "Top" -> Set("Child1"), + "Top2" -> Set("Child2", "Child3"), + "Child2" -> Set("Child2a", "Child2b"), + "Child1" -> Set(), + "Child2a" -> Set(), + "Child2b" -> Set(), + "Child3" -> Set() + ) } it should "not drop duplicate nodes when they collide as a result of transformNodes" in { val input = -"""circuit Top : + """circuit Top : module Buzz : skip module Fizz : @@ -134,70 +148,70 @@ circuit Top : // experience non-determinism it should "preserve Module declaration order" in { val input = """ - |circuit Top : - | module Top : - | inst c1 of Child1 - | inst c2 of Child2 - | module Child1 : - | inst a of Child1a - | inst b of Child1b - | skip - | module Child1a : - | skip - | module Child1b : - | skip - | module Child2 : - | skip - |""".stripMargin + |circuit Top : + | module Top : + | inst c1 of Child1 + | inst c2 of Child2 + | module Child1 : + | inst a of Child1a + | inst b of Child1b + | skip + | module Child1a : + | skip + | module Child1b : + | skip + | module Child2 : + | skip + |""".stripMargin val circuit = ToWorkingIR.run(parse(input)) val instGraph = new InstanceGraph(circuit) val childMap = instGraph.getChildrenInstances - childMap.keys.toSeq should equal (Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) + childMap.keys.toSeq should equal(Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) } // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism it should "preserve Instance declaration order" in { val input = """ - |circuit Top : - | module Top : - | inst a of Child - | inst b of Child - | inst c of Child - | inst d of Child - | inst e of Child - | inst f of Child - | module Child : - | skip - |""".stripMargin + |circuit Top : + | module Top : + | inst a of Child + | inst b of Child + | inst c of Child + | inst d of Child + | inst e of Child + | inst f of Child + | module Child : + | skip + |""".stripMargin val circuit = ToWorkingIR.run(parse(input)) val instGraph = new InstanceGraph(circuit) val childMap = instGraph.getChildrenInstances val insts = childMap("Top").toSeq.map(_.name) - insts should equal (Seq("a", "b", "c", "d", "e", "f")) + insts should equal(Seq("a", "b", "c", "d", "e", "f")) } // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism it should "have defined fullHierarchy order" in { val input = """ - |circuit Top : - | module Top : - | inst a of Child - | inst b of Child - | inst c of Child - | inst d of Child - | inst e of Child - | module Child : - | skip - |""".stripMargin + |circuit Top : + | module Top : + | inst a of Child + | inst b of Child + | inst c of Child + | inst d of Child + | inst e of Child + | module Child : + | skip + |""".stripMargin val circuit = ToWorkingIR.run(parse(input)) val instGraph = new InstanceGraph(circuit) val hier = instGraph.fullHierarchy - hier.keys.toSeq.map(_.name) should equal (Seq("Top", "a", "b", "c", "d", "e")) + hier.keys.toSeq.map(_.name) should equal(Seq("Top", "a", "b", "c", "d", "e")) } - behavior of "InstanceGraph.staticInstanceCount" + behavior.of("InstanceGraph.staticInstanceCount") it should "report that there is one instance of the top module" in { val input = @@ -207,7 +221,7 @@ circuit Top : |""".stripMargin val iGraph = new InstanceGraph(ToWorkingIR.run(parse(input))) val expectedCounts = Map(OfModule("Foo") -> 1) - iGraph.staticInstanceCount should be (expectedCounts) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report correct number of instances for a sample circuit" in { @@ -225,10 +239,8 @@ circuit Top : | inst bar2 of Bar |""".stripMargin val iGraph = new InstanceGraph(ToWorkingIR.run(parse(input))) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 2, - OfModule("Baz") -> 3) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 2, OfModule("Baz") -> 3) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report zero instances for dead modules" in { @@ -240,12 +252,11 @@ circuit Top : | skip |""".stripMargin val iGraph = new InstanceGraph(ToWorkingIR.run(parse(input))) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 0) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 0) + iGraph.staticInstanceCount should be(expectedCounts) } - behavior of "Reachable/Unreachable helper methods" + behavior.of("Reachable/Unreachable helper methods") they should "report correct reachable/unreachable counts" in { val input = diff --git a/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala b/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala index ec403259..1e486fe4 100644 --- a/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala +++ b/src/test/scala/firrtlTests/analyses/InstanceKeyGraphSpec.scala @@ -9,10 +9,10 @@ import firrtl.graph.DiGraph import firrtl.testutils.FirrtlFlatSpec class InstanceKeyGraphSpec extends FirrtlFlatSpec { - behavior of "InstanceKeyGraph.graph" + behavior.of("InstanceKeyGraph.graph") private def getEdgeSet(graph: DiGraph[String]): collection.Map[String, collection.Set[String]] = { - (graph.getVertices map {v => (v, graph.getEdges(v))}).toMap + (graph.getVertices.map { v => (v, graph.getEdges(v)) }).toMap } it should "recognize a simple hierarchy" in { @@ -37,7 +37,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { getEdgeSet(graph) shouldBe Map( "Top" -> Set("Child1", "Child2"), "Child1" -> Set("Child1a", "Child1b"), - "Child2" -> Set(), "Child1a" -> Set(), "Child1b" -> Set()) + "Child2" -> Set(), + "Child1a" -> Set(), + "Child1b" -> Set() + ) } it should "recognize disconnected hierarchies" in { @@ -69,7 +72,11 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { "Top" -> Set("Child1"), "Top2" -> Set("Child2", "Child3"), "Child2" -> Set("Child2a", "Child2b"), - "Child1" -> Set(), "Child2a" -> Set(), "Child2b" -> Set(), "Child3" -> Set()) + "Child1" -> Set(), + "Child2a" -> Set(), + "Child2b" -> Set(), + "Child3" -> Set() + ) } it should "not drop duplicate nodes when they collide as a result of transformNodes" in { @@ -101,8 +108,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { g2.getEdges("Fizz") shouldBe Set("Foo", "Bar") } - - behavior of "InstanceKeyGraph.getChildInstances" + behavior.of("InstanceKeyGraph.getChildInstances") // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism @@ -126,7 +132,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val circuit = parse(input) val instGraph = InstanceKeyGraph(circuit) val childMap = instGraph.getChildInstances - childMap.map(_._1) should equal (Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) + childMap.map(_._1) should equal(Seq("Top", "Child1", "Child1a", "Child1b", "Child2")) } // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to @@ -148,10 +154,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val instGraph = InstanceKeyGraph(circuit) val childMap = instGraph.getChildInstances.toMap val insts = childMap("Top").map(_.name) - insts should equal (Seq("a", "b", "c", "d", "e", "f")) + insts should equal(Seq("a", "b", "c", "d", "e", "f")) } - behavior of "InstanceKeyGraph.moduleOrder" + behavior.of("InstanceKeyGraph.moduleOrder") it should "compute a correct and deterministic module order" in { val input = """ @@ -180,10 +186,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val instGraph = InstanceKeyGraph(circuit) val order = instGraph.moduleOrder.map(_.name) // Where it has freedom, the instance declaration order will be reversed. - order should equal (Seq("Top", "Child3", "Child4", "Child2", "Child1", "Child1b", "Child1a")) + order should equal(Seq("Top", "Child3", "Child4", "Child2", "Child1", "Child1b", "Child1a")) } - behavior of "InstanceKeyGraph.findInstancesInHierarchy" + behavior.of("InstanceKeyGraph.findInstancesInHierarchy") it should "find hierarchical instances correctly in disconnected hierarchies" in { val input = @@ -221,7 +227,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { iGraph.findInstancesInHierarchy("Child3") shouldBe Nil } - behavior of "InstanceKeyGraph.staticInstanceCount" + behavior.of("InstanceKeyGraph.staticInstanceCount") it should "report that there is one instance of the top module" in { val input = @@ -231,7 +237,7 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { |""".stripMargin val iGraph = InstanceKeyGraph(parse(input)) val expectedCounts = Map(OfModule("Foo") -> 1) - iGraph.staticInstanceCount should be (expectedCounts) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report correct number of instances for a sample circuit" in { @@ -249,10 +255,8 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { | inst bar2 of Bar |""".stripMargin val iGraph = InstanceKeyGraph(parse(input)) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 2, - OfModule("Baz") -> 3) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 2, OfModule("Baz") -> 3) + iGraph.staticInstanceCount should be(expectedCounts) } it should "report zero instances for dead modules" in { @@ -264,12 +268,11 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { | skip |""".stripMargin val iGraph = InstanceKeyGraph(parse(input)) - val expectedCounts = Map(OfModule("Foo") -> 1, - OfModule("Bar") -> 0) - iGraph.staticInstanceCount should be (expectedCounts) + val expectedCounts = Map(OfModule("Foo") -> 1, OfModule("Bar") -> 0) + iGraph.staticInstanceCount should be(expectedCounts) } - behavior of "InstanceKeyGraph.getChildInstanceMap" + behavior.of("InstanceKeyGraph.getChildInstanceMap") it should "preserve Module declaration order" in { val input = """ @@ -302,15 +305,17 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { assert(childMap(OfModule("Child1b")).isEmpty) assert(childMap(OfModule("Child2")).isEmpty) - val topInstances = childMap(OfModule("Top")).map { case (k,v) => k.value -> v.value}.toSeq - assert(topInstances == - Seq("c1" -> "Child1", "c2" -> "Child2", "c3" -> "Child1", "c4" -> "Child1", "c5" -> "Child1")) + val topInstances = childMap(OfModule("Top")).map { case (k, v) => k.value -> v.value }.toSeq + assert( + topInstances == + Seq("c1" -> "Child1", "c2" -> "Child2", "c3" -> "Child1", "c4" -> "Child1", "c5" -> "Child1") + ) - val child1Instance = childMap(OfModule("Child1")).map { case (k,v) => k.value -> v.value}.toSeq + val child1Instance = childMap(OfModule("Child1")).map { case (k, v) => k.value -> v.value }.toSeq assert(child1Instance == Seq("a" -> "Child1a", "b" -> "Child1b")) } - behavior of "InstanceKeyGraph.fullHierarchy" + behavior.of("InstanceKeyGraph.fullHierarchy") // Note that due to optimized implementations of Map1-4, at least 5 entries are needed to // experience non-determinism @@ -329,10 +334,10 @@ class InstanceKeyGraphSpec extends FirrtlFlatSpec { val instGraph = InstanceKeyGraph(parse(input)) val hier = instGraph.fullHierarchy - hier.keys.toSeq.map(_.name) should equal (Seq("Top", "a", "b", "c", "d", "e")) + hier.keys.toSeq.map(_.name) should equal(Seq("Top", "a", "b", "c", "d", "e")) } - behavior of "Reachable/Unreachable helper methods" + behavior.of("Reachable/Unreachable helper methods") they should "report correct reachable/unreachable counts" in { val input = diff --git a/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala b/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala index 58cb3d11..67408bb7 100644 --- a/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/CleanupNamedTargetsSpec.scala @@ -11,7 +11,8 @@ import firrtl.annotations.{ MultiTargetAnnotation, ReferenceTarget, SingleTargetAnnotation, - Target} + Target +} import firrtl.annotations.transforms.CleanupNamedTargets import org.scalatest.flatspec.AnyFlatSpec @@ -56,7 +57,7 @@ class CleanupNamedTargetsSpec extends AnyFlatSpec with Matchers { } - behavior of "CleanupNamedTargets" + behavior.of("CleanupNamedTargets") it should "convert a SingleTargetAnnotation[ReferenceTarget] of an instance to an InstanceTarget" in new F { val annotations: AnnotationSeq = Seq(SingleReferenceAnnotation(barTarget)) @@ -71,10 +72,10 @@ class CleanupNamedTargetsSpec extends AnyFlatSpec with Matchers { val renames = transform.transform(circuitState(annotations)).renames.get - renames.get(barTarget) should be (Some(Seq(foo.instOf("bar", "Bar")))) + renames.get(barTarget) should be(Some(Seq(foo.instOf("bar", "Bar")))) info("and not touch a true ReferenceAnnotation") - renames.get(bazTarget) should be (None) + renames.get(bazTarget) should be(None) } diff --git a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala index 73f36cf0..bb833f0b 100644 --- a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala @@ -6,7 +6,7 @@ import firrtl._ import firrtl.annotations._ import firrtl.annotations.analysis.DuplicationHelper import firrtl.annotations.transforms.{NoSuchTargetException} -import firrtl.transforms.{DontTouchAnnotation, DedupedResult} +import firrtl.transforms.{DedupedResult, DontTouchAnnotation} import firrtl.testutils.{FirrtlMatchers, FirrtlPropSpec} object EliminateTargetPathsSpec { @@ -15,7 +15,7 @@ object EliminateTargetPathsSpec { override def duplicate(n: Target): Annotation = DummyAnnotation(n) } class DummyTransform() extends Transform with ResolvedAnnotationPaths { - override def inputForm: CircuitForm = LowForm + override def inputForm: CircuitForm = LowForm override def outputForm: CircuitForm = LowForm override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DummyAnnotation]) @@ -72,40 +72,47 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { property("Hierarchical tokens should be expanded properly") { val dupMap = DuplicationHelper(inputState.circuit.modules.map(_.name).toSet) - // Only a few instance references dupMap.expandHierarchy(Top_m1_l1_a) dupMap.expandHierarchy(Top_m2_l1_a) dupMap.expandHierarchy(Middle_l1_a) - dupMap.makePathless(Top_m1_l1_a).foreach {Set(TopCircuit.module("Leaf___Top_m1_l1").ref("a")) should contain (_)} - dupMap.makePathless(Top_m2_l1_a).foreach {Set(TopCircuit.module("Leaf___Top_m2_l1").ref("a")) should contain (_)} - dupMap.makePathless(Top_m1_l2_a).foreach {Set(Leaf_a) should contain (_)} - dupMap.makePathless(Top_m2_l2_a).foreach {Set(Leaf_a) should contain (_)} - dupMap.makePathless(Middle_l1_a).foreach {Set( - TopCircuit.module("Leaf___Top_m1_l1").ref("a"), - TopCircuit.module("Leaf___Top_m2_l1").ref("a"), - TopCircuit.module("Leaf___Middle_l1").ref("a") - ) should contain (_) } - dupMap.makePathless(Middle_l2_a).foreach {Set(Leaf_a) should contain (_)} - dupMap.makePathless(Leaf_a).foreach {Set( - TopCircuit.module("Leaf___Top_m1_l1").ref("a"), - TopCircuit.module("Leaf___Top_m2_l1").ref("a"), - TopCircuit.module("Leaf___Middle_l1").ref("a"), - Leaf_a - ) should contain (_)} - dupMap.makePathless(Top).foreach {Set(Top) should contain (_)} - dupMap.makePathless(Middle).foreach {Set( - TopCircuit.module("Middle___Top_m1"), - TopCircuit.module("Middle___Top_m2"), - Middle - ) should contain (_)} - dupMap.makePathless(Leaf).foreach {Set( - TopCircuit.module("Leaf___Top_m1_l1"), - TopCircuit.module("Leaf___Top_m2_l1"), - TopCircuit.module("Leaf___Middle_l1"), - Leaf - ) should contain (_) } + dupMap.makePathless(Top_m1_l1_a).foreach { Set(TopCircuit.module("Leaf___Top_m1_l1").ref("a")) should contain(_) } + dupMap.makePathless(Top_m2_l1_a).foreach { Set(TopCircuit.module("Leaf___Top_m2_l1").ref("a")) should contain(_) } + dupMap.makePathless(Top_m1_l2_a).foreach { Set(Leaf_a) should contain(_) } + dupMap.makePathless(Top_m2_l2_a).foreach { Set(Leaf_a) should contain(_) } + dupMap.makePathless(Middle_l1_a).foreach { + Set( + TopCircuit.module("Leaf___Top_m1_l1").ref("a"), + TopCircuit.module("Leaf___Top_m2_l1").ref("a"), + TopCircuit.module("Leaf___Middle_l1").ref("a") + ) should contain(_) + } + dupMap.makePathless(Middle_l2_a).foreach { Set(Leaf_a) should contain(_) } + dupMap.makePathless(Leaf_a).foreach { + Set( + TopCircuit.module("Leaf___Top_m1_l1").ref("a"), + TopCircuit.module("Leaf___Top_m2_l1").ref("a"), + TopCircuit.module("Leaf___Middle_l1").ref("a"), + Leaf_a + ) should contain(_) + } + dupMap.makePathless(Top).foreach { Set(Top) should contain(_) } + dupMap.makePathless(Middle).foreach { + Set( + TopCircuit.module("Middle___Top_m1"), + TopCircuit.module("Middle___Top_m2"), + Middle + ) should contain(_) + } + dupMap.makePathless(Leaf).foreach { + Set( + TopCircuit.module("Leaf___Top_m1_l1"), + TopCircuit.module("Leaf___Top_m2_l1"), + TopCircuit.module("Leaf___Middle_l1"), + Leaf + ) should contain(_) + } } property("Hierarchical donttouch should be resolved properly") { @@ -159,10 +166,10 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | m2.i <= m1.o | """.stripMargin - canonicalize(outputState.circuit).serialize should be (canonicalize(parse(check)).serialize) + canonicalize(outputState.circuit).serialize should be(canonicalize(parse(check)).serialize) outputState.annotations.collect { case x: DontTouchAnnotation => x.target - } should be (Seq(Top.circuitTarget.module("Leaf___Top_m1_l1").ref("a"))) + } should be(Seq(Top.circuitTarget.module("Leaf___Top_m1_l1").ref("a"))) } property("No name conflicts between old and new modules") { @@ -199,7 +206,7 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } } @@ -239,7 +246,7 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } checks.foreach { line => outputLines should not contain (" module Middle :") @@ -267,19 +274,19 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | m2.i <= m1.o | o <= m2.o """.stripMargin - val e1 = the [CustomTransformException] thrownBy { + val e1 = the[CustomTransformException] thrownBy { val Top_m1 = Top.instOf("m1", "MiddleX") val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m1))) new LowFirrtlCompiler().compile(inputState, customTransforms) } - e1.cause shouldBe a [NoSuchTargetException] + e1.cause shouldBe a[NoSuchTargetException] - val e2 = the [CustomTransformException] thrownBy { + val e2 = the[CustomTransformException] thrownBy { val Top_m2 = Top.instOf("x2", "Middle") val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m2))) new LowFirrtlCompiler().compile(inputState, customTransforms) } - e2.cause shouldBe a [NoSuchTargetException] + e2.cause shouldBe a[NoSuchTargetException] } property("No name conflicts between two new modules") { @@ -320,11 +327,12 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | module Leaf____Middle__l :""".stripMargin.split("\n") val Middle_l1 = CircuitTarget("Top").module("Middle").instOf("_l", "Leaf") val Middle_l2 = CircuitTarget("Top").module("Middle_").instOf("l", "Leaf") - val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Middle_l1), DummyAnnotation(Middle_l2))) + val inputState = + CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Middle_l1), DummyAnnotation(Middle_l2))) val outputState = new LowFirrtlCompiler().compile(inputState, customTransforms) val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } } @@ -362,12 +370,12 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputState = new VerilogCompiler().compile(inputState, customTransforms) val outputLines = outputState.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } } property("It should remove ResolvePaths annotations") { - val input = + val input = """|circuit Foo: | module Bar: | skip @@ -378,7 +386,7 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { CircuitState(passes.ToWorkingIR.run(Parser.parse(input)), UnknownForm, Nil) .resolvePaths(Seq(CircuitTarget("Foo").module("Foo").instOf("bar", "Bar"))) .annotations - .collect{ case a: firrtl.annotations.transforms.ResolvePaths => a } should be (empty) + .collect { case a: firrtl.annotations.transforms.ResolvePaths => a } should be(empty) } property("It should rename module annotations") { @@ -404,16 +412,14 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val parsedCheck = Parser.parse(check) info(output.circuit.serialize) - (output.circuit.serialize) should be (parsedCheck.serialize) + (output.circuit.serialize) should be(parsedCheck.serialize) val newBar_x = CircuitTarget("Foo").module("Bar___Foo_bar").ref("x") - output - .annotations - .filter{ - case _: DeletedAnnotation => false - case _ => true - } should contain allOf (DontTouchAnnotation(newBar_x), DontTouchAnnotation(Bar_x)) + (output.annotations.filter { + case _: DeletedAnnotation => false + case _ => true + } should contain).allOf(DontTouchAnnotation(newBar_x), DontTouchAnnotation(Bar_x)) } property("It should not rename lone instances") { @@ -440,10 +446,10 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { info(output.circuit.serialize) - output.circuit.serialize should be (inputCircuit.serialize) - output.annotations.collect { + output.circuit.serialize should be(inputCircuit.serialize) + (output.annotations.collect { case a: DontTouchAnnotation => a - } should contain allOf ( + } should contain).allOf( DontTouchAnnotation(ModuleTarget("Foo", "Foo").ref("foo")), DontTouchAnnotation(ModuleTarget("Foo", "Bar").ref("foo")), DontTouchAnnotation(ModuleTarget("Foo", "Baz").ref("foo")) @@ -481,12 +487,12 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputLines = output.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } - output.annotations.collect { + (output.annotations.collect { case a: DontTouchAnnotation => a - } should contain allOf ( + } should contain).allOf( DontTouchAnnotation(ModuleTarget("FooBar", "Bar___Foo_bar").ref("baz")), DontTouchAnnotation(ModuleTarget("FooBar", "Bar___Foo_barBar").ref("baz")) ) @@ -527,11 +533,11 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val outputLines = output.circuit.serialize.split("\n") checks.foreach { line => - outputLines should contain (line) + outputLines should contain(line) } - output.annotations.collect { + (output.annotations.collect { case a: DontTouchAnnotation => a - } should contain allOf ( + } should contain).allOf( DontTouchAnnotation(ModuleTarget("Top", "Baz_0").ref("foo")), DontTouchAnnotation(ModuleTarget("Top", "Baz_1").ref("foo")) ) @@ -563,11 +569,13 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { info(output.circuit.serialize) - output.annotations.collect { case a: DontTouchAnnotation => a } should be (Seq( - DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_asdf").ref("foo")), - DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_lkj").ref("foo")), - DontTouchAnnotation(baz.ref("foo")) - )) + output.annotations.collect { case a: DontTouchAnnotation => a } should be( + Seq( + DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_asdf").ref("foo")), + DontTouchAnnotation(ModuleTarget("Top", "Baz___Bar_lkj").ref("foo")), + DontTouchAnnotation(baz.ref("foo")) + ) + ) } property("It should properly rename modules with multiple instances") { @@ -600,6 +608,6 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { val checkDontTouches = (1 to 4).map { i => DummyAnnotation(ModuleTarget("Top", s"Core___System_core_$i")) } - output.annotations.collect { case a: DummyAnnotation => a } should be (checkDontTouches) + output.annotations.collect { case a: DummyAnnotation => a } should be(checkDontTouches) } } diff --git a/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala b/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala index 2c817c23..54a94edb 100644 --- a/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/JsonProtocolSpec.scala @@ -6,20 +6,20 @@ import firrtl._ import firrtl.annotations.{JsonProtocol, NoTargetAnnotation} import firrtl.ir._ import firrtl.options.Dependency -import _root_.logger.{Logger, LogLevel, LogLevelAnnotation} +import _root_.logger.{LogLevel, LogLevelAnnotation, Logger} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should._ case class AnAnnotation( - info: Info, - cir: Circuit, - mod: DefModule, - port: Port, - statement: Statement, - expr: Expression, - tpe: Type, - groundType: GroundType -) extends NoTargetAnnotation + info: Info, + cir: Circuit, + mod: DefModule, + port: Port, + statement: Statement, + expr: Expression, + tpe: Type, + groundType: GroundType) + extends NoTargetAnnotation class AnnoInjector extends Transform with DependencyAPIMigration { override def optionalPrerequisiteOf = Dependency[ChirrtlEmitter] :: Nil @@ -51,16 +51,16 @@ class JsonProtocolSpec extends AnyFlatSpec with Matchers { val inputAnnos = Seq(AnAnnotation(cir.info, cir, mod, port, stmt, expr, tpe, groundType)) val annosString = JsonProtocol.serialize(inputAnnos) val outputAnnos = JsonProtocol.deserialize(annosString) - inputAnnos should be (outputAnnos) + inputAnnos should be(outputAnnos) } "Annotation serialization during logging" should "not throw an exception" in { val compiler = new firrtl.stage.transforms.Compiler(Seq(Dependency[AnnoInjector])) val circuit = Parser.parse(""" - |circuit test : - | module test : - | output out : UInt<1> - | out <= UInt(0) + |circuit test : + | module test : + | output out : UInt<1> + | out <= UInt(0) """.stripMargin) Logger.makeScope(LogLevelAnnotation(LogLevel.Trace) :: Nil) { compiler.execute(CircuitState(circuit, Nil)) diff --git a/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala b/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala index ccf930ba..ac4f2b63 100644 --- a/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/MorphismSpec.scala @@ -16,15 +16,15 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } case class AnAnnotation( - target: Option[CompleteTarget], - from: Option[AnAnnotation] = None, - cause: Option[String] = None - ) extends Annotation { + target: Option[CompleteTarget], + from: Option[AnAnnotation] = None, + cause: Option[String] = None) + extends Annotation { override def update(renames: RenameMap): Seq[AnAnnotation] = { if (target.isDefined) { renames.get(target.get) match { - case None => Seq(this) - case Some(Seq()) => Seq(AnAnnotation(None, Some(this))) + case None => Seq(this) + case Some(Seq()) => Seq(AnAnnotation(None, Some(this))) case Some(targets) => //TODO: Add cause of renaming, requires FIRRTL change to RenameMap targets.map { t => AnAnnotation(Some(t), Some(this)) } @@ -60,7 +60,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { val annotationsx = a.annotations.filter { case a: DeletedAnnotation => false case AnAnnotation(None, _, _) => false - case _: DupedResult => false + case _: DupedResult => false case _: DedupedResult => false case _ => true } @@ -296,8 +296,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { ) } - - behavior of "EliminateTargetPaths" + behavior.of("EliminateTargetPaths") // NOTE: equivalience is defined structurally in this case trait RightInverseEliminateTargetsFixture extends RightInverseFixture with DefaultExample { @@ -393,24 +392,29 @@ class MorphismSpec extends AnyFlatSpec with Matchers { | inst qux of Baz___Top_qux""".stripMargin override val annotations: AnnotationSeq = Seq( AnAnnotation(CircuitTarget("Top").module("Baz").instOf("foo", "Foo")), - ResolvePaths(Seq( - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") - )) + ResolvePaths( + Seq( + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") + ) + ) ) - override val finalAnnotations: Option[AnnotationSeq] = Some(Seq( - AnAnnotation(CircuitTarget("Top").module("Foo___Top_qux_foo")), - AnAnnotation(CircuitTarget("Top").module("Foo___Top_baz_foo")) - )) + override val finalAnnotations: Option[AnnotationSeq] = Some( + Seq( + AnAnnotation(CircuitTarget("Top").module("Foo___Top_qux_foo")), + AnAnnotation(CircuitTarget("Top").module("Foo___Top_baz_foo")) + ) + ) test() } it should "be idempotent with per-module annotations" in new IdempotencyEliminateTargetsFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allModuleInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -418,6 +422,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with per-instance annotations" in new IdempotencyEliminateTargetsFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allAbsoluteInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -425,13 +430,14 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with relative module annotations" in new IdempotencyEliminateTargetsFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allRelative2LevelInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) test() } - behavior of "DedupModules" + behavior.of("DedupModules") trait RightInverseDedupModulesFixture extends RightInverseFixture with DefaultExample { override val f: Seq[Transform] = Seq(new firrtl.annotations.transforms.EliminateTargetPaths) @@ -498,24 +504,29 @@ class MorphismSpec extends AnyFlatSpec with Matchers { | inst qux of Baz""".stripMargin override val annotations: AnnotationSeq = Seq( AnAnnotation(CircuitTarget("Top").module("Baz").instOf("foo", "Foo")), - ResolvePaths(Seq( - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), - CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") - )) + ResolvePaths( + Seq( + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("bar", "Bar"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foox", "Foo"), + CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("bar", "Bar") + ) + ) ) - override val finalAnnotations: Option[AnnotationSeq] = Some(Seq( - AnAnnotation(CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo")), - AnAnnotation(CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo")) - )) + override val finalAnnotations: Option[AnnotationSeq] = Some( + Seq( + AnAnnotation(CircuitTarget("Top").module("Top").instOf("baz", "Baz").instOf("foo", "Foo")), + AnAnnotation(CircuitTarget("Top").module("Top").instOf("qux", "Baz").instOf("foo", "Foo")) + ) + ) test() } it should "be idempotent with per-module annotations" in new IdempotencyDedupModulesFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allModuleInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -523,6 +534,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with per-instance annotations" in new IdempotencyDedupModulesFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allAbsoluteInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) @@ -530,6 +542,7 @@ class MorphismSpec extends AnyFlatSpec with Matchers { } it should "be idempotent with relative module annotations" in new IdempotencyDedupModulesFixture { + /** An endomorphism */ override val annotations: AnnotationSeq = allRelative2LevelInstances.map(AnAnnotation.apply) :+ ResolvePaths(allAbsoluteInstances) diff --git a/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala index cbcd72e9..cc875ea1 100644 --- a/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/TargetDirAnnotationSpec.scala @@ -11,7 +11,6 @@ import firrtl.annotations.{Annotation, NoTargetAnnotation} case object FoundTargetDirTransformRanAnnotation extends NoTargetAnnotation case object FoundTargetDirTransformFoundTargetDirAnnotation extends NoTargetAnnotation - /** Looks for [[TargetDirAnnotation]] */ class FindTargetDirTransform extends Transform { def inputForm = HighForm @@ -19,14 +18,15 @@ class FindTargetDirTransform extends Transform { def execute(state: CircuitState): CircuitState = { val a: Option[Annotation] = state.annotations.collectFirst { - case TargetDirAnnotation("a/b/c") => FoundTargetDirTransformFoundTargetDirAnnotation } + case TargetDirAnnotation("a/b/c") => FoundTargetDirTransformFoundTargetDirAnnotation + } state.copy(annotations = state.annotations ++ a ++ Some(FoundTargetDirTransformRanAnnotation)) } } class TargetDirAnnotationSpec extends FirrtlFlatSpec { - behavior of "The target directory" + behavior.of("The target directory") val input = """circuit Top : @@ -41,37 +41,35 @@ class TargetDirAnnotationSpec extends FirrtlFlatSpec { val findTargetDir = new FindTargetDirTransform // looks for the annotation val optionsManager = new ExecutionOptionsManager("TargetDir") with HasFirrtlOptions { - commonOptions = commonOptions.copy(targetDirName = targetDir, - topName = "Top") - firrtlOptions = firrtlOptions.copy(compilerName = "high", - firrtlSource = Some(input), - customTransforms = Seq(findTargetDir)) + commonOptions = commonOptions.copy(targetDirName = targetDir, topName = "Top") + firrtlOptions = + firrtlOptions.copy(compilerName = "high", firrtlSource = Some(input), customTransforms = Seq(findTargetDir)) } val annotations: Seq[Annotation] = Driver.execute(optionsManager) match { case a: FirrtlExecutionSuccess => a.circuitState.annotations case _ => fail } - annotations should contain (FoundTargetDirTransformRanAnnotation) - annotations should contain (FoundTargetDirTransformFoundTargetDirAnnotation) + annotations should contain(FoundTargetDirTransformRanAnnotation) + annotations should contain(FoundTargetDirTransformFoundTargetDirAnnotation) // Delete created directory val dir = new java.io.File(targetDir) - dir.exists should be (true) - FileUtils.deleteDirectoryHierarchy("a") should be (true) + dir.exists should be(true) + FileUtils.deleteDirectoryHierarchy("a") should be(true) } it should "NOT be available as an annotation when using a raw compiler" in { val findTargetDir = new FindTargetDirTransform // looks for the annotation val compiler = new VerilogCompiler - val circuit = Parser.parse(input split "\n") + val circuit = Parser.parse(input.split("\n")) val annotations: Seq[Annotation] = compiler .compileAndEmit(CircuitState(circuit, HighForm), Seq(findTargetDir)) .annotations // Check that FindTargetDirTransform does not find the annotation - annotations should contain (FoundTargetDirTransformRanAnnotation) + annotations should contain(FoundTargetDirTransformRanAnnotation) annotations should not contain (FoundTargetDirTransformFoundTargetDirAnnotation) } } diff --git a/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala b/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala index 641eeb99..48f27faa 100644 --- a/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/TargetSpec.scala @@ -24,8 +24,9 @@ class TargetSpec extends FirrtlPropSpec { (top.ref("r").index(1).field("hi").clock, "~Circuit|Top>r[1].hi@clock"), (GenericTarget(None, None, Vector(Ref("r"))), "~???|???>r") ) - targets.foreach { case (t, str) => - assert(t.serialize == str, s"$t does not properly serialize") + targets.foreach { + case (t, str) => + assert(t.serialize == str, s"$t does not properly serialize") } } property("Should convert to/from Named") { @@ -38,7 +39,7 @@ class TargetSpec extends FirrtlPropSpec { check(Target(Some("Top"), Some("Top"), r2)) } property("Should enable creating from API") { - val top = ModuleTarget("Top","Top") + val top = ModuleTarget("Top", "Top") val x_reg0_data = top.instOf("x", "X").ref("reg0").field("data") top.instOf("x", "x") top.ref("y") @@ -47,8 +48,14 @@ class TargetSpec extends FirrtlPropSpec { val circuit = CircuitTarget("Circuit") val top = circuit.module("Top") val targets: Seq[Target] = - Seq(circuit, top, top.instOf("i", "I"), top.ref("r"), - top.ref("r").index(1).field("hi").clock, GenericTarget(None, None, Vector(Ref("r")))) + Seq( + circuit, + top, + top.instOf("i", "I"), + top.ref("r"), + top.ref("r").index(1).field("hi").clock, + GenericTarget(None, None, Vector(Ref("r"))) + ) targets.foreach { t => assert(Target.deserialize(t.serialize) == t, s"$t does not properly serialize/deserialize") } @@ -58,25 +65,20 @@ class TargetSpec extends FirrtlPropSpec { val top = circuit.module("B") val targets = Seq( (circuit, "circuit A:"), - (top, - """|circuit A: - |└── module B:""".stripMargin), - (top.instOf("c", "C"), - """|circuit A: - |└── module B: - | └── inst c of C:""".stripMargin), - (top.ref("r"), - """|circuit A: - |└── module B: - | └── r""".stripMargin), - (top.ref("r").index(1).field("hi").clock, - """|circuit A: - |└── module B: - | └── r[1].hi@clock""".stripMargin), - (GenericTarget(None, None, Vector(Ref("r"))), - """|circuit ???: - |└── module ???: - | └── r""".stripMargin) + (top, """|circuit A: + |└── module B:""".stripMargin), + (top.instOf("c", "C"), """|circuit A: + |└── module B: + | └── inst c of C:""".stripMargin), + (top.ref("r"), """|circuit A: + |└── module B: + | └── r""".stripMargin), + (top.ref("r").index(1).field("hi").clock, """|circuit A: + |└── module B: + | └── r[1].hi@clock""".stripMargin), + (GenericTarget(None, None, Vector(Ref("r"))), """|circuit ???: + |└── module ???: + | └── r""".stripMargin) ) targets.foreach { case (t, str) => assert(t.prettyPrint() == str, s"$t didn't properly prettyPrint") } } diff --git a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala index 8b26c80c..68db6873 100644 --- a/src/test/scala/firrtlTests/constraint/InequalitySpec.scala +++ b/src/test/scala/firrtlTests/constraint/InequalitySpec.scala @@ -7,101 +7,109 @@ import org.scalatest.matchers.should.Matchers class InequalitySpec extends AnyFlatSpec with Matchers { - behavior of "Constraints" + behavior.of("Constraints") "IsConstraints" should "reduce properly" in { - IsMin(Closed(0), Closed(1)) should be (Closed(0)) - IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) - IsMax(Closed(-1), Closed(1)) should be (Closed(1)) - IsNeg(IsMul(Closed(-1), Closed(-2))) should be (Closed(-2)) + IsMin(Closed(0), Closed(1)) should be(Closed(0)) + IsMin(Closed(-1), Closed(1)) should be(Closed(-1)) + IsMax(Closed(-1), Closed(1)) should be(Closed(1)) + IsNeg(IsMul(Closed(-1), Closed(-2))) should be(Closed(-2)) val x = IsMin(IsMul(Closed(1), VarCon("a")), Closed(2)) - x.children.toSet should be (IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet) + x.children.toSet should be(IsMin(Closed(2), IsMul(Closed(1), VarCon("a"))).children.toSet) } "IsAdd" should "reduce properly" in { // All constants - IsAdd(Closed(-1), Closed(1)) should be (Closed(0)) + IsAdd(Closed(-1), Closed(1)) should be(Closed(0)) // Pull Out IsMax - IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1)))) - IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsAdd(Closed(1), IsMax(Closed(1), VarCon("a"))) should be(IsMax(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMax(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) ) // Pull Out IsMin - IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1)))) - IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsAdd(Closed(1), IsMin(Closed(1), VarCon("a"))) should be(IsMin(Closed(2), IsAdd(VarCon("a"), Closed(1)))) + IsAdd(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMin(Seq(Closed(2), IsAdd(VarCon("a"), Closed(1)), IsAdd(VarCon("b"), Closed(1)))) ) // Add Zero - IsAdd(Closed(0), VarCon("a")) should be (VarCon("a")) + IsAdd(Closed(0), VarCon("a")) should be(VarCon("a")) // One argument - IsAdd(Seq(VarCon("a"))) should be (VarCon("a")) + IsAdd(Seq(VarCon("a"))) should be(VarCon("a")) } "IsMax" should "reduce properly" in { // All constants - IsMax(Closed(-1), Closed(1)) should be (Closed(1)) + IsMax(Closed(-1), Closed(1)) should be(Closed(1)) // Flatten nested IsMax - IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(1), VarCon("a"))) - IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMax(Closed(1), IsMax(Closed(1), VarCon("a"))) should be(IsMax(Closed(1), VarCon("a"))) + IsMax(Closed(1), IsMax(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMax(Seq(Closed(1), VarCon("a"), VarCon("b"))) ) // Eliminate IsMins if possible - IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be (Closed(2)) - IsMax(Seq( - Closed(2), - IsMin(Closed(1), VarCon("a")), - IsMin(Closed(3), VarCon("b")) - )) should be ( - IsMax(Seq( + IsMax(Closed(2), IsMin(Closed(1), VarCon("a"))) should be(Closed(2)) + IsMax( + Seq( Closed(2), + IsMin(Closed(1), VarCon("a")), IsMin(Closed(3), VarCon("b")) - )) + ) + ) should be( + IsMax( + Seq( + Closed(2), + IsMin(Closed(3), VarCon("b")) + ) + ) ) // One argument - IsMax(Seq(VarCon("a"))) should be (VarCon("a")) - IsMax(Seq(Closed(0))) should be (Closed(0)) - IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be (IsMin(VarCon("a"), Closed(0))) + IsMax(Seq(VarCon("a"))) should be(VarCon("a")) + IsMax(Seq(Closed(0))) should be(Closed(0)) + IsMax(Seq(IsMin(VarCon("a"), Closed(0)))) should be(IsMin(VarCon("a"), Closed(0))) } "IsMin" should "reduce properly" in { // All constants - IsMin(Closed(-1), Closed(1)) should be (Closed(-1)) + IsMin(Closed(-1), Closed(1)) should be(Closed(-1)) // Flatten nested IsMin - IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(1), VarCon("a"))) - IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be ( + IsMin(Closed(1), IsMin(Closed(1), VarCon("a"))) should be(IsMin(Closed(1), VarCon("a"))) + IsMin(Closed(1), IsMin(Seq(Closed(1), VarCon("a"), VarCon("b")))) should be( IsMin(Seq(Closed(1), VarCon("a"), VarCon("b"))) ) // Eliminate IsMaxs if possible - IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be (Closed(1)) - IsMin(Seq( - Closed(2), - IsMax(Closed(1), VarCon("a")), - IsMax(Closed(3), VarCon("b")) - )) should be ( - IsMin(Seq( + IsMin(Closed(1), IsMax(Closed(2), VarCon("a"))) should be(Closed(1)) + IsMin( + Seq( Closed(2), - IsMax(Closed(1), VarCon("a")) - )) + IsMax(Closed(1), VarCon("a")), + IsMax(Closed(3), VarCon("b")) + ) + ) should be( + IsMin( + Seq( + Closed(2), + IsMax(Closed(1), VarCon("a")) + ) + ) ) // One argument - IsMin(Seq(VarCon("a"))) should be (VarCon("a")) - IsMin(Seq(Closed(0))) should be (Closed(0)) - IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be (IsMax(VarCon("a"), Closed(0))) + IsMin(Seq(VarCon("a"))) should be(VarCon("a")) + IsMin(Seq(Closed(0))) should be(Closed(0)) + IsMin(Seq(IsMax(VarCon("a"), Closed(0)))) should be(IsMax(VarCon("a"), Closed(0))) } "IsMul" should "reduce properly" in { // All constants - IsMul(Closed(2), Closed(3)) should be (Closed(6)) + IsMul(Closed(2), Closed(3)) should be(Closed(6)) // Pull out max, if positive stays max IsMul(Closed(2), IsMax(Closed(3), VarCon("a"))) should be( @@ -124,75 +132,74 @@ class InequalitySpec extends AnyFlatSpec with Matchers { ) // Times zero - IsMul(Closed(0), VarCon("x")) should be (Closed(0)) + IsMul(Closed(0), VarCon("x")) should be(Closed(0)) // Times 1 - IsMul(Closed(1), VarCon("x")) should be (VarCon("x")) + IsMul(Closed(1), VarCon("x")) should be(VarCon("x")) // One argument - IsMul(Seq(Closed(0))) should be (Closed(0)) - IsMul(Seq(VarCon("a"))) should be (VarCon("a")) + IsMul(Seq(Closed(0))) should be(Closed(0)) + IsMul(Seq(VarCon("a"))) should be(VarCon("a")) // No optimizations val isMax = IsMax(VarCon("x"), VarCon("y")) val isMin = IsMin(VarCon("x"), VarCon("y")) val a = VarCon("a") - IsMul(a, isMax).children should be (Vector(a, isMax)) //non-known multiply - IsMul(a, isMin).children should be (Vector(a, isMin)) //non-known multiply - IsMul(Seq(Closed(2), isMin, isMin)).children should be (Vector(Closed(2), isMin, isMin)) //>1 min - IsMul(Seq(Closed(2), isMax, isMax)).children should be (Vector(Closed(2), isMax, isMax)) //>1 max - IsMul(Seq(Closed(2), isMin, isMax)).children should be (Vector(Closed(2), isMin, isMax)) //mixed min/max + IsMul(a, isMax).children should be(Vector(a, isMax)) //non-known multiply + IsMul(a, isMin).children should be(Vector(a, isMin)) //non-known multiply + IsMul(Seq(Closed(2), isMin, isMin)).children should be(Vector(Closed(2), isMin, isMin)) //>1 min + IsMul(Seq(Closed(2), isMax, isMax)).children should be(Vector(Closed(2), isMax, isMax)) //>1 max + IsMul(Seq(Closed(2), isMin, isMax)).children should be(Vector(Closed(2), isMin, isMax)) //mixed min/max } "IsNeg" should "reduce properly" in { // All constants - IsNeg(Closed(1)) should be (Closed(-1)) + IsNeg(Closed(1)) should be(Closed(-1)) // Pull out max - IsNeg(IsMax(Closed(1), VarCon("a"))) should be (IsMin(Closed(-1), IsNeg(VarCon("a")))) + IsNeg(IsMax(Closed(1), VarCon("a"))) should be(IsMin(Closed(-1), IsNeg(VarCon("a")))) // Pull out min - IsNeg(IsMin(Closed(1), VarCon("a"))) should be (IsMax(Closed(-1), IsNeg(VarCon("a")))) + IsNeg(IsMin(Closed(1), VarCon("a"))) should be(IsMax(Closed(-1), IsNeg(VarCon("a")))) // Pull out add - IsNeg(IsAdd(Closed(1), VarCon("a"))) should be (IsAdd(Closed(-1), IsNeg(VarCon("a")))) + IsNeg(IsAdd(Closed(1), VarCon("a"))) should be(IsAdd(Closed(-1), IsNeg(VarCon("a")))) // Pull out mul - IsNeg(IsMul(Closed(2), VarCon("a"))) should be (IsMul(Closed(-2), VarCon("a"))) + IsNeg(IsMul(Closed(2), VarCon("a"))) should be(IsMul(Closed(-2), VarCon("a"))) // No optimizations // (pow), (floor?) - IsNeg(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) - IsNeg(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + IsNeg(IsPow(VarCon("x"))).children should be(Vector(IsPow(VarCon("x")))) + IsNeg(IsFloor(VarCon("x"))).children should be(Vector(IsFloor(VarCon("x")))) } "IsPow" should "reduce properly" in { // All constants - IsPow(Closed(1)) should be (Closed(2)) + IsPow(Closed(1)) should be(Closed(2)) // Pull out max - IsPow(IsMax(Closed(1), VarCon("a"))) should be (IsMax(Closed(2), IsPow(VarCon("a")))) + IsPow(IsMax(Closed(1), VarCon("a"))) should be(IsMax(Closed(2), IsPow(VarCon("a")))) // Pull out min - IsPow(IsMin(Closed(1), VarCon("a"))) should be (IsMin(Closed(2), IsPow(VarCon("a")))) + IsPow(IsMin(Closed(1), VarCon("a"))) should be(IsMin(Closed(2), IsPow(VarCon("a")))) // Pull out add - IsPow(IsAdd(Closed(1), VarCon("a"))) should be (IsMul(Closed(2), IsPow(VarCon("a")))) + IsPow(IsAdd(Closed(1), VarCon("a"))) should be(IsMul(Closed(2), IsPow(VarCon("a")))) // No optimizations // (mul), (pow), (floor?) - IsPow(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) - IsPow(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) - IsPow(IsFloor(VarCon("x"))).children should be (Vector(IsFloor(VarCon("x")))) + IsPow(IsMul(Closed(2), VarCon("x"))).children should be(Vector(IsMul(Closed(2), VarCon("x")))) + IsPow(IsPow(VarCon("x"))).children should be(Vector(IsPow(VarCon("x")))) + IsPow(IsFloor(VarCon("x"))).children should be(Vector(IsFloor(VarCon("x")))) } "IsFloor" should "reduce properly" in { // All constants - IsFloor(Closed(1.9)) should be (Closed(1)) - IsFloor(Closed(-1.9)) should be (Closed(-2)) + IsFloor(Closed(1.9)) should be(Closed(1)) + IsFloor(Closed(-1.9)) should be(Closed(-2)) // Pull out max - IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be (IsMax(Closed(1), IsFloor(VarCon("a")))) + IsFloor(IsMax(Closed(1.9), VarCon("a"))) should be(IsMax(Closed(1), IsFloor(VarCon("a")))) // Pull out min - IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be (IsMin(Closed(1), IsFloor(VarCon("a")))) + IsFloor(IsMin(Closed(1.9), VarCon("a"))) should be(IsMin(Closed(1), IsFloor(VarCon("a")))) // Cancel with another floor - IsFloor(IsFloor(VarCon("a"))) should be (IsFloor(VarCon("a"))) + IsFloor(IsFloor(VarCon("a"))) should be(IsFloor(VarCon("a"))) // No optimizations // (add), (mul), (pow) - IsFloor(IsMul(Closed(2), VarCon("x"))).children should be (Vector(IsMul(Closed(2), VarCon("x")))) - IsFloor(IsPow(VarCon("x"))).children should be (Vector(IsPow(VarCon("x")))) - IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be (Vector(IsAdd(Closed(1), VarCon("x")))) + IsFloor(IsMul(Closed(2), VarCon("x"))).children should be(Vector(IsMul(Closed(2), VarCon("x")))) + IsFloor(IsPow(VarCon("x"))).children should be(Vector(IsPow(VarCon("x")))) + IsFloor(IsAdd(Closed(1), VarCon("x"))).children should be(Vector(IsAdd(Closed(1), VarCon("x")))) } } - diff --git a/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala b/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala index 7d250664..0a50e53e 100644 --- a/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala +++ b/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala @@ -31,18 +31,18 @@ object ExecutionTestHelper { // Generate test step counter, create ExecutionTestHelper that represents initial test state val cnt = DefRegister(NoInfo, DUTRules.counter.name, counterType, DUTRules.clock, DUTRules.reset, Utils.zero) - val inc = Connect(NoInfo, DUTRules.counter, DoPrim(PrimOps.Add, Seq(DUTRules.counter, UIntLiteral(1)), Nil, UnknownType)) + val inc = + Connect(NoInfo, DUTRules.counter, DoPrim(PrimOps.Add, Seq(DUTRules.counter, UIntLiteral(1)), Nil, UnknownType)) ExecutionTestHelper(c, Seq(cnt, inc), Map.empty[Expression, Expression], Nil, Nil) } } case class ExecutionTestHelper( - dut: Circuit, - setup: Seq[Statement], - pokeRegs: Map[Expression, Expression], + dut: Circuit, + setup: Seq[Statement], + pokeRegs: Map[Expression, Expression], completedSteps: Seq[Conditionally], - activeStep: Seq[Statement] -) { + activeStep: Seq[Statement]) { def step(n: Int): ExecutionTestHelper = { require(n > 0, "Step length must be positive") @@ -52,9 +52,7 @@ case class ExecutionTestHelper( def poke(expString: String, value: Literal): ExecutionTestHelper = { val pokeExp = ParseExpression(expString) val pokeable = ensurePokeable(pokeExp) - pokeable.addStatements( - Connect(NoInfo, pokeExp, value), - Connect(NoInfo, pokeable.pokeRegs(pokeExp), value)) + pokeable.addStatements(Connect(NoInfo, pokeExp, value), Connect(NoInfo, pokeable.pokeRegs(pokeExp), value)) } def invalidate(expString: String): ExecutionTestHelper = { @@ -85,7 +83,7 @@ case class ExecutionTestHelper( } private def top: Module = { - dut.modules.collectFirst({ case m: Module if m.name == dut.main => m }).get + dut.modules.collectFirst({ case m: Module if m.name == dut.main => m }).get } private[execution] def emit: Circuit = { diff --git a/src/test/scala/firrtlTests/execution/ParserHelpers.scala b/src/test/scala/firrtlTests/execution/ParserHelpers.scala index 3472c19c..1f74d634 100644 --- a/src/test/scala/firrtlTests/execution/ParserHelpers.scala +++ b/src/test/scala/firrtlTests/execution/ParserHelpers.scala @@ -14,10 +14,10 @@ object ParseStatement { val indent = " " val indented = stmtStr.split("\n").mkString(indent, s"\n${indent}", "") s"""circuit ${DUTRules.dutName} : - | module ${DUTRules.dutName} : - | input clock : Clock - | input reset : UInt<1> - |${indented}""".stripMargin + | module ${DUTRules.dutName} : + | input clock : Clock + | input reset : UInt<1> + |${indented}""".stripMargin } private def parse(stmtStr: String): Circuit = { diff --git a/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala b/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala index 2654f476..911f7485 100644 --- a/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala +++ b/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala @@ -20,19 +20,19 @@ trait TestExecution { /** * A class that makes it easier to write execution-driven tests. - * + * * By combining a DUT body (supplied as a string without an enclosing * module or circuit) with a sequence of test operations, an * executable, self-contained Verilog testbench may be automatically * created and checked. - * + * * @note It is necessary to mix in a trait extending TestExecution * @note The DUT has two implicit ports, "clock" and "reset" * @note Execution of the command sequences begins after reset is deasserted - * + * * @see [[firrtlTests.execution.TestExecution]] * @see [[firrtlTests.execution.VerilogExecution]] - * + * * @example {{{ * class AndTester extends SimpleExecutionTest with VerilogExecution { * val body = "reg r : UInt<32>, clock with: (reset => (reset, UInt<32>(0)))" @@ -64,9 +64,9 @@ abstract class SimpleExecutionTest extends FirrtlPropSpec { def commands: Seq[SimpleTestCommand] private def interpretCommand(eth: ExecutionTestHelper, cmd: SimpleTestCommand) = cmd match { - case Step(n) => eth.step(n) - case Invalidate(expStr) => eth.invalidate(expStr) - case Poke(expStr, value) => eth.poke(expStr, UIntLiteral(value)) + case Step(n) => eth.step(n) + case Invalidate(expStr) => eth.invalidate(expStr) + case Poke(expStr, value) => eth.poke(expStr, UIntLiteral(value)) case Expect(expStr, value) => eth.expect(expStr, UIntLiteral(value)) } diff --git a/src/test/scala/firrtlTests/execution/VerilogExecution.scala b/src/test/scala/firrtlTests/execution/VerilogExecution.scala index 89f27609..913cfc71 100644 --- a/src/test/scala/firrtlTests/execution/VerilogExecution.scala +++ b/src/test/scala/firrtlTests/execution/VerilogExecution.scala @@ -30,7 +30,7 @@ trait VerilogExecution extends TestExecution { // Make and run Verilog simulation verilogToCpp(c.main, testDir, Nil, harness) #&& - cppToExe(c.main, testDir) ! loggingProcessLogger + cppToExe(c.main, testDir) ! loggingProcessLogger assert(executeExpectingSuccess(c.main, testDir)) } } diff --git a/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala b/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala index 4299ac7f..7ab80387 100644 --- a/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala +++ b/src/test/scala/firrtlTests/features/LetterCaseTransformSpec.scala @@ -15,7 +15,7 @@ import org.scalatest.matchers.should.Matchers class LetterCaseTransformSpec extends AnyFlatSpec with Matchers { case class TrackingAnnotation(val target: IsMember) extends SingleTargetAnnotation[IsMember] { - override def duplicate(a: IsMember) = this.copy(target=a) + override def duplicate(a: IsMember) = this.copy(target = a) } class CircuitFixture { @@ -66,72 +66,94 @@ class LetterCaseTransformSpec extends AnyFlatSpec with Matchers { private val Foo = CircuitTarget("Foo") private val Bar = Foo.module("Bar") - val annotations = Seq(TrackingAnnotation(Foo.module("Foo").ref("MeM").field("wRITE")field("en")), - ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar)), Dependency[LowerCaseNames]), - ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar.ref("OuT"))), Dependency[UpperCaseNames])) + val annotations = Seq( + TrackingAnnotation(Foo.module("Foo").ref("MeM").field("wRITE").field("en")), + ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar)), Dependency[LowerCaseNames]), + ManipulateNamesBlocklistAnnotation(Seq(Seq(Bar.ref("OuT"))), Dependency[UpperCaseNames]) + ) val state = CircuitState(Parser.parse(input), annotations) } - behavior of "LowerCaseNames" + behavior.of("LowerCaseNames") it should "change all names to lowercase" in new CircuitFixture { val tm = new firrtl.stage.transforms.Compiler(Seq(firrtl.options.Dependency[LowerCaseNames])) val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "foo") => true }, - { case ir.Module(_, "foo", - Seq(ir.Port(_, "clk", _, _), ir.Port(_, "rst_p", _, _), ir.Port(_, "addr", _, _)), _) => true }, - /* Module "Bar" should be skipped via a ManipulateNamesBlocklistAnnotation */ - { case ir.Module(_, "Bar", Seq(ir.Port(_, "out", _, _)), _) => true }, + { + case ir + .Module(_, "foo", Seq(ir.Port(_, "clk", _, _), ir.Port(_, "rst_p", _, _), ir.Port(_, "addr", _, _)), _) => + true + }, + /* Module "Bar" should be skipped via a ManipulateNamesBlocklistAnnotation */ { + case ir.Module(_, "Bar", Seq(ir.Port(_, "out", _, _)), _) => true + }, { case ir.Module(_, "baz_0", Seq(ir.Port(_, "out", _, _)), _) => true }, { case ir.Module(_, "baz", Seq(ir.Port(_, "out", _, _)), _) => true }, - /* External module "Ext" is not renamed */ - { case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true }, + /* External module "Ext" is not renamed */ { + case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true + }, { case ir.DefNode(_, "bar", _) => true }, { case ir.DefRegister(_, "baz", _, WRef("clk", _, _, _), WRef("rst_p", _, _, _), WRef("bar", _, _, _)) => true }, { case ir.DefWire(_, "qux", _) => true }, { case ir.Connect(_, WRef("qux", _, _, _), _) => true }, { case ir.DefNode(_, "quuxquux", _) => true }, { case ir.DefMemory(_, "mem", _, _, _, _, Seq("read"), Seq("write"), Seq("rw"), _) => true }, - /* Ports of memories should be ignored, but these are already lower case */ - { case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "read", _, _), "addr", _, _)) => true }, + /* Ports of memories should be ignored, but these are already lower case */ { + case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "read", _, _), "addr", _, _)) => true + }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "write", _, _), "addr", _, _)) => true }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("mem", _, _, _), "rw", _, _), "addr", _, _)) => true }, /* Module "Bar" was skipped via a ManipulateNamesBlocklistAnnotation. The instance "SuB1" is renamed to "sub1_0" * because node "sub1" already exists. This differs from the upper case test. - */ - { case WDefInstance(_, "sub1_0", "Bar", _) => true }, + */ { case WDefInstance(_, "sub1_0", "Bar", _) => true }, { case WDefInstance(_, "sub2", "baz_0", _) => true }, { case WDefInstance(_, "sub3", "baz", _) => true }, - /* External module instance names are renamed */ - { case WDefInstance(_, "sub4", "Ext", _) => true }, + /* External module instance names are renamed */ { case WDefInstance(_, "sub4", "Ext", _) => true }, { case ir.DefNode(_, "sub1", _) => true }, { case ir.DefNode(_, "corge_corge", WSubField(WRef("sub1_0", _, _, _), "out", _, _)) => true }, - { case ir.DefNode(_, "quuzquuz", - ir.DoPrim(_,Seq(WSubField(WRef("sub2", _, _, _), "out", _, _), - WSubField(WRef("sub3", _, _, _), "out", _, _)), _, _)) => true }, - /* References to external module ports are not renamed, e.g., OuT */ - { case ir.DefNode(_, "graultgrault", - ir.DoPrim(_, Seq(WSubField(WRef("sub4", _, _, _), "OuT", _, _)), _, _)) => true } + { + case ir.DefNode( + _, + "quuzquuz", + ir.DoPrim( + _, + Seq(WSubField(WRef("sub2", _, _, _), "out", _, _), WSubField(WRef("sub3", _, _, _), "out", _, _)), + _, + _ + ) + ) => + true + }, + /* References to external module ports are not renamed, e.g., OuT */ { + case ir.DefNode(_, "graultgrault", ir.DoPrim(_, Seq(WSubField(WRef("sub4", _, _, _), "OuT", _, _)), _, _)) => + true + } ) - expected.foreach( statex should containTree (_) ) + expected.foreach(statex should containTree(_)) } - behavior of "UpperCaseNames" + behavior.of("UpperCaseNames") it should "change all names to uppercase" in new CircuitFixture { val tm = new firrtl.stage.transforms.Compiler(Seq(firrtl.options.Dependency[UpperCaseNames])) val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "FOO") => true }, - { case ir.Module(_, "FOO", - Seq(ir.Port(_, "CLK", _, _), ir.Port(_, "RST_P", _, _), ir.Port(_, "ADDR", _, _)), _) => true }, - /* "Bar>OuT" should be skipped via a ManipulateNamesBlocklistAnnotation */ - { case ir.Module(_, "BAR", Seq(ir.Port(_, "OuT", _, _)), _) => true }, + { + case ir + .Module(_, "FOO", Seq(ir.Port(_, "CLK", _, _), ir.Port(_, "RST_P", _, _), ir.Port(_, "ADDR", _, _)), _) => + true + }, + /* "Bar>OuT" should be skipped via a ManipulateNamesBlocklistAnnotation */ { + case ir.Module(_, "BAR", Seq(ir.Port(_, "OuT", _, _)), _) => true + }, { case ir.Module(_, "BAZ", Seq(ir.Port(_, "OUT", _, _)), _) => true }, { case ir.Module(_, "BAZ_0", Seq(ir.Port(_, "OUT", _, _)), _) => true }, - /* External module "Ext" is not renamed */ - { case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true }, + /* External module "Ext" is not renamed */ { + case ir.ExtModule(_, "Ext", Seq(ir.Port(_, "OuT", _, _)), _, _) => true + }, { case ir.DefNode(_, "BAR", _) => true }, { case ir.DefRegister(_, "BAZ", _, WRef("CLK", _, _, _), WRef("RST_P", _, _, _), WRef("BAR", _, _, _)) => true }, { case ir.DefWire(_, "QUX", _) => true }, @@ -140,28 +162,42 @@ class LetterCaseTransformSpec extends AnyFlatSpec with Matchers { { case ir.DefMemory(_, "MEM", _, _, _, _, Seq("READ"), Seq("WRITE"), Seq("RW"), _) => true }, /* Ports of memories should be ignored while readers/writers are renamed, e.g., "Read" is converted to upper case * while "addr" is not touched. - */ - { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "READ", _, _), "addr", _, _)) => true }, + */ { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "READ", _, _), "addr", _, _)) => true }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "WRITE", _, _), "addr", _, _)) => true }, { case ir.IsInvalid(_, WSubField(WSubField(WRef("MEM", _, _, _), "RW", _, _), "addr", _, _)) => true }, { case WDefInstance(_, "SUB1", "BAR", _) => true }, - /* Instance "SuB2" and "SuB3" switch their modules from the lower case test due to namespace behavior. */ - { case WDefInstance(_, "SUB2", "BAZ", _) => true }, + /* Instance "SuB2" and "SuB3" switch their modules from the lower case test due to namespace behavior. */ { + case WDefInstance(_, "SUB2", "BAZ", _) => true + }, { case WDefInstance(_, "SUB3", "BAZ_0", _) => true }, - /* External module "Ext" was skipped via a ManipulateBlocklistAnnotation */ - { case WDefInstance(_, "SUB4", "Ext", _) => true }, - /* Node "sub1" becomes "SUB1_0" because instance "SuB1" already got the "SUB1" name. */ - { case ir.DefNode(_, "SUB1_0", _) => true }, - /* Port "OuT" was skipped via a ManipulateNamesBlocklistAnnotation */ - { case ir.DefNode(_, "CORGE_CORGE", WSubField(WRef("SUB1", _, _, _), "OuT", _, _)) => true }, - { case ir.DefNode(_, "QUUZQUUZ", - ir.DoPrim(_,Seq(WSubField(WRef("SUB2", _, _, _), "OUT", _, _), - WSubField(WRef("SUB3", _, _, _), "OUT", _, _)), _, _)) => true }, - /* References to external module ports are not renamed, e.g., "OuT" */ - { case ir.DefNode(_, "GRAULTGRAULT", - ir.DoPrim(_, Seq(WSubField(WRef("SUB4", _, _, _), "OuT", _, _)), _, _)) => true } + /* External module "Ext" was skipped via a ManipulateBlocklistAnnotation */ { + case WDefInstance(_, "SUB4", "Ext", _) => true + }, + /* Node "sub1" becomes "SUB1_0" because instance "SuB1" already got the "SUB1" name. */ { + case ir.DefNode(_, "SUB1_0", _) => true + }, + /* Port "OuT" was skipped via a ManipulateNamesBlocklistAnnotation */ { + case ir.DefNode(_, "CORGE_CORGE", WSubField(WRef("SUB1", _, _, _), "OuT", _, _)) => true + }, + { + case ir.DefNode( + _, + "QUUZQUUZ", + ir.DoPrim( + _, + Seq(WSubField(WRef("SUB2", _, _, _), "OUT", _, _), WSubField(WRef("SUB3", _, _, _), "OUT", _, _)), + _, + _ + ) + ) => + true + }, + /* References to external module ports are not renamed, e.g., "OuT" */ { + case ir.DefNode(_, "GRAULTGRAULT", ir.DoPrim(_, Seq(WSubField(WRef("SUB4", _, _, _), "OuT", _, _)), _, _)) => + true + } ) - expected.foreach( statex should containTree (_) ) + expected.foreach(statex should containTree(_)) } } diff --git a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala index c4de1f46..a41ac90a 100644 --- a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala @@ -2,21 +2,21 @@ package firrtlTests.fixed -import firrtl.{CircuitState, ChirrtlForm, LowFirrtlCompiler} +import firrtl.{ChirrtlForm, CircuitState, LowFirrtlCompiler} import firrtl.testutils.FirrtlFlatSpec class FixedPointMathSpec extends FirrtlFlatSpec { - val SumPattern = """.*output sum.*<(\d+)>.*.*""".r - val ProductPattern = """.*output product.*<(\d+)>.*""".r + val SumPattern = """.*output sum.*<(\d+)>.*.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r val DifferencePattern = """.*output difference.*<(\d+)>.*""".r - val AssignPattern = """\s*(\w+) <= (\w+)\((.*)\)\s*""".r + val AssignPattern = """\s*(\w+) <= (\w+)\((.*)\)\s*""".r for { - bits1 <- 1 to 4 + bits1 <- 1 to 4 binaryPoint1 <- 1 to 4 - bits2 <- 1 to 4 + bits2 <- 1 to 4 binaryPoint2 <- 1 to 4 } { def config = s"($bits1,$binaryPoint1)($bits2,$binaryPoint2)" @@ -25,26 +25,26 @@ class FixedPointMathSpec extends FirrtlFlatSpec { val input = s"""circuit Unit : - | module Unit : - | input a : Fixed<$bits1><<$binaryPoint1>> - | input b : Fixed<$bits2><<$binaryPoint2>> - | output sum : Fixed - | output product : Fixed - | output difference : Fixed - | sum <= add(a, b) - | product <= mul(a, b) - | difference <= sub(a, b) - | """.stripMargin + | module Unit : + | input a : Fixed<$bits1><<$binaryPoint1>> + | input b : Fixed<$bits2><<$binaryPoint2>> + | output sum : Fixed + | output product : Fixed + | output difference : Fixed + | sum <= add(a, b) + | product <= mul(a, b) + | difference <= sub(a, b) + | """.stripMargin val lowerer = new LowFirrtlCompiler val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - val output = res.getEmittedCircuit.value split "\n" + val output = res.getEmittedCircuit.value.split("\n") def inferredAddWidth: Int = { val binaryDifference = binaryPoint1 - binaryPoint2 - val (newW1, newW2) = if(binaryDifference > 0) { + val (newW1, newW2) = if (binaryDifference > 0) { (bits1, bits2 + binaryDifference) } else { (bits1 + binaryDifference.abs, bits2) @@ -54,11 +54,11 @@ class FixedPointMathSpec extends FirrtlFlatSpec { for (line <- output) { line match { - case SumPattern(varWidth) => + case SumPattern(varWidth) => assert(varWidth.toInt === inferredAddWidth, s"$config sum sint bits wrong for $line") case ProductPattern(varWidth) => assert(varWidth.toInt === bits1 + bits2, s"$config product bits wrong for $line") - case DifferencePattern(varWidth) => + case DifferencePattern(varWidth) => assert(varWidth.toInt === inferredAddWidth, s"$config difference bits wrong for $line") case AssignPattern(varName, operation, args) => varName match { @@ -66,11 +66,15 @@ class FixedPointMathSpec extends FirrtlFlatSpec { assert(operation === "add", s"var sum should be result of an add in $line") if (binaryPoint1 > binaryPoint2) { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") - assert(args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), + s"$config second arg incorrect in $line" + ) } else if (binaryPoint1 < binaryPoint2) { - assert(args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), + s"$config second arg incorrect in $line" + ) assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") } else { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") @@ -84,11 +88,15 @@ class FixedPointMathSpec extends FirrtlFlatSpec { assert(operation === "sub", s"var difference should be result of an sub in $line") if (binaryPoint1 > binaryPoint2) { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") - assert(args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), + s"$config second arg incorrect in $line" + ) } else if (binaryPoint1 < binaryPoint2) { - assert(args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), - s"$config second arg incorrect in $line") + assert( + args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), + s"$config second arg incorrect in $line" + ) assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") } else { assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") @@ -102,4 +110,3 @@ class FixedPointMathSpec extends FirrtlFlatSpec { } } } - diff --git a/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala b/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala index db107cb3..68125bc0 100644 --- a/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedSerializationSpec.scala @@ -7,7 +7,7 @@ import firrtl.ir import org.scalatest.flatspec.AnyFlatSpec class FixedSerializationSpec extends AnyFlatSpec { - behavior of "FixedType" + behavior.of("FixedType") it should "serialize correctly" in { assert(ir.FixedType(ir.IntWidth(3), ir.IntWidth(2)).serialize == "Fixed<3><<2>>") @@ -16,7 +16,7 @@ class FixedSerializationSpec extends AnyFlatSpec { assert(ir.FixedType(ir.UnknownWidth, ir.UnknownWidth).serialize == "Fixed") } - behavior of "FixedLiteral" + behavior.of("FixedLiteral") it should "serialize correctly" in { assert(ir.FixedLiteral(1, ir.IntWidth(3), ir.IntWidth(2)).serialize == "Fixed<3><<2>>(\"h1\")") diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala index 1a7092bb..4d3dbe98 100644 --- a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -9,12 +9,14 @@ import firrtl.testutils._ class FixedTypeInferenceSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - val lines = c.serialize.split("\n") map normalized + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -29,7 +31,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -46,7 +49,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input c : Fixed<4><<3>> | output d : Fixed<13><<3>> | d <= add(a, add(b, c))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "infer add correctly" in { @@ -59,7 +62,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -76,7 +80,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input c : Fixed<4><<3>> | output d : Fixed<15><<3>> | d <= add(a, add(b, c))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be correctly shifted left" in { @@ -89,7 +93,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -102,7 +107,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<12><<2>> | d <= shl(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be correctly shifted right" in { @@ -115,7 +120,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -128,7 +134,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<8><<2>> | d <= shr(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "relatively move binary point left" in { @@ -141,7 +147,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -154,7 +161,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<12><<4>> | d <= incp(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "relatively move binary point right" in { @@ -167,7 +174,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -180,7 +188,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<8><<0>> | d <= decp(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "absolutely set binary point correctly" in { @@ -193,7 +201,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -206,7 +215,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : Fixed<10><<2>> | output d : Fixed<11><<3>> | d <= setp(a, 3)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "cat, head, tail, bits" in { @@ -219,7 +228,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -246,7 +256,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | head <= head(a, 3) | tail <= tail(a, 3) | bits <= bits(a, 6, 3)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be cast to" in { @@ -259,7 +269,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { ResolveFlows, CheckFlows, new InferWidths, - CheckWidths) + CheckWidths + ) val input = """circuit Unit : | module Unit : @@ -272,7 +283,7 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : Fixed<10><<2>> | d <= asFixedPoint(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "support binary point of zero" in { @@ -286,7 +297,8 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """ |circuit Unit : @@ -312,53 +324,53 @@ class FixedTypeInferenceSpec extends FirrtlFlatSpec { | io_out <= io_in | """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "work with mems" in { def input(memType: String): String = s""" - |circuit Unit : - | module Unit : - | input clock : Clock - | input in : Fixed<16><<8>> - | input ridx : UInt<3> - | output out : Fixed<16><<8>> - | input widx : UInt<3> - | $memType mem : Fixed<16><<8>>[8] - | infer mport min = mem[ridx], clock - | min <= in - | infer mport mout = mem[widx], clock - | out <= mout + |circuit Unit : + | module Unit : + | input clock : Clock + | input in : Fixed<16><<8>> + | input ridx : UInt<3> + | output out : Fixed<16><<8>> + | input widx : UInt<3> + | $memType mem : Fixed<16><<8>>[8] + | infer mport min = mem[ridx], clock + | min <= in + | infer mport mout = mem[widx], clock + | out <= mout """.stripMargin def check(readLatency: Int, moutEn: Int, minEn: Int): String = s""" - |circuit Unit : - | module Unit : - | input clock : Clock - | input in : SInt<16> - | input ridx : UInt<3> - | output out : SInt<16> - | input widx : UInt<3> - | - | mem mem : - | data-type => SInt<16> - | depth => 8 - | read-latency => $readLatency - | write-latency => 1 - | reader => mout - | writer => min - | read-under-write => undefined - | out <= mem.mout.data - | mem.mout.addr <= widx - | mem.mout.en <= UInt<1>("h$moutEn") - | mem.mout.clk <= clock - | mem.min.addr <= ridx - | mem.min.en <= UInt<1>("h$minEn") - | mem.min.clk <= clock - | mem.min.data <= in - | mem.min.mask <= UInt<1>("h1") + |circuit Unit : + | module Unit : + | input clock : Clock + | input in : SInt<16> + | input ridx : UInt<3> + | output out : SInt<16> + | input widx : UInt<3> + | + | mem mem : + | data-type => SInt<16> + | depth => 8 + | read-latency => $readLatency + | write-latency => 1 + | reader => mout + | writer => min + | read-under-write => undefined + | out <= mem.mout.data + | mem.mout.addr <= widx + | mem.mout.en <= UInt<1>("h$moutEn") + | mem.mout.clk <= clock + | mem.min.addr <= ridx + | mem.min.en <= UInt<1>("h$minEn") + | mem.min.clk <= clock + | mem.min.data <= in + | mem.min.mask <= UInt<1>("h1") """.stripMargin - executeTest(input("smem"), check(1, 0, 1).split("\n") map normalized, new LowFirrtlCompiler) - executeTest(input("cmem"), check(0, 1, 1).split("\n") map normalized, new LowFirrtlCompiler) + executeTest(input("smem"), check(1, 0, 1).split("\n").map(normalized), new LowFirrtlCompiler) + executeTest(input("cmem"), check(0, 1, 1).split("\n").map(normalized), new LowFirrtlCompiler) } } diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala index 9dc61927..d0218b11 100644 --- a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -9,12 +9,14 @@ import firrtl.testutils._ class RemoveFixedTypeSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) - }.circuit - val lines = c.serialize.split("\n") map normalized + val c = passes + .foldLeft(CircuitState(Parser.parse(input.split("\n").toIterator), UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + } + .circuit + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -30,7 +32,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -41,13 +44,13 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | d <= add(a, add(b, c))""".stripMargin val check = """circuit Unit : - | module Unit : - | input a : SInt<10> - | input b : SInt<10> - | input c : SInt<4> - | output d : SInt<15> - | d <= shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + | module Unit : + | input a : SInt<10> + | input b : SInt<10> + | input c : SInt<4> + | output d : SInt<15> + | d <= shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "be removed, even with a bulk connect" in { val passes = Seq( @@ -60,7 +63,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -71,13 +75,13 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | d <- add(a, add(b, c))""".stripMargin val check = """circuit Unit : - | module Unit : - | input a : SInt<10> - | input b : SInt<10> - | input c : SInt<4> - | output d : SInt<15> - | d <- shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + | module Unit : + | input a : SInt<10> + | input b : SInt<10> + | input c : SInt<4> + | output d : SInt<15> + | d <- shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "remove binary point shift correctly" in { @@ -91,7 +95,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -104,7 +109,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : SInt<12> | d <= shl(a, 2)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "remove binary point shift correctly in reverse" in { @@ -118,7 +123,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -131,7 +137,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : SInt<9> | d <= shr(a, 1)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed types" should "remove an absolutely set binary point correctly" in { @@ -145,7 +151,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """circuit Unit : | module Unit : @@ -158,7 +165,7 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | input a : SInt<10> | output d : SInt<11> | d <= shl(a, 1)""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Fixed point numbers" should "allow binary point to be set to zero at creation" in { @@ -197,7 +204,8 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { CheckFlows, new InferWidths, CheckWidths, - ConvertFixedToSInt) + ConvertFixedToSInt + ) val input = """ |circuit Unit : @@ -210,6 +218,6 @@ class RemoveFixedTypeSpec extends FirrtlFlatSpec { | module Unit : | node x = asSInt(asSInt(UInt<2>("h3"))) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } } diff --git a/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala b/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala index edfd31d3..e413a70d 100644 --- a/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala +++ b/src/test/scala/firrtlTests/formal/AssertSubmoduleAssumptionsSpec.scala @@ -1,4 +1,3 @@ - package firrtlTests.formal import firrtl.{CircuitState, Parser, Transform, UnknownForm} @@ -7,24 +6,25 @@ import firrtl.transforms.formal.AssertSubmoduleAssumptions import firrtl.stage.{Forms, TransformManager} class AssertSubmoduleAssumptionsSpec extends FirrtlFlatSpec { - behavior of "AssertSubmoduleAssumptions" + behavior.of("AssertSubmoduleAssumptions") - val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm) - .flattenedTransformOrder ++ Seq(new AssertSubmoduleAssumptions) + val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm).flattenedTransformOrder ++ Seq( + new AssertSubmoduleAssumptions + ) def run(input: String, check: Seq[String], debug: Boolean = false): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) - val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } - val lines = result.circuit.serialize.split("\n") map normalized + val lines = result.circuit.serialize.split("\n").map(normalized) if (debug) { println(lines.mkString("\n")) } for (ch <- check) { - lines should contain (ch) + lines should contain(ch) } } diff --git a/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala b/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala index c70a3ce4..847c211e 100644 --- a/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala +++ b/src/test/scala/firrtlTests/formal/ConvertAssertsSpec.scala @@ -8,15 +8,15 @@ import firrtl.transforms.formal.ConvertAsserts class ConvertAssertsSpec extends FirrtlFlatSpec { val preamble = - """circuit DUT: - | module DUT: - | input clock: Clock - | input reset: UInt<1> - | input x: UInt<8> - | output y: UInt<8> - | y <= x - | node ne5 = neq(x, UInt(5)) - |""".stripMargin + """circuit DUT: + | module DUT: + | input clock: Clock + | input reset: UInt<1> + | input x: UInt<8> + | output y: UInt<8> + | y <= x + | node ne5 = neq(x, UInt(5)) + |""".stripMargin "assert nodes" should "be converted to predicated prints and stops" in { val input = preamble + @@ -29,7 +29,7 @@ class ConvertAssertsSpec extends FirrtlFlatSpec { |""".stripMargin val outputCS = ConvertAsserts.execute(CircuitState(parse(input), Nil)) - (parse(outputCS.circuit.serialize)) should be (parse(ref)) + (parse(outputCS.circuit.serialize)) should be(parse(ref)) } "assert nodes with no message" should "omit printed messages" in { @@ -42,6 +42,6 @@ class ConvertAssertsSpec extends FirrtlFlatSpec { |""".stripMargin val outputCS = ConvertAsserts.execute(CircuitState(parse(input), Nil)) - (parse(outputCS.circuit.serialize)) should be (parse(ref)) + (parse(outputCS.circuit.serialize)) should be(parse(ref)) } } diff --git a/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala b/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala index 10e63ae4..40d810c5 100644 --- a/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala +++ b/src/test/scala/firrtlTests/formal/RemoveVerificationStatementsSpec.scala @@ -1,4 +1,3 @@ - package firrtlTests.formal import firrtl.{CircuitState, Parser, Transform, UnknownForm} @@ -7,17 +6,18 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.transforms.formal.RemoveVerificationStatements class RemoveVerificationStatementsSpec extends FirrtlFlatSpec { - behavior of "RemoveVerificationStatements" + behavior.of("RemoveVerificationStatements") - val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm) - .flattenedTransformOrder ++ Seq(new RemoveVerificationStatements) + val transforms = new TransformManager(Forms.HighForm, Forms.MinimalHighForm).flattenedTransformOrder ++ Seq( + new RemoveVerificationStatements + ) def run(input: String, antiCheck: Seq[String], debug: Boolean = false): Unit = { val circuit = Parser.parse(input.split("\n").toIterator) - val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { - (c: CircuitState, p: Transform) => p.runTransform(c) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { (c: CircuitState, p: Transform) => + p.runTransform(c) } - val lines = result.circuit.serialize.split("\n") map normalized + val lines = result.circuit.serialize.split("\n").map(normalized) if (debug) { println(lines.mkString("\n")) diff --git a/src/test/scala/firrtlTests/formal/VerificationSpec.scala b/src/test/scala/firrtlTests/formal/VerificationSpec.scala index 73d1404d..a8e28c13 100644 --- a/src/test/scala/firrtlTests/formal/VerificationSpec.scala +++ b/src/test/scala/firrtlTests/formal/VerificationSpec.scala @@ -2,14 +2,14 @@ package firrtlTests.formal -import firrtl.{CircuitState, SystemVerilogCompiler, ir} +import firrtl.{ir, CircuitState, SystemVerilogCompiler} import firrtl.testutils.FirrtlFlatSpec import logger.{LogLevel, Logger} import firrtl.options.Dependency import firrtl.stage.TransformManager class VerificationSpec extends FirrtlFlatSpec { - behavior of "Formal" + behavior.of("Formal") it should "generate SystemVerilog verification statements" in { val compiler = new SystemVerilogCompiler @@ -56,7 +56,7 @@ class VerificationSpec extends FirrtlFlatSpec { | end | end |endmodule - |""".stripMargin.split("\n") map normalized + |""".stripMargin.split("\n").map(normalized) executeTest(input, expected, compiler) } diff --git a/src/test/scala/firrtlTests/graph/DiGraphTests.scala b/src/test/scala/firrtlTests/graph/DiGraphTests.scala index 0f8c7193..0f5cf09c 100644 --- a/src/test/scala/firrtlTests/graph/DiGraphTests.scala +++ b/src/test/scala/firrtlTests/graph/DiGraphTests.scala @@ -7,32 +7,24 @@ import firrtl.testutils._ class DiGraphTests extends FirrtlFlatSpec { - val acyclicGraph = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d"), - "c" -> Set("d"), - "d" -> Set("e"), - "e" -> Set.empty[String])) - - val reversedAcyclicGraph = DiGraph(Map( - "a" -> Set.empty[String], - "b" -> Set("a"), - "c" -> Set("a"), - "d" -> Set("b", "c"), - "e" -> Set("d"))) - - val cyclicGraph = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d"), - "c" -> Set("d"), - "d" -> Set("a"))) - - val tupleGraph = DiGraph(Map( - ("a", 0) -> Set(("b", 2)), - ("a", 1) -> Set(("c", 3)), - ("b", 2) -> Set.empty[(String, Int)], - ("c", 3) -> Set.empty[(String, Int)] - )) + val acyclicGraph = DiGraph( + Map("a" -> Set("b", "c"), "b" -> Set("d"), "c" -> Set("d"), "d" -> Set("e"), "e" -> Set.empty[String]) + ) + + val reversedAcyclicGraph = DiGraph( + Map("a" -> Set.empty[String], "b" -> Set("a"), "c" -> Set("a"), "d" -> Set("b", "c"), "e" -> Set("d")) + ) + + val cyclicGraph = DiGraph(Map("a" -> Set("b", "c"), "b" -> Set("d"), "c" -> Set("d"), "d" -> Set("a"))) + + val tupleGraph = DiGraph( + Map( + ("a", 0) -> Set(("b", 2)), + ("a", 1) -> Set(("c", 3)), + ("b", 2) -> Set.empty[(String, Int)], + ("c", 3) -> Set.empty[(String, Int)] + ) + ) val degenerateGraph = DiGraph(Map("a" -> Set.empty[String])) @@ -45,109 +37,113 @@ class DiGraphTests extends FirrtlFlatSpec { } "Asking a DiGraph for a path that exists" should "work" in { - acyclicGraph.path("a","e") should not be empty + acyclicGraph.path("a", "e") should not be empty } "Asking a DiGraph for a path from one node to another with no path" should "error" in { - an [PathNotFoundException] should be thrownBy acyclicGraph.path("e","a") + an[PathNotFoundException] should be thrownBy acyclicGraph.path("e", "a") } "The first element in a linearized graph with a single root node" should "be the root" in { - acyclicGraph.linearize.head should equal ("a") + acyclicGraph.linearize.head should equal("a") } "A DiGraph with a cycle" should "error when linearized" in { - a [CyclicException] should be thrownBy cyclicGraph.linearize + a[CyclicException] should be thrownBy cyclicGraph.linearize } "CyclicExceptions" should "contain information about the cycle" in { - val c = the [CyclicException] thrownBy { + val c = the[CyclicException] thrownBy { cyclicGraph.linearize } - c.getMessage.contains("found at a") should be (true) - c.node.asInstanceOf[String] should be ("a") + c.getMessage.contains("found at a") should be(true) + c.node.asInstanceOf[String] should be("a") } "Reversing a graph" should "reverse all of the edges" in { - acyclicGraph.reverse.getEdgeMap should equal (reversedAcyclicGraph.getEdgeMap) + acyclicGraph.reverse.getEdgeMap should equal(reversedAcyclicGraph.getEdgeMap) } "Reversing a graph with no edges" should "equal the graph itself" in { - degenerateGraph.getEdgeMap should equal (degenerateGraph.reverse.getEdgeMap) + degenerateGraph.getEdgeMap should equal(degenerateGraph.reverse.getEdgeMap) } "transformNodes" should "combine vertices that collide, not drop them" in { - tupleGraph.transformNodes(_._1).getEdgeMap should contain ("a" -> Set("b", "c")) + tupleGraph.transformNodes(_._1).getEdgeMap should contain("a" -> Set("b", "c")) } "Graph summation" should "be order-wise equivalent to original" in { val first = acyclicGraph.subgraph(Set("a", "b", "c")) val second = acyclicGraph.subgraph(Set("b", "c", "d", "e")) - (first + second).getEdgeMap should equal (acyclicGraph.getEdgeMap) + (first + second).getEdgeMap should equal(acyclicGraph.getEdgeMap) } it should "be idempotent" in { val first = acyclicGraph.subgraph(Set("a", "b", "c")) val second = acyclicGraph.subgraph(Set("b", "c", "d", "e")) - (first + second + second + second).getEdgeMap should equal (acyclicGraph.getEdgeMap) + (first + second + second + second).getEdgeMap should equal(acyclicGraph.getEdgeMap) } "linearize" should "not cause a stack overflow on very large graphs" in { // Graph of 0 -> 1, 1 -> 2, etc. val N = 10000 - val edges = (1 to N).zipWithIndex.map({ case (n, idx) => idx -> Set(n)}).toMap + val edges = (1 to N).zipWithIndex.map({ case (n, idx) => idx -> Set(n) }).toMap val bigGraph = DiGraph(edges + (N -> Set.empty[Int])) - bigGraph.linearize should be (0 to N) + bigGraph.linearize should be(0 to N) } it should "work on multi-rooted graphs" in { val graph = DiGraph(Map("a" -> Set[String](), "b" -> Set[String]())) - graph.linearize.toSet should be (graph.getVertices) + graph.linearize.toSet should be(graph.getVertices) } "acyclic graph" should "be rendered" in { - val acyclicGraph2 = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d", "x", "z"), - "c" -> Set("d", "x"), - "d" -> Set("e", "k", "l"), - "x" -> Set("e"), - "z" -> Set("e", "j"), - "j" -> Set("k", "l", "c"), - "k" -> Set("l"), - "l" -> Set("e"), - "e" -> Set.empty[String] - )) + val acyclicGraph2 = DiGraph( + Map( + "a" -> Set("b", "c"), + "b" -> Set("d", "x", "z"), + "c" -> Set("d", "x"), + "d" -> Set("e", "k", "l"), + "x" -> Set("e"), + "z" -> Set("e", "j"), + "j" -> Set("k", "l", "c"), + "k" -> Set("l"), + "l" -> Set("e"), + "e" -> Set.empty[String] + ) + ) val render = new RenderDiGraph(acyclicGraph2) val dotLines = render.toDotRanked.split("\n") - dotLines.count(s => s.contains("rank=same")) should be (4) - dotLines.exists(s => s.contains(""""b" -> { "d" "x" "z" };""")) should be (true) - dotLines.exists(s => s.contains("""rankdir="LR";""")) should be (true) + dotLines.count(s => s.contains("rank=same")) should be(4) + dotLines.exists(s => s.contains(""""b" -> { "d" "x" "z" };""")) should be(true) + dotLines.exists(s => s.contains("""rankdir="LR";""")) should be(true) } "subgraphs containing cycles" should "be rendered with loop edges in red, can override orientation" in { - val cyclicGraph2 = DiGraph(Map( - "a" -> Set("b","c"), - "b" -> Set("d", "x", "z"), - "c" -> Set("d", "x"), - "d" -> Set("e", "k", "l"), - "x" -> Set("e"), - "z" -> Set("e", "j"), - "j" -> Set("k", "l", "c"), - "k" -> Set("l"), - "l" -> Set("e"), - "e" -> Set("c") - )) + val cyclicGraph2 = DiGraph( + Map( + "a" -> Set("b", "c"), + "b" -> Set("d", "x", "z"), + "c" -> Set("d", "x"), + "d" -> Set("e", "k", "l"), + "x" -> Set("e"), + "z" -> Set("e", "j"), + "j" -> Set("k", "l", "c"), + "k" -> Set("l"), + "l" -> Set("e"), + "e" -> Set("c") + ) + ) val render = new RenderDiGraph(cyclicGraph2, rankDir = "TB") val dotLines = render.showOnlyTheLoopAsDot.split("\n") - dotLines.count(s => s.contains("rank=same")) should be (4) - dotLines.count(s => s.contains("""[color=red,penwidth=3.0];""")) should be (3) - dotLines.exists(s => s.contains(""""d" -> "k";""")) should be (true) - dotLines.exists(s => s.contains("""rankdir="TB";""")) should be (true) + dotLines.count(s => s.contains("rank=same")) should be(4) + dotLines.count(s => s.contains("""[color=red,penwidth=3.0];""")) should be(3) + dotLines.exists(s => s.contains(""""d" -> "k";""")) should be(true) + dotLines.exists(s => s.contains("""rankdir="TB";""")) should be(true) } "reachableFrom" should "omit the queried node if no self-path exists" in { diff --git a/src/test/scala/firrtlTests/graph/EulerTourTests.scala b/src/test/scala/firrtlTests/graph/EulerTourTests.scala index f6deb721..703235af 100644 --- a/src/test/scala/firrtlTests/graph/EulerTourTests.scala +++ b/src/test/scala/firrtlTests/graph/EulerTourTests.scala @@ -11,26 +11,30 @@ class EulerTourTests extends FirrtlFlatSpec { val third_layer = Set("3a", "3b", "3c") val last_null = Set.empty[String] - val m = Map(top -> first_layer) ++ first_layer.map{ - case x => Map(x -> second_layer) }.flatten.toMap ++ second_layer.map{ - case x => Map(x -> third_layer) }.flatten.toMap ++ third_layer.map{ - case x => Map(x -> last_null) }.flatten.toMap + val m = Map(top -> first_layer) ++ first_layer.map { + case x => Map(x -> second_layer) + }.flatten.toMap ++ second_layer.map { + case x => Map(x -> third_layer) + }.flatten.toMap ++ third_layer.map { + case x => Map(x -> last_null) + }.flatten.toMap val graph = DiGraph(m) val instances = graph.pathsInDAG(top).values.flatten val tour = EulerTour(graph, top) it should "show equivalency of Berkman--Vishkin and naive RMQs" in { - instances.toSeq.combinations(2).toList.map { case Seq(a, b) => - tour.rmqNaive(a, b) should be (tour.rmqBV(a, b)) + instances.toSeq.combinations(2).toList.map { + case Seq(a, b) => + tour.rmqNaive(a, b) should be(tour.rmqBV(a, b)) } } it should "determine naive RMQs of itself correctly" in { - instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) } + instances.toSeq.map { case a => tour.rmqNaive(a, a) should be(a) } } it should "determine Berkman--Vishkin RMQs of itself correctly" in { - instances.toSeq.map { case a => tour.rmqNaive(a, a) should be (a) } + instances.toSeq.map { case a => tour.rmqNaive(a, a) should be(a) } } } diff --git a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala index 656e1f8c..74e6cabf 100644 --- a/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala +++ b/src/test/scala/firrtlTests/interval/IntervalMathSpec.scala @@ -10,14 +10,14 @@ import firrtl.constraint._ import firrtl.testutils.FirrtlFlatSpec class IntervalMathSpec extends FirrtlFlatSpec { - val SumPattern = """.*output sum.*<(\d+)>.*""".r - val ProductPattern = """.*output product.*<(\d+)>.*""".r - val DifferencePattern = """.*output difference.*<(\d+)>.*""".r - val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r - val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r - val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r - val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r - val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r + val SumPattern = """.*output sum.*<(\d+)>.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + val ComparisonPattern = """.*output (\w+).*UInt<(\d+)>.*""".r + val ShiftLeftPattern = """.*output shl.*<(\d+)>.*""".r + val ShiftRightPattern = """.*output shr.*<(\d+)>.*""".r + val DShiftLeftPattern = """.*output dshl.*<(\d+)>.*""".r + val DShiftRightPattern = """.*output dshr.*<(\d+)>.*""".r val ArithAssignPattern = """\s*(\w+) <= asSInt\(bits\((\w+)\((.*)\).*\)\)\s*""".r def getBound(bound: String, value: BigDecimal): IsKnown = bound match { case "[" => Closed(value) @@ -29,16 +29,16 @@ class IntervalMathSpec extends FirrtlFlatSpec { val prec = 0.5 for { - lb1 <- Seq("[", "(") - lv1 <- Range.BigDecimal(-1.0, 1.0, prec) - uv1 <- if(lb1 == "[") Range.BigDecimal(lv1, 1.0, prec) else Range.BigDecimal(lv1 + prec, 1.0, prec) - ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") - bp1 <- 0 to 1 - lb2 <- Seq("[", "(") - lv2 <- Range.BigDecimal(-1.0, 1.0, prec) - uv2 <- if(lb2 == "[") Range.BigDecimal(lv2, 1.0, prec) else Range.BigDecimal(lv2 + prec, 1.0, prec) - ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") - bp2 <- 0 to 1 + lb1 <- Seq("[", "(") + lv1 <- Range.BigDecimal(-1.0, 1.0, prec) + uv1 <- if (lb1 == "[") Range.BigDecimal(lv1, 1.0, prec) else Range.BigDecimal(lv1 + prec, 1.0, prec) + ub1 <- if (lv1 == uv1) Seq("]") else Seq("]", ")") + bp1 <- 0 to 1 + lb2 <- Seq("[", "(") + lv2 <- Range.BigDecimal(-1.0, 1.0, prec) + uv2 <- if (lb2 == "[") Range.BigDecimal(lv2, 1.0, prec) else Range.BigDecimal(lv2 + prec, 1.0, prec) + ub2 <- if (lv2 == uv2) Seq("]") else Seq("]", ")") + bp2 <- 0 to 1 } { val it1 = IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1.toInt)) val it2 = IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2.toInt)) @@ -47,103 +47,108 @@ class IntervalMathSpec extends FirrtlFlatSpec { case (_, Some(Nil)) => case _ => def config = s"$lb1$lv1,$uv1$ub1.$bp1 and $lb2$lv2,$uv2$ub2.$bp2" - + s"Configuration $config" should "pass" in { - + val input = s"""circuit Unit : - | module Unit : - | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 - | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 - | input amt : UInt<3> - | output sum : Interval - | output difference : Interval - | output product : Interval - | output shl : Interval - | output shr : Interval - | output dshl : Interval - | output dshr : Interval - | output lt : UInt - | output leq : UInt - | output gt : UInt - | output geq : UInt - | output eq : UInt - | output neq : UInt - | output cat : UInt - | sum <= add(in1, in2) - | difference <= sub(in1, in2) - | product <= mul(in1, in2) - | shl <= shl(in1, 3) - | shr <= shr(in1, 3) - | dshl <= dshl(in1, amt) - | dshr <= dshr(in1, amt) - | lt <= lt(in1, in2) - | leq <= leq(in1, in2) - | gt <= gt(in1, in2) - | geq <= geq(in1, in2) - | eq <= eq(in1, in2) - | neq <= lt(in1, in2) - | cat <= cat(in1, in2) - | """.stripMargin - + | module Unit : + | input in1 : Interval$lb1$lv1, $uv1$ub1.$bp1 + | input in2 : Interval$lb2$lv2, $uv2$ub2.$bp2 + | input amt : UInt<3> + | output sum : Interval + | output difference : Interval + | output product : Interval + | output shl : Interval + | output shr : Interval + | output dshl : Interval + | output dshr : Interval + | output lt : UInt + | output leq : UInt + | output gt : UInt + | output geq : UInt + | output eq : UInt + | output neq : UInt + | output cat : UInt + | sum <= add(in1, in2) + | difference <= sub(in1, in2) + | product <= mul(in1, in2) + | shl <= shl(in1, 3) + | shr <= shr(in1, 3) + | dshl <= dshl(in1, amt) + | dshr <= dshr(in1, amt) + | lt <= lt(in1, in2) + | leq <= leq(in1, in2) + | gt <= gt(in1, in2) + | geq <= geq(in1, in2) + | eq <= eq(in1, in2) + | neq <= lt(in1, in2) + | cat <= cat(in1, in2) + | """.stripMargin + val lowerer = new LowFirrtlCompiler val res = lowerer.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) - val output = res.getEmittedCircuit.value split "\n" + val output = res.getEmittedCircuit.value.split("\n") val min1 = Closed(it1.min.get) val max1 = Closed(it1.max.get) val min2 = Closed(it2.min.get) val max2 = Closed(it2.max.get) for (line <- output) { line match { - case SumPattern(varWidth) => + case SumPattern(varWidth) => val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) val it = IntervalType(IsAdd(min1, min2), IsAdd(max1, max2), bp) assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, s"$line,${it.range}") - case ProductPattern(varWidth) => + case ProductPattern(varWidth) => val bp = IntWidth(bp1.toInt + bp2.toInt) val lv = IsMin(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) val uv = IsMax(Seq(IsMul(min1, min2), IsMul(min1, max2), IsMul(max1, min2), IsMul(max1, max2))) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "product") - case DifferencePattern(varWidth) => + case DifferencePattern(varWidth) => val bp = IntWidth(Math.max(bp1.toInt, bp2.toInt)) val lv = min1 + max2.neg val uv = max1 + min2.neg assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "diff") - case ShiftLeftPattern(varWidth) => + case ShiftLeftPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 * Closed(8) val uv = max1 * Closed(8) val it = IntervalType(lv, uv, bp) assert(varWidth.toInt == it.width.asInstanceOf[IntWidth].width, "shl") - case ShiftRightPattern(varWidth) => + case ShiftRightPattern(varWidth) => val bp = IntWidth(bp1.toInt) - val lv = min1 * Closed(1/3) - val uv = max1 * Closed(1/3) + val lv = min1 * Closed(1 / 3) + val uv = max1 * Closed(1 / 3) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "shr") - case DShiftLeftPattern(varWidth) => + case DShiftLeftPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 * Closed(128) val uv = max1 * Closed(128) assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshl") - case DShiftRightPattern(varWidth) => + case DShiftRightPattern(varWidth) => val bp = IntWidth(bp1.toInt) val lv = min1 val uv = max1 assert(varWidth.toInt == IntervalType(lv, uv, bp).width.asInstanceOf[IntWidth].width, "dshr") case ComparisonPattern(varWidth) => assert(varWidth.toInt == 1, "==") case ArithAssignPattern(varName, operation, args) => - val arg1 = if(IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) """SInt<1>("h0")""" else "in1" - val arg2 = if(IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) """SInt<1>("h0")""" else "in2" + val arg1 = + if (IntervalType(getBound(lb1, lv1), getBound(ub1, uv1), IntWidth(bp1)).width == IntWidth(0)) + """SInt<1>("h0")""" + else "in1" + val arg2 = + if (IntervalType(getBound(lb2, lv2), getBound(ub2, uv2), IntWidth(bp2)).width == IntWidth(0)) + """SInt<1>("h0")""" + else "in2" varName match { case "sum" => assert(operation === "add", s"""var sum should be result of an add in ${output.mkString("\n")}""") if (bp1 > bp2) { - if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") - assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), - s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), s"$config second arg incorrect in $line") } else if (bp1 < bp2) { - assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), - s"$config second arg incorrect in $line") + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), s"$config second arg incorrect in $line") assert(!args.contains("shl($arg2"), s"$config second arg should be just $arg2 in $line") } else { assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") @@ -156,13 +161,13 @@ class IntervalMathSpec extends FirrtlFlatSpec { case "difference" => assert(operation === "sub", s"var difference should be result of an sub in $line") if (bp1 > bp2) { - if (arg1 != arg2) assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") - assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), - s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") + assert(args.contains(s"shl($arg2, ${bp1 - bp2})"), s"$config second arg incorrect in $line") } else if (bp1 < bp2) { - assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), - s"$config second arg incorrect in $line") - if (arg1 != arg2) assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") + assert(args.contains(s"shl($arg1, ${(bp1 - bp2).abs})"), s"$config second arg incorrect in $line") + if (arg1 != arg2) + assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") } else { assert(!args.contains(s"shl($arg1"), s"$config first arg should be just $arg1 in $line") assert(!args.contains(s"shl($arg2"), s"$config second arg should be just $arg2 in $line") @@ -170,12 +175,11 @@ class IntervalMathSpec extends FirrtlFlatSpec { case _ => } case _ => + } } } - } } } } - // vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/interval/IntervalSpec.scala b/src/test/scala/firrtlTests/interval/IntervalSpec.scala index 5d82f6b5..1a39e98e 100644 --- a/src/test/scala/firrtlTests/interval/IntervalSpec.scala +++ b/src/test/scala/firrtlTests/interval/IntervalSpec.scala @@ -10,13 +10,12 @@ import firrtl.testutils.FirrtlFlatSpec class IntervalSpec extends FirrtlFlatSpec { private def executeTest(input: String, expected: Seq[String], passes: Seq[Transform]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Transform) => - p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { (c: Circuit, p: Transform) => + p.runTransform(CircuitState(c, UnknownForm, AnnotationSeq(Nil), None)).circuit } - val lines = c.serialize.split("\n") map normalized + val lines = c.serialize.split("\n").map(normalized) - expected foreach { e => + expected.foreach { e => lines should contain(e) } } @@ -37,7 +36,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out1 : Interval | out0 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6)))))) | out1 <= add(in0, add(in1, add(in2, add(in3, add(in4, add(in5, in6))))))""".stripMargin - executeTest(input, input.split("\n") map normalized, passes) + executeTest(input, input.split("\n").map(normalized), passes) } "Interval types" should "infer bp correctly" in { @@ -58,7 +57,7 @@ class IntervalSpec extends FirrtlFlatSpec { | input in2 : Interval(-0.32, 10].2 | output out0 : Interval.4 | out0 <= add(in0, add(in1, in2))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "trim known intervals correctly" in { @@ -79,11 +78,12 @@ class IntervalSpec extends FirrtlFlatSpec { | input in2 : Interval[-0.25, 10].2 | output out0 : Interval.4 | out0 <= add(in0, incp(add(in1, incp(in2, 1)), 1))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer intervals correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) val input = """circuit Unit : | module Unit : @@ -100,11 +100,19 @@ class IntervalSpec extends FirrtlFlatSpec { """output out0 : Interval[-0.5625, 22.9375].4 |output out1 : Interval[-74.53125, 298.125].9 |output out2 : Interval[-10.6875, 12.8125].4""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "be removed correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals(), + new InferWidths(), + new RemoveIntervals() + ) val input = """circuit Unit : | module Unit : @@ -129,209 +137,227 @@ class IntervalSpec extends FirrtlFlatSpec { | out0 <= add(in0, shl(add(in1, shl(in2, 1)), 1)) | out1 <= mul(in0, mul(in1, in2)) | out2 <= sub(in0, shl(sub(in1, shl(in2, 1)), 1))""".stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } -"Interval types" should "infer multiplication by zero correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + "Interval types" should "infer multiplication by zero correctly" in { + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) val input = s"""circuit Unit : - | module Unit : - | input in1 : Interval[0, 0.5].1 - | input in2 : Interval[0, 0].1 - | output mul : Interval - | mul <= mul(in2, in1) - | """.stripMargin - val check = s"""output mul : Interval[0, 0].2 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) -} + | module Unit : + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output mul : Interval + | mul <= mul(in2, in1) + | """.stripMargin + val check = s"""output mul : Interval[0, 0].2 """.stripMargin + executeTest(input, check.split("\n").map(normalized), passes) + } "Interval types" should "infer muxes correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<1> - | input in1 : Interval[0, 0.5].1 - | input in2 : Interval[0, 0].1 - | output out : Interval - | out <= mux(p, in2, in1) - | """.stripMargin + val passes = + Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths()) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<1> + | input in1 : Interval[0, 0.5].1 + | input in2 : Interval[0, 0].1 + | output out : Interval + | out <= mux(p, in2, in1) + | """.stripMargin val check = s"""output out : Interval[0, 0.5].1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer dshl correctly" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveKinds, ResolveFlows, new InferBinaryPoints(), new TrimIntervals, new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<3> - | input in1 : Interval[-1, 1].0 - | output out : Interval - | out <= dshl(in1, p) - | """.stripMargin + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveKinds, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals, + new InferWidths() + ) + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | input in1 : Interval[-1, 1].0 + | output out : Interval + | out <= dshl(in1, p) + | """.stripMargin val check = s"""output out : Interval[-128, 128].0 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer asInterval correctly" in { val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferWidths()) - val input = - s"""circuit Unit : - | module Unit : - | input p : UInt<3> - | output out : Interval - | out <= asInterval(p, 0, 4, 1) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input p : UInt<3> + | output out : Interval + | out <= asInterval(p, 0, 4, 1) + | """.stripMargin val check = s"""output out : Interval[0, 2].1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "do wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input u: UInt<3> - | input in1: Interval[-3, 5].0 - | output wrap3: Interval - | output wrap4: Interval - | output wrap5: Interval - | output wrap6: Interval - | output wrap7: Interval - | output clip3: Interval - | output clip4: Interval - | output clip5: Interval - | output clip6: Interval - | output clip7: Interval - | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) - | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) - | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) - | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) - | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) - | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) - | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) - | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) - | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) - | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap4: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap4 <= wrap(in1, asInterval(s, -1, 1, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) """.stripMargin - //| output wrap1: Interval - //| output wrap2: Interval - //| output clip1: Interval - //| output clip2: Interval - //| wrap1 <= wrap(in1, u, 0) - //| wrap2 <= wrap(in1, s, 0) - //| clip1 <= clip(in1, u) - //| clip2 <= clip(in1, s) + //| output wrap1: Interval + //| output wrap2: Interval + //| output clip1: Interval + //| output clip2: Interval + //| wrap1 <= wrap(in1, u, 0) + //| wrap2 <= wrap(in1, s, 0) + //| clip1 <= clip(in1, u) + //| clip2 <= clip(in1, s) val check = s""" - | output wrap3 : Interval[-2, 4].0 - | output wrap4 : Interval[-1, 1].0 - | output wrap5 : Interval[-4, 4].0 - | output wrap6 : Interval[-1, 7].0 - | output wrap7 : Interval[-4, 7].0 - | output clip3 : Interval[-2, 4].0 - | output clip4 : Interval[-1, 1].0 - | output clip5 : Interval[-3, 4].0 - | output clip6 : Interval[-1, 5].0 - | output clip7 : Interval[-3, 5].0 """.stripMargin - // TODO: this optimization - //| output wrap1 : Interval[0, 7].0 - //| output wrap2 : Interval[-2, 1].0 - //| output clip1 : Interval[0, 5].0 - //| output clip2 : Interval[-2, 1].0 - //| output wrap7 : Interval[-3, 5].0 - executeTest(input, check.split("\n") map normalized, passes) + | output wrap3 : Interval[-2, 4].0 + | output wrap4 : Interval[-1, 1].0 + | output wrap5 : Interval[-4, 4].0 + | output wrap6 : Interval[-1, 7].0 + | output wrap7 : Interval[-4, 7].0 + | output clip3 : Interval[-2, 4].0 + | output clip4 : Interval[-1, 1].0 + | output clip5 : Interval[-3, 4].0 + | output clip6 : Interval[-1, 5].0 + | output clip7 : Interval[-3, 5].0 """.stripMargin + // TODO: this optimization + //| output wrap1 : Interval[0, 7].0 + //| output wrap2 : Interval[-2, 1].0 + //| output clip1 : Interval[0, 5].0 + //| output clip2 : Interval[-2, 1].0 + //| output wrap7 : Interval[-3, 5].0 + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "remove wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck(), new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input u: UInt<3> - | input in1: Interval[-3, 5].0 - | output wrap3: Interval - | output wrap5: Interval - | output wrap6: Interval - | output wrap7: Interval - | output clip3: Interval - | output clip4: Interval - | output clip5: Interval - | output clip6: Interval - | output clip7: Interval - | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) - | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) - | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) - | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) - | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) - | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) - | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) - | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) - | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input u: UInt<3> + | input in1: Interval[-3, 5].0 + | output wrap3: Interval + | output wrap5: Interval + | output wrap6: Interval + | output wrap7: Interval + | output clip3: Interval + | output clip4: Interval + | output clip5: Interval + | output clip6: Interval + | output clip7: Interval + | wrap3 <= wrap(in1, asInterval(s, -2, 4, 0)) + | wrap5 <= wrap(in1, asInterval(s, -4, 4, 0)) + | wrap6 <= wrap(in1, asInterval(s, -1, 7, 0)) + | wrap7 <= wrap(in1, asInterval(s, -4, 7, 0)) + | clip3 <= clip(in1, asInterval(s, -2, 4, 0)) + | clip4 <= clip(in1, asInterval(s, -1, 1, 0)) + | clip5 <= clip(in1, asInterval(s, -4, 4, 0)) + | clip6 <= clip(in1, asInterval(s, -1, 7, 0)) + | clip7 <= clip(in1, asInterval(s, -4, 7, 0)) + | """.stripMargin val check = s""" - | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) - | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) - | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) - | wrap7 <= in1 - | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) - | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) - | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) - | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) - | clip7 <= in1 + | wrap3 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<4>("h7")), mux(lt(in1, SInt<2>("h-2")), add(in1, SInt<4>("h7")), in1)) + | wrap5 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), in1) + | wrap6 <= mux(lt(in1, SInt<1>("h-1")), add(in1, SInt<5>("h9")), in1) + | wrap7 <= in1 + | clip3 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<2>("h-2")), SInt<2>("h-2"), in1)) + | clip4 <= mux(gt(in1, SInt<2>("h1")), SInt<2>("h1"), mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1)) + | clip5 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), in1) + | clip6 <= mux(lt(in1, SInt<1>("h-1")), SInt<1>("h-1"), in1) + | clip7 <= in1 """.stripMargin - //| output wrap4: Interval - //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) - //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) - executeTest(input, check.split("\n") map normalized, passes) + //| output wrap4: Interval + //| wrap4 <= wrap(in1, asInterval(s, -1, 1, 0), 0) + //| wrap4 <= add(rem(sub(in1, SInt<1>("h-1")), sub(SInt<2>("h1"), SInt<1>("h-1"))), SInt<1>("h-1")) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "shift wrap/clip correctly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input s: SInt<2> - | input in1: Interval[-3, 5].1 - | output wrap1: Interval - | output clip1: Interval - | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) - | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input s: SInt<2> + | input in1: Interval[-3, 5].1 + | output wrap1: Interval + | output clip1: Interval + | wrap1 <= wrap(in1, asInterval(s, -2, 2, 0)) + | clip1 <= clip(in1, asInterval(s, -2, 2, 0)) + | """.stripMargin val check = s""" - | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) - | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) + | wrap1 <= mux(gt(in1, SInt<4>("h4")), sub(in1, SInt<5>("h9")), mux(lt(in1, SInt<3>("h-4")), add(in1, SInt<5>("h9")), in1)) + | clip1 <= mux(gt(in1, SInt<4>("h4")), SInt<4>("h4"), mux(lt(in1, SInt<3>("h-4")), SInt<3>("h-4"), in1)) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "infer negative binary points" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck()) - val input = - s"""circuit Unit : - | module Unit : - | input in1: Interval[-2, 4].-1 - | input in2: Interval[-4, 8].-2 - | output out: Interval - | out <= add(in1, in2) - | """.stripMargin + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval + | out <= add(in1, in2) + | """.stripMargin val check = s""" - | output out : Interval[-6, 12].-1 + | output out : Interval[-6, 12].-1 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "remove negative binary points" in { - val passes = Seq(ToWorkingIR, InferTypes, ResolveFlows, new InferBinaryPoints(), new TrimIntervals(), new InferWidths(), new RemoveIntervals()) - val input = - s"""circuit Unit : - | module Unit : - | input in1: Interval[-2, 4].-1 - | input in2: Interval[-4, 8].-2 - | output out: Interval.0 - | out <= add(in1, in2) - | """.stripMargin + val passes = Seq( + ToWorkingIR, + InferTypes, + ResolveFlows, + new InferBinaryPoints(), + new TrimIntervals(), + new InferWidths(), + new RemoveIntervals() + ) + val input = + s"""circuit Unit : + | module Unit : + | input in1: Interval[-2, 4].-1 + | input in2: Interval[-4, 8].-2 + | output out: Interval.0 + | out <= add(in1, in2) + | """.stripMargin val check = s""" - | output out : SInt<5> - | out <= shl(add(in1, shl(in2, 1)), 1) + | output out : SInt<5> + | out <= shl(add(in1, shl(in2, 1)), 1) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "implement squz properly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck) @@ -372,7 +398,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output minOff : Interval[-1, 4].1 | output offMin : Interval[-1, 4].2 """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Interval types" should "lower squz properly" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) @@ -413,7 +439,7 @@ class IntervalSpec extends FirrtlFlatSpec { | minOff <= asSInt(bits(min, 4, 0)) | offMin <= asSInt(bits(off, 5, 0)) """.stripMargin - executeTest(input, check.split("\n") map normalized, passes) + executeTest(input, check.split("\n").map(normalized), passes) } "Assigning a larger interval to a smaller interval" should "error!" in { val passes = Seq(ToWorkingIR, new ResolveAndCheck, new RemoveIntervals) @@ -424,7 +450,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out: Interval[2, 3].1 | out <= in | """.stripMargin - intercept[InvalidConnect]{ + intercept[InvalidConnect] { executeTest(input, Nil, passes) } } @@ -437,7 +463,7 @@ class IntervalSpec extends FirrtlFlatSpec { | output out: Interval[2, 3].1 | out <= in | """.stripMargin - intercept[InvalidConnect]{ + intercept[InvalidConnect] { executeTest(input, Nil, passes) } } @@ -512,7 +538,6 @@ class IntervalSpec extends FirrtlFlatSpec { ) } - "Wrap with remainder" should "error" in { intercept[WrapWithRemainder] { val input = diff --git a/src/test/scala/firrtlTests/options/OptionParserSpec.scala b/src/test/scala/firrtlTests/options/OptionParserSpec.scala index e93c9b2c..452e6cb7 100644 --- a/src/test/scala/firrtlTests/options/OptionParserSpec.scala +++ b/src/test/scala/firrtlTests/options/OptionParserSpec.scala @@ -19,66 +19,69 @@ class OptionParserSpec extends AnyFlatSpec with Matchers with firrtl.testutils.U /* An option parser that prepends to a Seq[Int] */ class IntParser extends OptionParser[AnnotationSeq]("Int Parser") { - opt[Int]("integer").abbr("n").unbounded.action( (x, c) => IntAnnotation(x) +: c ) + opt[Int]("integer").abbr("n").unbounded.action((x, c) => IntAnnotation(x) +: c) help("help") } trait DuplicateShortOption { this: OptionParser[AnnotationSeq] => - opt[Int]("not-an-integer").abbr("n").unbounded.action( (x, c) => IntAnnotation(x) +: c ) + opt[Int]("not-an-integer").abbr("n").unbounded.action((x, c) => IntAnnotation(x) +: c) } trait DuplicateLongOption { this: OptionParser[AnnotationSeq] => - opt[Int]("integer").abbr("m").unbounded.action( (x, c) => IntAnnotation(x) +: c ) + opt[Int]("integer").abbr("m").unbounded.action((x, c) => IntAnnotation(x) +: c) } trait WithIntParser { val parser = new IntParser } - behavior of "A default OptionsParser" + behavior.of("A default OptionsParser") it should "call sys.exit if terminate is called" in new WithIntParser { info("exit status of 1 for failure") - catchStatus { parser.terminate(Left("some message")) } should be (Left(1)) + catchStatus { parser.terminate(Left("some message")) } should be(Left(1)) info("exit status of 0 for success") - catchStatus { parser.terminate(Right(())) } should be (Left(0)) + catchStatus { parser.terminate(Right(())) } should be(Left(0)) } it should "print to stderr on an invalid option" in new WithIntParser { - grabStdOutErr{ parser.parse(Array("--foo"), Seq[Annotation]()) }._2 should include ("Unknown option --foo") + grabStdOutErr { parser.parse(Array("--foo"), Seq[Annotation]()) }._2 should include("Unknown option --foo") } - behavior of "An OptionParser with DoNotTerminateOnExit mixed in" + behavior.of("An OptionParser with DoNotTerminateOnExit mixed in") it should "disable sys.exit for terminate method" in { val parser = new IntParser with DoNotTerminateOnExit info("no exit for failure") - catchStatus { parser.terminate(Left("some message")) } should be (Right(())) + catchStatus { parser.terminate(Left("some message")) } should be(Right(())) info("no exit for success") - catchStatus { parser.terminate(Right(())) } should be (Right(())) + catchStatus { parser.terminate(Right(())) } should be(Right(())) } - behavior of "An OptionParser with DuplicateHandling mixed in" + behavior.of("An OptionParser with DuplicateHandling mixed in") it should "detect short duplicates" in { val parser = new IntParser with DuplicateHandling with DuplicateShortOption - intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) } - .getMessage should startWith ("Duplicate short option") + intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) }.getMessage should startWith( + "Duplicate short option" + ) } it should "detect long duplicates" in { val parser = new IntParser with DuplicateHandling with DuplicateLongOption - intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) } - .getMessage should startWith ("Duplicate long option") + intercept[OptionsException] { parser.parse(Array[String](), Seq[Annotation]()) }.getMessage should startWith( + "Duplicate long option" + ) } - behavior of "An OptionParser with ExceptOnError mixed in" + behavior.of("An OptionParser with ExceptOnError mixed in") it should "cause an OptionsException on an invalid option" in { val parser = new IntParser with ExceptOnError - intercept[OptionsException] { parser.parse(Array("--foo"), Seq[Annotation]()) } - .getMessage should include ("Unknown option") + intercept[OptionsException] { parser.parse(Array("--foo"), Seq[Annotation]()) }.getMessage should include( + "Unknown option" + ) } } diff --git a/src/test/scala/firrtlTests/options/OptionsViewSpec.scala b/src/test/scala/firrtlTests/options/OptionsViewSpec.scala index 0c868cb2..504dcdf6 100644 --- a/src/test/scala/firrtlTests/options/OptionsViewSpec.scala +++ b/src/test/scala/firrtlTests/options/OptionsViewSpec.scala @@ -2,10 +2,9 @@ package firrtlTests.options - import firrtl.options.OptionsView import firrtl.AnnotationSeq -import firrtl.annotations.{Annotation,NoTargetAnnotation} +import firrtl.annotations.{Annotation, NoTargetAnnotation} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers @@ -22,7 +21,7 @@ class OptionsViewSpec extends AnyFlatSpec with Matchers { /* An OptionsView that converts an AnnotationSeq to Option[Foo] */ implicit object FooView extends OptionsView[Foo] { private def append(foo: Foo, anno: Annotation): Foo = anno match { - case NameAnnotation(n) => foo.copy(name = Some(n)) + case NameAnnotation(n) => foo.copy(name = Some(n)) case ValueAnnotation(v) => foo.copy(value = Some(v)) case _ => foo } @@ -40,20 +39,20 @@ class OptionsViewSpec extends AnyFlatSpec with Matchers { def view(options: AnnotationSeq): Bar = options.foldLeft(Bar())(append) } - behavior of "OptionsView" + behavior.of("OptionsView") it should "convert annotations to one of two types" in { /* Some default annotations */ val annos = Seq(NameAnnotation("foo"), ValueAnnotation(42)) info("Foo conversion okay") - FooView.view(annos) should be (Foo(Some("foo"), Some(42))) + FooView.view(annos) should be(Foo(Some("foo"), Some(42))) info("Bar conversion okay") - BarView.view(annos) should be (Bar("foo")) + BarView.view(annos) should be(Bar("foo")) } - behavior of "Viewer" + behavior.of("Viewer") it should "implicitly view annotations as the specified type" in { import firrtl.options.Viewer._ @@ -62,9 +61,9 @@ class OptionsViewSpec extends AnyFlatSpec with Matchers { val annos = Seq[Annotation]() info("Foo view okay") - view[Foo](annos) should be (Foo(None, None)) + view[Foo](annos) should be(Foo(None, None)) info("Bar view okay") - view[Bar](annos) should be (Bar()) + view[Bar](annos) should be(Bar()) } } diff --git a/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala b/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala index 108f3730..f31b96fd 100644 --- a/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala +++ b/src/test/scala/firrtlTests/options/PhaseManagerSpec.scala @@ -2,9 +2,8 @@ package firrtlTests.options - import firrtl.AnnotationSeq -import firrtl.options.{DependencyManagerException, Phase, PhaseManager, Dependency} +import firrtl.options.{Dependency, DependencyManagerException, Phase, PhaseManager} import java.io.{File, PrintWriter} @@ -62,7 +61,6 @@ class F extends IdentityPhase { } } - /** [[Phase]] that requires [[C]] and invalidates [[F]] */ class G extends IdentityPhase { override def prerequisites = Seq(Dependency[C]) @@ -235,7 +233,7 @@ object UnrelatedFixture { trait InvalidatesB8Dep { this: Phase => override def invalidates(a: Phase) = a match { case _: B8Dep => true - case _ => false + case _ => false } } @@ -368,7 +366,7 @@ object OrderingFixture { class B extends IdentityPhase { override def invalidates(phase: Phase): Boolean = phase match { case _: A => true - case _ => false + case _ => false } } @@ -376,7 +374,7 @@ object OrderingFixture { override def prerequisites = Seq(Dependency[A], Dependency[B]) override def invalidates(phase: Phase): Boolean = phase match { case _: B => true - case _ => false + case _ => false } } @@ -423,7 +421,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { } - behavior of this.getClass.getName + behavior.of(this.getClass.getName) it should "do nothing if all targets are reached" in { val targets = Seq(Dependency[A], Dependency[B], Dependency[C], Dependency[D]) @@ -431,7 +429,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/DoNothing") - pm.flattenedTransformOrder should be (empty) + pm.flattenedTransformOrder should be(empty) } it should "handle a simple dependency" in { @@ -441,7 +439,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/SimpleDependency") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "handle a simple dependency with an invalidation" in { @@ -451,7 +449,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/OneInvalidate") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "handle a dependency with two invalidates optimally" in { @@ -460,7 +458,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/TwoInvalidates") - pm.flattenedTransformOrder.size should be (targets.size) + pm.flattenedTransformOrder.size should be(targets.size) } it should "throw an exception for cyclic prerequisites" in { @@ -469,8 +467,9 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/CyclicPrerequisites") - intercept[DependencyManagerException]{ pm.flattenedTransformOrder } - .getMessage should startWith ("No transform ordering possible") + intercept[DependencyManagerException] { pm.flattenedTransformOrder }.getMessage should startWith( + "No transform ordering possible" + ) } it should "throw an exception for cyclic invalidates" in { @@ -479,8 +478,9 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/CyclicInvalidates") - intercept[DependencyManagerException]{ pm.flattenedTransformOrder } - .getMessage should startWith ("No transform ordering possible") + intercept[DependencyManagerException] { pm.flattenedTransformOrder }.getMessage should startWith( + "No transform ordering possible" + ) } it should "handle a complicated graph" in { @@ -491,41 +491,31 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/Complicated") info("only one phase was recomputed") - pm.flattenedTransformOrder.size should be (targets.size + 1) + pm.flattenedTransformOrder.size should be(targets.size + 1) } it should "handle repeated recomputed analyses" in { val f = RepeatedAnalysisFixture val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.C]) val order = - Seq( classOf[f.Analysis], - classOf[f.A], - classOf[f.Analysis], - classOf[f.B], - classOf[f.Analysis], - classOf[f.C]) + Seq(classOf[f.Analysis], classOf[f.A], classOf[f.Analysis], classOf[f.B], classOf[f.Analysis], classOf[f.C]) val pm = new PhaseManager(targets) writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/RepeatedAnalysis") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "handle inverted repeated recomputed analyses" in { val f = InvertedAnalysisFixture val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.C]) val order = - Seq( classOf[f.Analysis], - classOf[f.C], - classOf[f.Analysis], - classOf[f.B], - classOf[f.Analysis], - classOf[f.A]) + Seq(classOf[f.Analysis], classOf[f.C], classOf[f.Analysis], classOf[f.B], classOf[f.Analysis], classOf[f.A]) val pm = new PhaseManager(targets) writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/InvertedRepeatedAnalysis") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } /** This test shows how the optionalPrerequisiteOf member can be used to run one transform before another. */ @@ -535,7 +525,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { info("without the custom transform it runs: First -> Second") val pm = new PhaseManager(Seq(Dependency[f.Second])) val orderNoCustom = Seq(classOf[f.First], classOf[f.Second]) - pm.flattenedTransformOrder.map(_.getClass) should be (orderNoCustom) + pm.flattenedTransformOrder.map(_.getClass) should be(orderNoCustom) info("with the custom transform it runs: First -> Custom -> Second") val pmCustom = new PhaseManager(Seq(Dependency[f.Custom], Dependency[f.Second])) @@ -543,7 +533,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pmCustom, "test_run_dir/PhaseManagerSpec/SingleDependent") - pmCustom.flattenedTransformOrder.map(_.getClass) should be (orderCustom) + pmCustom.flattenedTransformOrder.map(_.getClass) should be(orderCustom) } it should "handle chained invalidation" in { @@ -553,11 +543,11 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { val current = Seq(Dependency[f.B], Dependency[f.C], Dependency[f.D]) val pm = new PhaseManager(targets, current) - val order = Seq( classOf[f.A], classOf[f.B], classOf[f.C], classOf[f.D], classOf[f.E] ) + val order = Seq(classOf[f.A], classOf[f.B], classOf[f.C], classOf[f.D], classOf[f.E]) writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/ChainedInvalidate") - pm.flattenedTransformOrder.map(_.getClass) should be (order) + pm.flattenedTransformOrder.map(_.getClass) should be(order) } it should "maintain the order of input targets" in { @@ -565,62 +555,70 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { /** A bunch of unrelated Phases. This ensures that these run in the order in which they are specified. */ val targets = - Seq( Dependency[f.B0], - Dependency[f.B1], - Dependency[f.B2], - Dependency[f.B3], - Dependency[f.B4], - Dependency[f.B5], - Dependency[f.B6], - Dependency[f.B7], - Dependency[f.B8], - Dependency[f.B9], - Dependency[f.B10], - Dependency[f.B11], - Dependency[f.B12], - Dependency[f.B13], - Dependency[f.B14], - Dependency[f.B15] ) + Seq( + Dependency[f.B0], + Dependency[f.B1], + Dependency[f.B2], + Dependency[f.B3], + Dependency[f.B4], + Dependency[f.B5], + Dependency[f.B6], + Dependency[f.B7], + Dependency[f.B8], + Dependency[f.B9], + Dependency[f.B10], + Dependency[f.B11], + Dependency[f.B12], + Dependency[f.B13], + Dependency[f.B14], + Dependency[f.B15] + ) + /** A sequence of custom transforms that should all run after B6 and before B7. This exercises correct ordering of the * prerequisiteGraph and optionalPrerequisiteOfGraph. */ val prerequisiteTargets = - Seq( Dependency[f.B6_0], - Dependency[f.B6_1], - Dependency[f.B6_2], - Dependency[f.B6_3], - Dependency[f.B6_4], - Dependency[f.B6_5], - Dependency[f.B6_6], - Dependency[f.B6_7], - Dependency[f.B6_8], - Dependency[f.B6_9], - Dependency[f.B6_10], - Dependency[f.B6_11], - Dependency[f.B6_12], - Dependency[f.B6_13], - Dependency[f.B6_14], - Dependency[f.B6_15] ) + Seq( + Dependency[f.B6_0], + Dependency[f.B6_1], + Dependency[f.B6_2], + Dependency[f.B6_3], + Dependency[f.B6_4], + Dependency[f.B6_5], + Dependency[f.B6_6], + Dependency[f.B6_7], + Dependency[f.B6_8], + Dependency[f.B6_9], + Dependency[f.B6_10], + Dependency[f.B6_11], + Dependency[f.B6_12], + Dependency[f.B6_13], + Dependency[f.B6_14], + Dependency[f.B6_15] + ) + /** A sequence of transforms that are invalidated by B0 and only define optionalPrerequisiteOf on B8. This exercises * the ordering defined by "otherPrerequisites". */ val current = - Seq( Dependency[f.B8_0], - Dependency[f.B8_1], - Dependency[f.B8_2], - Dependency[f.B8_3], - Dependency[f.B8_4], - Dependency[f.B8_5], - Dependency[f.B8_6], - Dependency[f.B8_7], - Dependency[f.B8_8], - Dependency[f.B8_9], - Dependency[f.B8_10], - Dependency[f.B8_11], - Dependency[f.B8_12], - Dependency[f.B8_13], - Dependency[f.B8_14], - Dependency[f.B8_15] ) + Seq( + Dependency[f.B8_0], + Dependency[f.B8_1], + Dependency[f.B8_2], + Dependency[f.B8_3], + Dependency[f.B8_4], + Dependency[f.B8_5], + Dependency[f.B8_6], + Dependency[f.B8_7], + Dependency[f.B8_8], + Dependency[f.B8_9], + Dependency[f.B8_10], + Dependency[f.B8_11], + Dependency[f.B8_12], + Dependency[f.B8_13], + Dependency[f.B8_14], + Dependency[f.B8_15] + ) /** The resulting order: B0--B6, B6_0--B6_B15, B7, B8_0--B8_15, B8--B15 */ val expectedDeps = targets.slice(0, 7) ++ prerequisiteTargets ++ Some(targets(7)) ++ current ++ targets.drop(8) @@ -630,7 +628,7 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { writeGraphviz(pm, "test_run_dir/PhaseManagerSpec/DeterministicOrder") - pm.flattenedTransformOrder.map(_.getClass) should be (expectedClasses) + pm.flattenedTransformOrder.map(_.getClass) should be(expectedClasses) } it should "allow conditional placement of custom transforms" in { @@ -642,13 +640,21 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { val targetsFull = Seq(Dependency[f.Custom], Dependency[f.DoneFull]) val pmFull = new PhaseManager(targetsFull) - val expectedMinimum = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.AfterOpt], classOf[f.Custom], classOf[f.DoneMinimum]) + val expectedMinimum = + Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.AfterOpt], classOf[f.Custom], classOf[f.DoneMinimum]) writeGraphviz(pmMinimum, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/minimum") - pmMinimum.flattenedTransformOrder.map(_.getClass) should be (expectedMinimum) - - val expectedFull = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.OptFull], classOf[f.AfterOpt], classOf[f.Custom], classOf[f.DoneFull]) + pmMinimum.flattenedTransformOrder.map(_.getClass) should be(expectedMinimum) + + val expectedFull = Seq( + classOf[f.Root], + classOf[f.OptMinimum], + classOf[f.OptFull], + classOf[f.AfterOpt], + classOf[f.Custom], + classOf[f.DoneFull] + ) writeGraphviz(pmFull, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/full") - pmFull.flattenedTransformOrder.map(_.getClass) should be (expectedFull) + pmFull.flattenedTransformOrder.map(_.getClass) should be(expectedFull) } it should "support optional prerequisites" in { @@ -662,11 +668,12 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { val expectedMinimum = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.Custom], classOf[f.DoneMinimum]) writeGraphviz(pmMinimum, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/minimum") - pmMinimum.flattenedTransformOrder.map(_.getClass) should be (expectedMinimum) + pmMinimum.flattenedTransformOrder.map(_.getClass) should be(expectedMinimum) - val expectedFull = Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.OptFull], classOf[f.Custom], classOf[f.DoneFull]) + val expectedFull = + Seq(classOf[f.Root], classOf[f.OptMinimum], classOf[f.OptFull], classOf[f.Custom], classOf[f.DoneFull]) writeGraphviz(pmFull, "test_run_dir/PhaseManagerSpec/CustomAfterOptimization/full") - pmFull.flattenedTransformOrder.map(_.getClass) should be (expectedFull) + pmFull.flattenedTransformOrder.map(_.getClass) should be(expectedFull) } /** This tests a situation the ordering of edges matters. Namely, this test is dependent on the ordering in which @@ -678,13 +685,13 @@ class PhaseManagerSpec extends AnyFlatSpec with Matchers { { val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.C]) val order = Seq(classOf[f.B], classOf[f.A], classOf[f.C], classOf[f.B], classOf[f.A]) - (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be (order) + (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be(order) } { val targets = Seq(Dependency[f.A], Dependency[f.B], Dependency[f.Cx]) val order = Seq(classOf[f.B], classOf[f.A], classOf[f.Cx], classOf[f.B], classOf[f.A]) - (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be (order) + (new PhaseManager(targets)).flattenedTransformOrder.map(_.getClass) should be(order) } } diff --git a/src/test/scala/firrtlTests/options/RegistrationSpec.scala b/src/test/scala/firrtlTests/options/RegistrationSpec.scala index fa6b0fa0..821ac8b3 100644 --- a/src/test/scala/firrtlTests/options/RegistrationSpec.scala +++ b/src/test/scala/firrtlTests/options/RegistrationSpec.scala @@ -6,7 +6,7 @@ import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import java.util.ServiceLoader -import firrtl.options.{RegisteredTransform, RegisteredLibrary, ShellOption} +import firrtl.options.{RegisteredLibrary, RegisteredTransform, ShellOption} import firrtl.passes.Pass import firrtl.ir.Circuit import firrtl.annotations.NoTargetAnnotation @@ -19,10 +19,8 @@ class FooTransform extends Pass with RegisteredTransform { def run(c: Circuit): Circuit = c val options = Seq( - new ShellOption[Unit]( - longOption = "hello", - toAnnotationSeq = _ => Seq(HelloAnnotation), - helpText = "Hello option") ) + new ShellOption[Unit](longOption = "hello", toAnnotationSeq = _ => Seq(HelloAnnotation), helpText = "Hello option") + ) } @@ -30,15 +28,13 @@ class BarLibrary extends RegisteredLibrary { def name: String = "Bar" val options = Seq( - new ShellOption[Unit]( - longOption = "world", - toAnnotationSeq = _ => Seq(HelloAnnotation), - helpText = "World option") ) + new ShellOption[Unit](longOption = "world", toAnnotationSeq = _ => Seq(HelloAnnotation), helpText = "World option") + ) } class RegistrationSpec extends AnyFlatSpec with Matchers { - behavior of "RegisteredTransform" + behavior.of("RegisteredTransform") it should "FooTransform should be discovered by Java.util.ServiceLoader" in { val iter = ServiceLoader.load(classOf[RegisteredTransform]).iterator() @@ -46,10 +42,10 @@ class RegistrationSpec extends AnyFlatSpec with Matchers { while (iter.hasNext) { transforms += iter.next() } - transforms.map(_.getClass.getName) should contain ("firrtlTests.options.FooTransform") + transforms.map(_.getClass.getName) should contain("firrtlTests.options.FooTransform") } - behavior of "RegisteredLibrary" + behavior.of("RegisteredLibrary") it should "BarLibrary be discovered by Java.util.ServiceLoader" in { val iter = ServiceLoader.load(classOf[RegisteredLibrary]).iterator() @@ -57,6 +53,6 @@ class RegistrationSpec extends AnyFlatSpec with Matchers { while (iter.hasNext) { transforms += iter.next() } - transforms.map(_.getClass.getName) should contain ("firrtlTests.options.BarLibrary") + transforms.map(_.getClass.getName) should contain("firrtlTests.options.BarLibrary") } } diff --git a/src/test/scala/firrtlTests/options/ShellSpec.scala b/src/test/scala/firrtlTests/options/ShellSpec.scala index af6b2669..178b1128 100644 --- a/src/test/scala/firrtlTests/options/ShellSpec.scala +++ b/src/test/scala/firrtlTests/options/ShellSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options - import firrtl.annotations.NoTargetAnnotation import firrtl.options.Shell import org.scalatest.flatspec.AnyFlatSpec @@ -17,25 +16,26 @@ class ShellSpec extends AnyFlatSpec with Matchers { case object E extends NoTargetAnnotation trait AlphabeticalCli { this: Shell => - parser.opt[Unit]('c', "c-option").unbounded().action( (x, c) => C +: c ) - parser.opt[Unit]('d', "d-option").unbounded().action( (x, c) => D +: c ) - parser.opt[Unit]('e', "e-option").unbounded().action( (x, c) => E +: c ) } + parser.opt[Unit]('c', "c-option").unbounded().action((x, c) => C +: c) + parser.opt[Unit]('d', "d-option").unbounded().action((x, c) => D +: c) + parser.opt[Unit]('e', "e-option").unbounded().action((x, c) => E +: c) + } - behavior of "Shell" + behavior.of("Shell") it should "detect all registered libraries and transforms" in { val shell = new Shell("foo") info("Found FooTransform") - shell.registeredTransforms.map(_.getClass.getName) should contain ("firrtlTests.options.FooTransform") + shell.registeredTransforms.map(_.getClass.getName) should contain("firrtlTests.options.FooTransform") info("Found BarLibrary") - shell.registeredLibraries.map(_.getClass.getName) should contain ("firrtlTests.options.BarLibrary") + shell.registeredLibraries.map(_.getClass.getName) should contain("firrtlTests.options.BarLibrary") } it should "correctly order annotations and options" in { val shell = new Shell("foo") with AlphabeticalCli - shell.parse(Array("-c", "-d", "-e"), Seq(A, B)).toSeq should be (Seq(A, B, C, D, E)) + shell.parse(Array("-c", "-d", "-e"), Seq(A, B)).toSeq should be(Seq(A, B, C, D, E)) } } diff --git a/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala b/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala index 3401a408..f625f991 100644 --- a/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/AddDefaultsSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options.phases - import firrtl.options.{Phase, TargetDirAnnotation} import firrtl.options.phases.AddDefaults import org.scalatest.flatspec.AnyFlatSpec @@ -16,13 +15,13 @@ class AddDefaultsSpec extends AnyFlatSpec with Matchers { val defaultDir = TargetDirAnnotation(".") } - behavior of classOf[AddDefaults].toString + behavior.of(classOf[AddDefaults].toString) it should "add a TargetDirAnnotation if it does not exist" in new Fixture { - phase.transform(Seq.empty).toSeq should be (Seq(defaultDir)) + phase.transform(Seq.empty).toSeq should be(Seq(defaultDir)) } it should "don't add a TargetDirAnnotation if it exists" in new Fixture { - phase.transform(Seq(targetDir)).toSeq should be (Seq(targetDir)) + phase.transform(Seq(targetDir)).toSeq should be(Seq(targetDir)) } } diff --git a/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala b/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala index 96d6569d..62afed94 100644 --- a/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/ChecksSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options.phases - import firrtl.AnnotationSeq import firrtl.options.{OptionsException, OutputAnnotationFileAnnotation, Phase, TargetDirAnnotation} import firrtl.options.phases.Checks @@ -20,9 +19,9 @@ class ChecksSpec extends AnyFlatSpec with Matchers { val min = Seq(targetDir) def checkExceptionMessage(phase: Phase, annotations: AnnotationSeq, messageStart: String): Unit = - intercept[OptionsException]{ phase.transform(annotations) }.getMessage should startWith(messageStart) + intercept[OptionsException] { phase.transform(annotations) }.getMessage should startWith(messageStart) - behavior of classOf[Checks].toString + behavior.of(classOf[Checks].toString) it should "enforce exactly one TargetDirAnnotation" in new Fixture { info("0 target directories throws an exception") diff --git a/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala b/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala index 7d20ac89..95c2a435 100644 --- a/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/GetIncludesSpec.scala @@ -2,12 +2,10 @@ package firrtlTests.options.phases - import java.io.{File, PrintWriter} import firrtl.AnnotationSeq -import firrtl.annotations.{AnnotationFileNotFoundException, JsonProtocol, - NoTargetAnnotation} +import firrtl.annotations.{AnnotationFileNotFoundException, JsonProtocol, NoTargetAnnotation} import firrtl.options.phases.GetIncludes import firrtl.options.{InputAnnotationFileAnnotation, Phase} import firrtl.util.BackendCompilationUtilities @@ -29,10 +27,10 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU def checkAnnos(a: AnnotationSeq, b: AnnotationSeq): Unit = { info("read the expected number of annotations") - a.size should be (b.size) + a.size should be(b.size) info("annotations match exact order") - a.zip(b).foreach{ case (ax, bx) => ax should be (bx) } + a.zip(b).foreach { case (ax, bx) => ax should be(bx) } } val files = Seq( @@ -43,19 +41,21 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU new File(dir + "/e.anno.json") -> Seq(E) ) - files.foreach{ case (file, annotations) => - val pw = new PrintWriter(file) - pw.write(JsonProtocol.serialize(annotations)) - pw.close() + files.foreach { + case (file, annotations) => + val pw = new PrintWriter(file) + pw.write(JsonProtocol.serialize(annotations)) + pw.close() } class Fixture { val phase: Phase = new GetIncludes } - behavior of classOf[GetIncludes].toString + behavior.of(classOf[GetIncludes].toString) it should "throw an exception if the annotation file doesn't exit" in new Fixture { - intercept[AnnotationFileNotFoundException]{ phase.transform(Seq(ref("f"))) } - .getMessage should startWith("Annotation file") + intercept[AnnotationFileNotFoundException] { phase.transform(Seq(ref("f"))) }.getMessage should startWith( + "Annotation file" + ) } it should "read annotations from a file" in new Fixture { @@ -75,9 +75,9 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU checkAnnos(out, expect) - Seq("d", "e").foreach{ x => + Seq("d", "e").foreach { x => info(s"a warning about '$x.anno.json' was printed") - stdout should include (s"Warning: Annotation file ($dir/$x.anno.json) already included!") + stdout should include(s"Warning: Annotation file ($dir/$x.anno.json) already included!") } } @@ -90,7 +90,7 @@ class GetIncludesSpec extends AnyFlatSpec with Matchers with BackendCompilationU checkAnnos(out, expect) info("a warning about 'a.anno.json' was printed") - stdout should include (s"Warning: Annotation file ($dir/a.anno.json)") + stdout should include(s"Warning: Annotation file ($dir/a.anno.json)") } } diff --git a/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala b/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala index 0a3cce67..4fe16041 100644 --- a/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala +++ b/src/test/scala/firrtlTests/options/phases/WriteOutputAnnotationsSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.options.phases - import java.io.File import firrtl.AnnotationSeq @@ -15,7 +14,8 @@ import firrtl.options.{ PhaseException, StageOptions, TargetDirAnnotation, - WriteDeletedAnnotation} + WriteDeletedAnnotation +} import firrtl.options.Viewer.view import firrtl.options.phases.{GetIncludes, WriteOutputAnnotations} import org.scalatest.flatspec.AnyFlatSpec @@ -37,33 +37,38 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t info(s"reading '$f' works") val read = (new GetIncludes) .transform(Seq(InputAnnotationFileAnnotation(f.toString))) - .filterNot{ + .filterNot { case a @ DeletedAnnotation(_, _: InputAnnotationFileAnnotation) => true - case _ => false } + case _ => false + } info(s"annotations in file are expected size") - read.size should be (a.size) + read.size should be(a.size) read .zip(a) - .foreach{ case (read, expected) => - info(s"$read matches") - read should be (expected) } + .foreach { + case (read, expected) => + info(s"$read matches") + read should be(expected) + } f.delete() } class Fixture { val phase: Phase = new WriteOutputAnnotations } - behavior of classOf[WriteOutputAnnotations].toString + behavior.of(classOf[WriteOutputAnnotations].toString) it should "write annotations to a file (excluding DeletedAnnotations)" in new Fixture { val file = new File(dir + "/should-write-annotations-to-a-file.anno.json") - val annotations = Seq( OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.FooAnnotation, - WriteOutputAnnotationsSpec.BarAnnotation(0), - WriteOutputAnnotationsSpec.BarAnnotation(1), - DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation) ) + val annotations = Seq( + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.FooAnnotation, + WriteOutputAnnotationsSpec.BarAnnotation(0), + WriteOutputAnnotationsSpec.BarAnnotation(1), + DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation) + ) val expected = annotations.filter { case a: DeletedAnnotation => false case a => true @@ -71,31 +76,35 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t val out = phase.transform(annotations) info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) fileContainsAnnotations(file, expected) } it should "include DeletedAnnotations if a WriteDeletedAnnotation is present" in new Fixture { val file = new File(dir + "should-include-deleted.anno.json") - val annotations = Seq( OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.FooAnnotation, - WriteOutputAnnotationsSpec.BarAnnotation(0), - WriteOutputAnnotationsSpec.BarAnnotation(1), - DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation), - WriteDeletedAnnotation ) + val annotations = Seq( + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.FooAnnotation, + WriteOutputAnnotationsSpec.BarAnnotation(0), + WriteOutputAnnotationsSpec.BarAnnotation(1), + DeletedAnnotation("foo", WriteOutputAnnotationsSpec.FooAnnotation), + WriteDeletedAnnotation + ) val out = phase.transform(annotations) info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) fileContainsAnnotations(file, annotations) } it should "do nothing if no output annotation file is specified" in new Fixture { - val annotations = Seq( WriteOutputAnnotationsSpec.FooAnnotation, - WriteOutputAnnotationsSpec.BarAnnotation(0), - WriteOutputAnnotationsSpec.BarAnnotation(1) ) + val annotations = Seq( + WriteOutputAnnotationsSpec.FooAnnotation, + WriteOutputAnnotationsSpec.BarAnnotation(0), + WriteOutputAnnotationsSpec.BarAnnotation(1) + ) val out = catchWrites { phase.transform(annotations) } match { case Right(a) => @@ -106,14 +115,16 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t } info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) } it should "write CustomFileEmission annotations" in new Fixture { val file = new File("write-CustomFileEmission-annotations.anno.json") - val annotations = Seq( TargetDirAnnotation(dir), - OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.Custom("hello!") ) + val annotations = Seq( + TargetDirAnnotation(dir), + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.Custom("hello!") + ) val serializedFileName = view[StageOptions](annotations).getBuildFileName("Custom", Some(".Emission")) val expected = annotations.map { case _: WriteOutputAnnotationsSpec.Custom => WriteOutputAnnotationsSpec.Replacement(serializedFileName) @@ -123,7 +134,7 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t val out = phase.transform(annotations) info("annotations are unmodified") - out.toSeq should be (annotations) + out.toSeq should be(annotations) fileContainsAnnotations(new File(dir, file.toString), expected) @@ -133,13 +144,15 @@ class WriteOutputAnnotationsSpec extends AnyFlatSpec with Matchers with firrtl.t it should "error if multiple annotations try to write to the same file" in new Fixture { val file = new File("write-CustomFileEmission-annotations-error.anno.json") - val annotations = Seq( TargetDirAnnotation(dir), - OutputAnnotationFileAnnotation(file.toString), - WriteOutputAnnotationsSpec.Custom("foo"), - WriteOutputAnnotationsSpec.Custom("bar") ) + val annotations = Seq( + TargetDirAnnotation(dir), + OutputAnnotationFileAnnotation(file.toString), + WriteOutputAnnotationsSpec.Custom("foo"), + WriteOutputAnnotationsSpec.Custom("bar") + ) intercept[PhaseException] { phase.transform(annotations) - }.getMessage should startWith ("Multiple CustomFileEmission annotations") + }.getMessage should startWith("Multiple CustomFileEmission annotations") } } diff --git a/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala index bfc72f49..b628c1b7 100644 --- a/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala +++ b/src/test/scala/firrtlTests/passes/InferTypesFlowsAndKindsSpec.scala @@ -6,48 +6,50 @@ import firrtl.ir.SubField import firrtl.options.Dependency import firrtl.stage.TransformManager import firrtl.{InstanceKind, MemKind, NodeKind, PortKind, RegKind, WireKind} -import firrtl.{CircuitState, SinkFlow, SourceFlow, ir, passes} +import firrtl.{ir, passes, CircuitState, SinkFlow, SourceFlow} import org.scalatest.flatspec.AnyFlatSpec /** Tests the combined results of ResolveKinds, InferTypes and ResolveFlows */ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { - private val deps = Seq( - Dependency(passes.ResolveKinds), - Dependency(passes.InferTypes), - Dependency(passes.ResolveFlows)) + private val deps = + Seq(Dependency(passes.ResolveKinds), Dependency(passes.InferTypes), Dependency(passes.ResolveFlows)) private val manager = new TransformManager(deps) private def infer(src: String): ir.Circuit = manager.execute(CircuitState(firrtl.Parser.parse(src), Seq())).circuit private def getNodes(s: ir.Statement): Seq[(String, ir.Expression)] = s match { - case ir.DefNode(_, name, value) => Seq((name, value)) - case ir.Block(stmts) => stmts.flatMap(getNodes) - case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getNodes) - case _ => Seq() + case ir.DefNode(_, name, value) => Seq((name, value)) + case ir.Block(stmts) => stmts.flatMap(getNodes) + case ir.Conditionally(_, _, a, b) => Seq(a, b).flatMap(getNodes) + case _ => Seq() } private def getConnects(s: ir.Statement): Seq[ir.Connect] = s match { - case c : ir.Connect => Seq(c) - case ir.Block(stmts) => stmts.flatMap(getConnects) - case ir.Conditionally(_, _, a, b) => Seq(a,b).flatMap(getConnects) - case _ => Seq() + case c: ir.Connect => Seq(c) + case ir.Block(stmts) => stmts.flatMap(getConnects) + case ir.Conditionally(_, _, a, b) => Seq(a, b).flatMap(getConnects) + case _ => Seq() } private def getModule(c: ir.Circuit, name: String): ir.Module = c.modules.find(_.name == name).get.asInstanceOf[ir.Module] it should "infer references to ports, wires, nodes and registers" in { - val node = getNodes(getModule(infer( - """circuit m: - | module m: - | input clk: Clock - | input a: UInt<4> - | wire b : SInt<5> - | reg c: UInt<5>, clk - | node na = a - | node nb = b - | node nc = c - | node nna = na - | node na2 = a - | node a_plus_c = add(a, c) - |""".stripMargin), "m").body).toMap + val node = getNodes( + getModule( + infer("""circuit m: + | module m: + | input clk: Clock + | input a: UInt<4> + | wire b : SInt<5> + | reg c: UInt<5>, clk + | node na = a + | node nb = b + | node nc = c + | node nna = na + | node na2 = a + | node a_plus_c = add(a, c) + |""".stripMargin), + "m" + ).body + ).toMap assert(node("na").tpe == ir.UIntType(ir.IntWidth(4))) assert(node("na").asInstanceOf[ir.Reference].flow == SourceFlow) @@ -74,29 +76,29 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { } it should "infer types for references to instances" in { - val m = getModule(infer( - """circuit m: - | module other: - | output x: { y: UInt, flip z: UInt<1> } - | module m: - | inst i of other - | node i_x = i.x - | node i_x_y = i.x.y - | node i_x_y_2 = i_x.y - | node a = UInt<1>(1) - | i.x.z <= a - |""".stripMargin), "m") + val m = getModule( + infer("""circuit m: + | module other: + | output x: { y: UInt, flip z: UInt<1> } + | module m: + | inst i of other + | node i_x = i.x + | node i_x_y = i.x.y + | node i_x_y_2 = i_x.y + | node a = UInt<1>(1) + | i.x.z <= a + |""".stripMargin), + "m" + ) val node = getNodes(m.body).toMap val con = getConnects(m.body) - // node i_x_y = i.x.y assert(node("i_x_y").tpe.isInstanceOf[ir.UIntType]) // the type inference replaces all unknown widths with a variable assert(node("i_x_y").tpe.asInstanceOf[ir.UIntType].width.isInstanceOf[ir.VarWidth]) assert(node("i_x_y").asInstanceOf[ir.SubField].flow == SourceFlow) - // node i_x = i.x val x = node("i_x").asInstanceOf[ir.SubField] assert(x.tpe.isInstanceOf[ir.BundleType]) @@ -110,12 +112,10 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { assert(i.kind == InstanceKind) assert(i.flow == SourceFlow) - // node i_x_y_2 = i_x.y assert(node("i_x_y").tpe == node("i_x_y_2").tpe) assert(node("i_x_y").asInstanceOf[ir.SubField].flow == node("i_x_y_2").asInstanceOf[ir.SubField].flow) - // i.x.z <= a val (left, right) = (con.head.loc.asInstanceOf[ir.SubField], con.head.expr.asInstanceOf[ir.Reference]) @@ -131,29 +131,27 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { } it should "infer types for references to memories" in { - val c = infer( - """circuit m: - | module m: - | mem m: - | data-type => UInt - | depth => 30 - | reader => r - | writer => w - | read-latency => 1 - | write-latency => 1 - | read-under-write => undefined - | - | node m_r_addr = m.r.addr - | node m_r_data = m.r.data - | node m_w_addr = m.w.addr - | node m_w_data = m.w.data - |""".stripMargin) + val c = infer("""circuit m: + | module m: + | mem m: + | data-type => UInt + | depth => 30 + | reader => r + | writer => w + | read-latency => 1 + | write-latency => 1 + | read-under-write => undefined + | + | node m_r_addr = m.r.addr + | node m_r_data = m.r.data + | node m_w_addr = m.w.addr + | node m_w_data = m.w.data + |""".stripMargin) val m = getModule(c, "m") val node = getNodes(m.body).toMap // this might be a little flaky... val memory = m.body.asInstanceOf[ir.Block].stmts.head.asInstanceOf[ir.DefMemory] - // after InferTypes, all expressions referring to the `data` should have this type: val dataTpe = memory.dataType.asInstanceOf[ir.UIntType] val addrTpe = ir.UIntType(ir.IntWidth(5)) @@ -163,8 +161,12 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { assert(node("m_w_addr").tpe == addrTpe) assert(node("m_w_data").tpe == dataTpe) - val memory_ref = node("m_r_addr").asInstanceOf[ir.SubField].expr - .asInstanceOf[ir.SubField].expr.asInstanceOf[ir.Reference] + val memory_ref = node("m_r_addr") + .asInstanceOf[ir.SubField] + .expr + .asInstanceOf[ir.SubField] + .expr + .asInstanceOf[ir.Reference] assert(memory_ref.kind == MemKind) val mem_ref_tpe = memory_ref.tpe.asInstanceOf[ir.BundleType] val r_tpe = mem_ref_tpe.fields.find(_.name == "r").get.tpe.asInstanceOf[ir.BundleType] @@ -176,18 +178,17 @@ class InferTypesFlowsAndKindsSpec extends AnyFlatSpec { } it should "infer different instances of the same module to have the same width variable" in { - val c = infer( - """circuit m: - | module other: - | input x: UInt - | module x: - | inst i of other - | i.x <= UInt<16>(3) - | module m: - | inst x of x - | inst i of other - | i.x <= UInt<1>(1) - |""".stripMargin) + val c = infer("""circuit m: + | module other: + | input x: UInt + | module x: + | inst i of other + | i.x <= UInt<16>(3) + | module m: + | inst x of x + | inst i of other + | i.x <= UInt<1>(1) + |""".stripMargin) val m_con = getConnects(getModule(c, "m").body).head val x_con = getConnects(getModule(c, "x").body).head val other = getModule(c, "other") diff --git a/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala index 157520ea..d4caa546 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlCliSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage - import firrtl.stage.RunFirrtlTransformAnnotation import firrtl.options.Shell import firrtl.stage.FirrtlCli @@ -11,25 +10,30 @@ import org.scalatest.matchers.should.Matchers class FirrtlCliSpec extends AnyFlatSpec with Matchers { - behavior of "FirrtlCli for RunFirrtlTransformAnnotation / -fct / --custom-transforms" + behavior.of("FirrtlCli for RunFirrtlTransformAnnotation / -fct / --custom-transforms") it should "preserver transform order" in { val shell = new Shell("foo") with FirrtlCli val args = Array( - "--custom-transforms", "firrtl.transforms.BlackBoxSourceHelper,firrtl.transforms.CheckCombLoops", - "--custom-transforms", "firrtl.transforms.CombineCats", - "--custom-transforms", "firrtl.transforms.ConstantPropagation" ) + "--custom-transforms", + "firrtl.transforms.BlackBoxSourceHelper,firrtl.transforms.CheckCombLoops", + "--custom-transforms", + "firrtl.transforms.CombineCats", + "--custom-transforms", + "firrtl.transforms.ConstantPropagation" + ) val expected = Seq( classOf[firrtl.transforms.BlackBoxSourceHelper], classOf[firrtl.transforms.CheckCombLoops], classOf[firrtl.transforms.CombineCats], - classOf[firrtl.transforms.ConstantPropagation] ) + classOf[firrtl.transforms.ConstantPropagation] + ) shell .parse(args) - .collect{ case a: RunFirrtlTransformAnnotation => a } + .collect { case a: RunFirrtlTransformAnnotation => a } .zip(expected) - .map{ case (RunFirrtlTransformAnnotation(a), b) => a.getClass should be (b) } + .map { case (RunFirrtlTransformAnnotation(a), b) => a.getClass should be(b) } } } diff --git a/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala index 7d57f7ed..9274bac6 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlMainSpec.scala @@ -21,7 +21,11 @@ import org.scalatest.matchers.should.Matchers * This test uses the [[org.scalatest.FeatureSpec FeatureSpec]] intentionally as this test exercises the top-level * interface and is more suitable to an Acceptance Testing style. */ -class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers with firrtl.testutils.Utils +class FirrtlMainSpec + extends AnyFeatureSpec + with GivenWhenThen + with Matchers + with firrtl.testutils.Utils with BackendCompilationUtilities { /** Parameterizes one test of [[FirrtlMain]]. Running the [[FirrtlMain]] `main` with certain args should produce @@ -36,13 +40,14 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit * @param result expected exit code */ case class FirrtlMainTest( - args: Array[String], - circuit: Option[FirrtlCircuitFixture] = Some(new SimpleFirrtlCircuitFixture), - files: Seq[String] = Seq.empty, + args: Array[String], + circuit: Option[FirrtlCircuitFixture] = Some(new SimpleFirrtlCircuitFixture), + files: Seq[String] = Seq.empty, notFiles: Seq[String] = Seq.empty, - stdout: Option[String] = None, - stderr: Option[String] = None, - result: Int = 0) { + stdout: Option[String] = None, + stderr: Option[String] = None, + result: Int = 0) { + /** Generate a name for the test based on the arguments */ def testName: String = "args" + args.mkString("_") @@ -70,8 +75,8 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit case None => Array.empty } - p.files.foreach( f => new File(td.buildDir + s"/$f").delete() ) - p.notFiles.foreach( f => new File(td.buildDir + s"/$f").delete() ) + p.files.foreach(f => new File(td.buildDir + s"/$f").delete()) + p.notFiles.foreach(f => new File(td.buildDir + s"/$f").delete()) When(s"""the user tries to compile with '${p.argsString}'""") val (stdout, stderr, result) = @@ -80,25 +85,25 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit p.stdout match { case Some(a) => Then(s"""STDOUT should include "$a"""") - stdout should include (a) + stdout should include(a) case None => Then(s"nothing should print to STDOUT") - stdout should be (empty) + stdout should be(empty) } p.stderr match { case Some(a) => And(s"""STDERR should include "$a"""") - stderr should include (a) + stderr should include(a) case None => And(s"nothing should print to STDERR") - stderr should be (empty) + stderr should be(empty) } p.result match { case 0 => And(s"the exit code should be 0") - result shouldBe a [Right[_,_]] + result shouldBe a[Right[_, _]] case a => And(s"the exit code should be $a") result shouldBe (Left(a)) @@ -113,12 +118,11 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit p.notFiles.foreach { f => And(s"file '$f' should NOT be emitted in the target directory") val out = new File(td.buildDir + s"/$f") - out should not (exist) + out should not(exist) } } } - /** Test fixture that links to the [[FirrtlMain]] object. This could be done without, but its use matches the * Given/When/Then style more accurately. */ @@ -137,7 +141,7 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit } trait FirrtlCircuitFixture { - val main: String + val main: String val input: String } @@ -185,13 +189,13 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val (out, _, result) = grabStdOutErr { catchStatus { f.stage.main(Array("--help")) } } Then("the usage text should be shown") - out should include ("Usage: firrtl") + out should include("Usage: firrtl") And("usage text should show known registered transforms") - out should include ("--no-dce") + out should include("--no-dce") And("usage text should show known registered libraries") - out should include ("MemLib Options") + out should include("MemLib Options") info("""And the exit code should be 0, but scopt catches all throwable, so we can't check this... ¯\_(ツ)_/¯""") // And("the exit code should be zero") @@ -200,67 +204,89 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit Seq( /* Test all standard emitters with and without annotation file outputs */ - FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl"), - files = Seq("Top.fir")), - FirrtlMainTest(args = Array("-X", "high", "-E", "high"), - stdout = defaultStdOut, - files = Seq("Top.hi.fir")), - FirrtlMainTest(args = Array("-X", "middle", "-E", "middle", "-foaf", "Top"), - stdout = defaultStdOut, - files = Seq("Top.mid.fir", "Top.anno.json")), - FirrtlMainTest(args = Array("-X", "low", "-E", "low", "-foaf", "annotations.anno.json"), - stdout = defaultStdOut, - files = Seq("Top.lo.fir", "annotations.anno.json")), - FirrtlMainTest(args = Array("-X", "verilog", "-E", "verilog", "-foaf", "foo.anno"), - stdout = defaultStdOut, - files = Seq("Top.v", "foo.anno.anno.json")), - FirrtlMainTest(args = Array("-X", "sverilog", "-E", "sverilog", "-foaf", "foo.json"), - stdout = defaultStdOut, - files = Seq("Top.sv", "foo.json.anno.json")), - + FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl"), files = Seq("Top.fir")), + FirrtlMainTest(args = Array("-X", "high", "-E", "high"), stdout = defaultStdOut, files = Seq("Top.hi.fir")), + FirrtlMainTest( + args = Array("-X", "middle", "-E", "middle", "-foaf", "Top"), + stdout = defaultStdOut, + files = Seq("Top.mid.fir", "Top.anno.json") + ), + FirrtlMainTest( + args = Array("-X", "low", "-E", "low", "-foaf", "annotations.anno.json"), + stdout = defaultStdOut, + files = Seq("Top.lo.fir", "annotations.anno.json") + ), + FirrtlMainTest( + args = Array("-X", "verilog", "-E", "verilog", "-foaf", "foo.anno"), + stdout = defaultStdOut, + files = Seq("Top.v", "foo.anno.anno.json") + ), + FirrtlMainTest( + args = Array("-X", "sverilog", "-E", "sverilog", "-foaf", "foo.json"), + stdout = defaultStdOut, + files = Seq("Top.sv", "foo.json.anno.json") + ), /* Test all one file per module emitters */ - FirrtlMainTest(args = Array("-X", "none", "-e", "chirrtl"), - files = Seq("Top.fir", "Child.fir")), - FirrtlMainTest(args = Array("-X", "high", "-e", "high"), - stdout = defaultStdOut, - files = Seq("Top.hi.fir", "Child.hi.fir")), - FirrtlMainTest(args = Array("-X", "middle", "-e", "middle"), - stdout = defaultStdOut, - files = Seq("Top.mid.fir", "Child.mid.fir")), - FirrtlMainTest(args = Array("-X", "low", "-e", "low"), - stdout = defaultStdOut, - files = Seq("Top.lo.fir", "Child.lo.fir")), - FirrtlMainTest(args = Array("-X", "verilog", "-e", "verilog"), - stdout = defaultStdOut, - files = Seq("Top.v", "Child.v")), - FirrtlMainTest(args = Array("-X", "sverilog", "-e", "sverilog"), - stdout = defaultStdOut, - files = Seq("Top.sv", "Child.sv")), - + FirrtlMainTest(args = Array("-X", "none", "-e", "chirrtl"), files = Seq("Top.fir", "Child.fir")), + FirrtlMainTest( + args = Array("-X", "high", "-e", "high"), + stdout = defaultStdOut, + files = Seq("Top.hi.fir", "Child.hi.fir") + ), + FirrtlMainTest( + args = Array("-X", "middle", "-e", "middle"), + stdout = defaultStdOut, + files = Seq("Top.mid.fir", "Child.mid.fir") + ), + FirrtlMainTest( + args = Array("-X", "low", "-e", "low"), + stdout = defaultStdOut, + files = Seq("Top.lo.fir", "Child.lo.fir") + ), + FirrtlMainTest( + args = Array("-X", "verilog", "-e", "verilog"), + stdout = defaultStdOut, + files = Seq("Top.v", "Child.v") + ), + FirrtlMainTest( + args = Array("-X", "sverilog", "-e", "sverilog"), + stdout = defaultStdOut, + files = Seq("Top.sv", "Child.sv") + ), /* Test mixing of -E with -e */ - FirrtlMainTest(args = Array("-X", "middle", "-E", "high", "-e", "middle"), - stdout = defaultStdOut, - files = Seq("Top.hi.fir", "Top.mid.fir", "Child.mid.fir"), - notFiles = Seq("Child.hi.fir")), - + FirrtlMainTest( + args = Array("-X", "middle", "-E", "high", "-e", "middle"), + stdout = defaultStdOut, + files = Seq("Top.hi.fir", "Top.mid.fir", "Child.mid.fir"), + notFiles = Seq("Child.hi.fir") + ), /* Test changes to output file name */ - FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl", "-o", "foo"), - files = Seq("foo.fir")), - FirrtlMainTest(args = Array("-X", "high", "-E", "high", "-o", "foo"), - stdout = defaultStdOut, - files = Seq("foo.hi.fir")), - FirrtlMainTest(args = Array("-X", "middle", "-E", "middle", "-o", "foo.middle"), - stdout = defaultStdOut, - files = Seq("foo.middle.mid.fir")), - FirrtlMainTest(args = Array("-X", "low", "-E", "low", "-o", "foo.lo.fir"), - stdout = defaultStdOut, - files = Seq("foo.lo.fir")), - FirrtlMainTest(args = Array("-X", "verilog", "-E", "verilog", "-o", "foo.sv"), - stdout = defaultStdOut, - files = Seq("foo.sv.v")), - FirrtlMainTest(args = Array("-X", "sverilog", "-E", "sverilog", "-o", "Foo"), - stdout = defaultStdOut, - files = Seq("Foo.sv")) + FirrtlMainTest(args = Array("-X", "none", "-E", "chirrtl", "-o", "foo"), files = Seq("foo.fir")), + FirrtlMainTest( + args = Array("-X", "high", "-E", "high", "-o", "foo"), + stdout = defaultStdOut, + files = Seq("foo.hi.fir") + ), + FirrtlMainTest( + args = Array("-X", "middle", "-E", "middle", "-o", "foo.middle"), + stdout = defaultStdOut, + files = Seq("foo.middle.mid.fir") + ), + FirrtlMainTest( + args = Array("-X", "low", "-E", "low", "-o", "foo.lo.fir"), + stdout = defaultStdOut, + files = Seq("foo.lo.fir") + ), + FirrtlMainTest( + args = Array("-X", "verilog", "-E", "verilog", "-o", "foo.sv"), + stdout = defaultStdOut, + files = Seq("foo.sv.v") + ), + FirrtlMainTest( + args = Array("-X", "sverilog", "-E", "sverilog", "-o", "Foo"), + stdout = defaultStdOut, + files = Seq("Foo.sv") + ) ) .foreach(runStageExpectFiles) @@ -272,15 +298,17 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val out = new File(s"$outName.hi.fir") out.delete() val result = catchStatus { - f.stage.main(Array("-i", "src/test/resources/integration/GCDTester.fir", "-o", outName, "-X", "high", - "-E", "high")) } + f.stage.main( + Array("-i", "src/test/resources/integration/GCDTester.fir", "-o", outName, "-X", "high", "-E", "high") + ) + } Then("outputs should be written to current directory") out should (exist) out.delete() And("the exit code should be 0") - result shouldBe a [Right[_,_]] + result shouldBe a[Right[_, _]] } Scenario("User provides Protocol Buffer input") { @@ -292,8 +320,9 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit copyResourceToFile("/integration/GCDTester.pb", protobufIn) When("the user tries to compile to High FIRRTL") - f.stage.main(Array("-i", protobufIn.toString, "-X", "high", "-E", "high", "-td", td.buildDir.toString, - "-o", "Foo")) + f.stage.main( + Array("-i", protobufIn.toString, "-X", "high", "-E", "high", "-td", td.buildDir.toString, "-o", "Foo") + ) Then("the output should be the same as using FIRRTL input") new File(td.buildDir + "/Foo.hi.fir") should (exist) @@ -311,16 +340,16 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val (out, err, result) = grabStdOutErr { catchStatus { f.stage.main(Array.empty) } } Then("an error should be printed on stdout") - out should include (s"Error: Unable to determine FIRRTL source to read") + out should include(s"Error: Unable to determine FIRRTL source to read") And("no usage text should be shown") - out should not include ("Usage: firrtl") + (out should not).include("Usage: firrtl") And("nothing should print to stderr") - err should be (empty) + err should be(empty) And("the exit code should be 1") - result should be (Left(1)) + result should be(Left(1)) } } @@ -333,22 +362,30 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit Seq( /* Erroneous inputs */ - FirrtlMainTest(args = Array("--thisIsNotASupportedOption"), - circuit = None, - stdout = Some("Error: Unknown option"), - result = 1), - FirrtlMainTest(args = Array("-i", "foo", "--info-mode", "Use"), - circuit = None, - stdout = Some("Unknown info mode 'Use'! (Did you misspell it?)"), - result = 1), - FirrtlMainTest(args = Array("-i", "test_run_dir/I-DO-NOT-EXIST"), - circuit = None, - stdout = Some("Input file 'test_run_dir/I-DO-NOT-EXIST' not found!"), - result = 1), - FirrtlMainTest(args = Array("-i", "foo", "-X", "Verilog"), - circuit = None, - stdout = Some("Unknown compiler name 'Verilog'! (Did you misspell it?)"), - result = 1) + FirrtlMainTest( + args = Array("--thisIsNotASupportedOption"), + circuit = None, + stdout = Some("Error: Unknown option"), + result = 1 + ), + FirrtlMainTest( + args = Array("-i", "foo", "--info-mode", "Use"), + circuit = None, + stdout = Some("Unknown info mode 'Use'! (Did you misspell it?)"), + result = 1 + ), + FirrtlMainTest( + args = Array("-i", "test_run_dir/I-DO-NOT-EXIST"), + circuit = None, + stdout = Some("Input file 'test_run_dir/I-DO-NOT-EXIST' not found!"), + result = 1 + ), + FirrtlMainTest( + args = Array("-i", "foo", "-X", "Verilog"), + circuit = None, + stdout = Some("Unknown compiler name 'Verilog'! (Did you misspell it?)"), + result = 1 + ) ) .foreach(runStageExpectFiles) @@ -364,13 +401,13 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit val (out, _, result) = grabStdOutErr { catchStatus { f.stage.main(Array("--show-registrations")) } } Then("stdout should show registered transforms") - out should include ("firrtl.passes.InlineInstances") + out should include("firrtl.passes.InlineInstances") And("stdout should show registered libraries") out should include("firrtl.passes.memlib.MemLibOptions") And("the exit code should be 1") - result should be (Left(1)) + result should be(Left(1)) } } @@ -380,23 +417,21 @@ class FirrtlMainSpec extends AnyFeatureSpec with GivenWhenThen with Matchers wit def optionRemoved(a: String): Option[String] = Some(s"Option '$a' was removed as part of the FIRRTL Stage refactor") Seq( /* Removed --top-name/-tn handling */ - FirrtlMainTest(args = Array("--top-name", "foo"), - circuit = None, - stdout = optionRemoved("--top-name/-tn"), - result = 1), - FirrtlMainTest(args = Array("-tn"), - circuit = None, - stdout = optionRemoved("--top-name/-tn"), - result = 1), + FirrtlMainTest( + args = Array("--top-name", "foo"), + circuit = None, + stdout = optionRemoved("--top-name/-tn"), + result = 1 + ), + FirrtlMainTest(args = Array("-tn"), circuit = None, stdout = optionRemoved("--top-name/-tn"), result = 1), /* Removed --split-modules/-fsm handling */ - FirrtlMainTest(args = Array("--split-modules"), - circuit = None, - stdout = optionRemoved("--split-modules/-fsm"), - result = 1), - FirrtlMainTest(args = Array("-fsm"), - circuit = None, - stdout = optionRemoved("--split-modules/-fsm"), - result = 1) + FirrtlMainTest( + args = Array("--split-modules"), + circuit = None, + stdout = optionRemoved("--split-modules/-fsm"), + result = 1 + ), + FirrtlMainTest(args = Array("-fsm"), circuit = None, stdout = optionRemoved("--split-modules/-fsm"), result = 1) ) .foreach(runStageExpectFiles) } diff --git a/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala b/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala index 4161d29b..00aa8e6a 100644 --- a/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala +++ b/src/test/scala/firrtlTests/stage/FirrtlOptionsViewSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage - import firrtl.stage._ import firrtl.{ir, NoneCompiler, Parser} @@ -17,7 +16,7 @@ class Baz_Compiler extends NoneCompiler class FirrtlOptionsViewSpec extends AnyFlatSpec with Matchers { - behavior of FirrtlOptionsView.getClass.getName + behavior.of(FirrtlOptionsView.getClass.getName) def circuitString(main: String): String = s"""|circuit $main: | module $main: @@ -37,9 +36,9 @@ class FirrtlOptionsViewSpec extends AnyFlatSpec with Matchers { it should "construct a view from an AnnotationSeq" in { val out = view[FirrtlOptions](annotations) - out.outputFileName should be (Some("bar")) - out.infoModeName should be ("use") - out.firrtlCircuit should be (Some(grault)) + out.outputFileName should be(Some("bar")) + out.infoModeName should be("use") + out.firrtlCircuit should be(Some(grault)) } /* This test only exists to catch changes to existing behavior. This test does not indicate that this is the correct @@ -57,9 +56,9 @@ class FirrtlOptionsViewSpec extends AnyFlatSpec with Matchers { val out = view[FirrtlOptions](annotations ++ overwrites) - out.outputFileName should be (Some("bar_")) - out.infoModeName should be ("gen") - out.firrtlCircuit should be (Some(grault_)) + out.outputFileName should be(Some("bar_")) + out.infoModeName should be("gen") + out.firrtlCircuit should be(Some(grault_)) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala index 58026ecd..aac18dee 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddCircuitSpec.scala @@ -2,12 +2,16 @@ package firrtlTests.stage.phases - import firrtl.Parser import firrtl.annotations.NoTargetAnnotation import firrtl.options.{OptionsException, Phase, PhasePrerequisiteException} -import firrtl.stage.{CircuitOption, FirrtlCircuitAnnotation, FirrtlSourceAnnotation, InfoModeAnnotation, - FirrtlFileAnnotation} +import firrtl.stage.{ + CircuitOption, + FirrtlCircuitAnnotation, + FirrtlFileAnnotation, + FirrtlSourceAnnotation, + InfoModeAnnotation +} import firrtl.stage.phases.AddCircuit import java.io.{File, FileWriter} @@ -21,7 +25,7 @@ class AddCircuitSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new AddCircuit } - behavior of classOf[AddCircuit].toString + behavior.of(classOf[AddCircuit].toString) def firrtlSource(name: String): String = s"""|circuit $name: @@ -32,15 +36,16 @@ class AddCircuitSpec extends AnyFlatSpec with Matchers { |""".stripMargin it should "throw a PhasePrerequisiteException if a CircuitOption exists without an InfoModeAnnotation" in - new Fixture { - {the [PhasePrerequisiteException] thrownBy phase.transform(Seq(FirrtlSourceAnnotation("foo")))} - .message should startWith ("An InfoModeAnnotation must be present") - } + new Fixture { + { + the[PhasePrerequisiteException] thrownBy phase.transform(Seq(FirrtlSourceAnnotation("foo"))) + }.message should startWith("An InfoModeAnnotation must be present") + } it should "do nothing if no CircuitOption annotations are present" in new Fixture { val annotations = (1 to 10).map(FooAnnotation) ++ ('a' to 'm').map(_.toString).map(BarAnnotation) :+ InfoModeAnnotation("ignore") - phase.transform(annotations).toSeq should be (annotations.toSeq) + phase.transform(annotations).toSeq should be(annotations.toSeq) } val (file, fileCircuit) = { @@ -66,39 +71,44 @@ class AddCircuitSpec extends AnyFlatSpec with Matchers { FirrtlFileAnnotation(file), FirrtlSourceAnnotation(source), FirrtlCircuitAnnotation(circuit), - InfoModeAnnotation("ignore") ) + InfoModeAnnotation("ignore") + ) val annotationsExpected = Set( FirrtlCircuitAnnotation(fileCircuit), FirrtlCircuitAnnotation(sourceCircuit), - FirrtlCircuitAnnotation(circuit) ) + FirrtlCircuitAnnotation(circuit) + ) val out = phase.transform(annotations).toSeq info("generated expected FirrtlCircuitAnnotations") - out.collect{ case a: FirrtlCircuitAnnotation => a}.toSet should be (annotationsExpected) + out.collect { case a: FirrtlCircuitAnnotation => a }.toSet should be(annotationsExpected) info("all CircuitOptions were removed") - out.collect{ case a: CircuitOption => a } should be (empty) + out.collect { case a: CircuitOption => a } should be(empty) } it should """add info for a FirrtlFileAnnotation with a "gen" info mode""" in new Fixture { - phase.transform(Seq(InfoModeAnnotation("gen"), FirrtlFileAnnotation(file))) - .collectFirst{ case a: FirrtlCircuitAnnotation => a.circuit.serialize } - .get should include ("AddCircuitSpec") + phase + .transform(Seq(InfoModeAnnotation("gen"), FirrtlFileAnnotation(file))) + .collectFirst { case a: FirrtlCircuitAnnotation => a.circuit.serialize } + .get should include("AddCircuitSpec") } it should """add info for a FirrtlSourceAnnotation with an "append" info mode""" in new Fixture { - phase.transform(Seq(InfoModeAnnotation("append"), FirrtlSourceAnnotation(source))) - .collectFirst{ case a: FirrtlCircuitAnnotation => a.circuit.serialize } - .get should include ("anonymous source") + phase + .transform(Seq(InfoModeAnnotation("append"), FirrtlSourceAnnotation(source))) + .collectFirst { case a: FirrtlCircuitAnnotation => a.circuit.serialize } + .get should include("anonymous source") } it should "throw an OptionsException if the specified file doesn't exist" in new Fixture { val a = Seq(InfoModeAnnotation("ignore"), FirrtlFileAnnotation("test_run_dir/I-DO-NOT-EXIST")) - {the [OptionsException] thrownBy phase.transform(a)} - .message should startWith (s"Input file 'test_run_dir/I-DO-NOT-EXIST' not found") + { the[OptionsException] thrownBy phase.transform(a) }.message should startWith( + s"Input file 'test_run_dir/I-DO-NOT-EXIST' not found" + ) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala index b600e6c5..686c42ad 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddDefaultsSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.NoneCompiler import firrtl.annotations.Annotation import firrtl.stage.phases.AddDefaults @@ -16,24 +15,25 @@ class AddDefaultsSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new AddDefaults } - behavior of classOf[AddDefaults].toString + behavior.of(classOf[AddDefaults].toString) it should "add expected default annotations and nothing else" in new Fixture { val expected = Seq( (a: Annotation) => a match { case BlackBoxTargetDirAnno(b) => b == TargetDirAnnotation().directory }, - (a: Annotation) => a match { case RunFirrtlTransformAnnotation(e: firrtl.Emitter) => - Dependency.fromTransform(e) == Dependency[firrtl.VerilogEmitter] }, - (a: Annotation) => a match { case InfoModeAnnotation(b) => b == InfoModeAnnotation().modeName } ) - - phase.transform(Seq.empty).zip(expected).map { case (x, f) => f(x) should be (true) } + (a: Annotation) => + a match { + case RunFirrtlTransformAnnotation(e: firrtl.Emitter) => + Dependency.fromTransform(e) == Dependency[firrtl.VerilogEmitter] + }, + (a: Annotation) => a match { case InfoModeAnnotation(b) => b == InfoModeAnnotation().modeName } + ) + + phase.transform(Seq.empty).zip(expected).map { case (x, f) => f(x) should be(true) } } it should "not overwrite existing annotations" in new Fixture { - val input = Seq( - BlackBoxTargetDirAnno("foo"), - CompilerAnnotation(new NoneCompiler()), - InfoModeAnnotation("ignore")) + val input = Seq(BlackBoxTargetDirAnno("foo"), CompilerAnnotation(new NoneCompiler()), InfoModeAnnotation("ignore")) - phase.transform(input).toSeq should be (input) + phase.transform(input).toSeq should be(input) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala index 941f1883..1252090b 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddImplicitEmitterSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.{EmitAllModulesAnnotation, EmitCircuitAnnotation, HighFirrtlEmitter, VerilogCompiler} import firrtl.annotations.NoTargetAnnotation import firrtl.options.Phase @@ -20,27 +19,26 @@ class AddImplicitEmitterSpec extends AnyFlatSpec with Matchers { val someAnnos = Seq(FooAnnotation(1), FooAnnotation(2), BarAnnotation("bar")) - behavior of classOf[AddImplicitEmitter].toString + behavior.of(classOf[AddImplicitEmitter].toString) it should "do nothing if no CompilerAnnotation is present" in new Fixture { - phase.transform(someAnnos).toSeq should be (someAnnos) + phase.transform(someAnnos).toSeq should be(someAnnos) } it should "add an EmitCircuitAnnotation derived from a CompilerAnnotation" in new Fixture { val input = CompilerAnnotation(new VerilogCompiler) +: someAnnos - val expected = input.flatMap{ - case a@ CompilerAnnotation(b) => Seq(a, - RunFirrtlTransformAnnotation(b.emitter), - EmitCircuitAnnotation(b.emitter.getClass)) + val expected = input.flatMap { + case a @ CompilerAnnotation(b) => + Seq(a, RunFirrtlTransformAnnotation(b.emitter), EmitCircuitAnnotation(b.emitter.getClass)) case a => Some(a) } - phase.transform(input).toSeq should be (expected) + phase.transform(input).toSeq should be(expected) } it should "not add an EmitCircuitAnnotation if an EmitAnnotation already exists" in new Fixture { - val input = Seq(CompilerAnnotation(new VerilogCompiler), - EmitAllModulesAnnotation(classOf[HighFirrtlEmitter])) ++ someAnnos - phase.transform(input).toSeq should be (input) + val input = + Seq(CompilerAnnotation(new VerilogCompiler), EmitAllModulesAnnotation(classOf[HighFirrtlEmitter])) ++ someAnnos + phase.transform(input).toSeq should be(input) } } diff --git a/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala b/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala index 5ec051f4..499b05ae 100644 --- a/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/AddImplicitOutputFileSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.{ChirrtlEmitter, EmitAllModulesAnnotation, Parser} import firrtl.options.Phase import firrtl.stage.{FirrtlCircuitAnnotation, OutputFileAnnotation} @@ -21,27 +20,27 @@ class AddImplicitOutputFileSpec extends AnyFlatSpec with Matchers { val circuit = Parser.parse(foo) - behavior of classOf[AddImplicitOutputFile].toString + behavior.of(classOf[AddImplicitOutputFile].toString) it should "default to an output file named 'a'" in new Fixture { - phase.transform(Seq.empty).toSeq should be (Seq(OutputFileAnnotation("a"))) + phase.transform(Seq.empty).toSeq should be(Seq(OutputFileAnnotation("a"))) } it should "set the output file based on a FirrtlCircuitAnnotation's main" in new Fixture { val in = Seq(FirrtlCircuitAnnotation(circuit)) val out = OutputFileAnnotation(circuit.main) +: in - phase.transform(in).toSeq should be (out) + phase.transform(in).toSeq should be(out) } it should "do nothing if an OutputFileAnnotation or EmitAllModulesAnnotation already exists" in new Fixture { info("OutputFileAnnotation works") val outputFile = Seq(OutputFileAnnotation("Bar"), FirrtlCircuitAnnotation(circuit)) - phase.transform(outputFile).toSeq should be (outputFile) + phase.transform(outputFile).toSeq should be(outputFile) info("EmitAllModulesAnnotation works") val eam = Seq(EmitAllModulesAnnotation(classOf[ChirrtlEmitter]), FirrtlCircuitAnnotation(circuit)) - phase.transform(eam).toSeq should be (eam) + phase.transform(eam).toSeq should be(eam) } } diff --git a/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala b/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala index e10bbe6d..65516da5 100644 --- a/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/ChecksSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import firrtl.stage._ import firrtl.{AnnotationSeq, ChirrtlEmitter, EmitAllModulesAnnotation, NoneCompiler} @@ -25,9 +24,9 @@ class ChecksSpec extends AnyFlatSpec with Matchers { val min = Seq(inputFile, goodCompiler, infoMode) def checkExceptionMessage(phase: Phase, annotations: AnnotationSeq, messageStart: String): Unit = - intercept[OptionsException]{ phase.transform(annotations) }.getMessage should startWith(messageStart) + intercept[OptionsException] { phase.transform(annotations) }.getMessage should startWith(messageStart) - behavior of classOf[Checks].toString + behavior.of(classOf[Checks].toString) it should "require exactly one input source" in new Fixture { info("0 input source causes an exception") @@ -74,8 +73,11 @@ class ChecksSpec extends AnyFlatSpec with Matchers { it should "enforce exactly one info mode" in new Fixture { info("0 info modes should throw an exception") - checkExceptionMessage(phase, Seq(inputFile, goodCompiler), - "Exactly one info mode must be specified, but none found") + checkExceptionMessage( + phase, + Seq(inputFile, goodCompiler), + "Exactly one info mode must be specified, but none found" + ) info("2 info modes should throw an exception") val i = infoMode.modeName diff --git a/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala b/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala index 0446d4a3..12ec66c2 100644 --- a/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/CompilerSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import scala.collection.mutable import firrtl.{Compiler => _, _} @@ -16,10 +15,10 @@ class CompilerSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new Compiler } - behavior of classOf[Compiler].toString + behavior.of(classOf[Compiler].toString) it should "do nothing for an empty AnnotationSeq" in new Fixture { - phase.transform(Seq.empty).toSeq should be (empty) + phase.transform(Seq.empty).toSeq should be(empty) } /** A circuit with a parameterized main (top name) that is different at High, Mid, and Low FIRRTL forms. */ @@ -36,11 +35,9 @@ class CompilerSpec extends AnyFlatSpec with Matchers { val circuitIn = Parser.parse(chirrtl("top")) val circuitOut = compiler.compile(CircuitState(circuitIn, ChirrtlForm), Seq.empty).circuit - val input = Seq( - FirrtlCircuitAnnotation(circuitIn), - CompilerAnnotation(compiler) ) + val input = Seq(FirrtlCircuitAnnotation(circuitIn), CompilerAnnotation(compiler)) - phase.transform(input).toSeq should be (Seq(FirrtlCircuitAnnotation(circuitOut))) + phase.transform(input).toSeq should be(Seq(FirrtlCircuitAnnotation(circuitOut))) } it should "compile multiple FirrtlCircuitAnnotations" in new Fixture { @@ -50,32 +47,31 @@ class CompilerSpec extends AnyFlatSpec with Matchers { new MiddleFirrtlCompiler, new LowFirrtlCompiler, new VerilogCompiler, - new SystemVerilogCompiler ) + new SystemVerilogCompiler + ) val (ce, hfe, mfe, lfe, ve, sve) = ( new ChirrtlEmitter, new HighFirrtlEmitter, new MiddleFirrtlEmitter, new LowFirrtlEmitter, new VerilogEmitter, - new SystemVerilogEmitter ) + new SystemVerilogEmitter + ) val a = Seq( /* Default Compiler is HighFirrtlCompiler */ CompilerAnnotation(hfc), - /* First compiler group, use NoneCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("a"))), CompilerAnnotation(nc), RunFirrtlTransformAnnotation(ce), EmitCircuitAnnotation(ce.getClass), - /* Second compiler group, use default HighFirrtlCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("b"))), RunFirrtlTransformAnnotation(ce), EmitCircuitAnnotation(ce.getClass), RunFirrtlTransformAnnotation(hfe), EmitCircuitAnnotation(hfe.getClass), - /* Third compiler group, use MiddleFirrtlCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("c"))), CompilerAnnotation(mfc), @@ -85,7 +81,6 @@ class CompilerSpec extends AnyFlatSpec with Matchers { EmitCircuitAnnotation(hfe.getClass), RunFirrtlTransformAnnotation(mfe), EmitCircuitAnnotation(mfe.getClass), - /* Fourth compiler group, use LowFirrtlCompiler*/ FirrtlCircuitAnnotation(Parser.parse(chirrtl("d"))), CompilerAnnotation(lfc), @@ -97,7 +92,6 @@ class CompilerSpec extends AnyFlatSpec with Matchers { EmitCircuitAnnotation(mfe.getClass), RunFirrtlTransformAnnotation(lfe), EmitCircuitAnnotation(lfe.getClass), - /* Fifth compiler group, use VerilogCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("e"))), CompilerAnnotation(vc), @@ -111,7 +105,6 @@ class CompilerSpec extends AnyFlatSpec with Matchers { EmitCircuitAnnotation(lfe.getClass), RunFirrtlTransformAnnotation(ve), EmitCircuitAnnotation(ve.getClass), - /* Sixth compiler group, use SystemVerilogCompiler */ FirrtlCircuitAnnotation(Parser.parse(chirrtl("f"))), CompilerAnnotation(svc), @@ -130,14 +123,10 @@ class CompilerSpec extends AnyFlatSpec with Matchers { val output = phase.transform(a) info("with the same number of output FirrtlCircuitAnnotations") - output - .collect{ case a: FirrtlCircuitAnnotation => a } - .size should be (6) + output.collect { case a: FirrtlCircuitAnnotation => a }.size should be(6) info("and all expected EmittedAnnotations should be generated") - output - .collect{ case a: EmittedAnnotation[_] => a } - .size should be (20) + output.collect { case a: EmittedAnnotation[_] => a }.size should be(20) } it should "run transforms in sequential order" in new Fixture { @@ -145,20 +134,23 @@ class CompilerSpec extends AnyFlatSpec with Matchers { val circuitIn = Parser.parse(chirrtl("top")) val annotations = - Seq( FirrtlCircuitAnnotation(circuitIn), - CompilerAnnotation(new VerilogCompiler), - RunFirrtlTransformAnnotation(new FirstTransform), - RunFirrtlTransformAnnotation(new SecondTransform) ) + Seq( + FirrtlCircuitAnnotation(circuitIn), + CompilerAnnotation(new VerilogCompiler), + RunFirrtlTransformAnnotation(new FirstTransform), + RunFirrtlTransformAnnotation(new SecondTransform) + ) phase.transform(annotations) - CompilerSpec.globalState.toSeq should be (Seq(classOf[FirstTransform], classOf[SecondTransform])) + CompilerSpec.globalState.toSeq should be(Seq(classOf[FirstTransform], classOf[SecondTransform])) } } object CompilerSpec { - private[CompilerSpec] val globalState: mutable.Queue[Class[_ <: Transform]] = mutable.Queue.empty[Class[_ <: Transform]] + private[CompilerSpec] val globalState: mutable.Queue[Class[_ <: Transform]] = + mutable.Queue.empty[Class[_ <: Transform]] class LoggingTransform extends Transform { override def inputForm = UnknownForm diff --git a/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala b/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala index bbec32fe..73ee455d 100644 --- a/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala +++ b/src/test/scala/firrtlTests/stage/phases/WriteEmittedSpec.scala @@ -2,7 +2,6 @@ package firrtlTests.stage.phases - import java.io.File import firrtl._ @@ -22,21 +21,22 @@ class WriteEmittedSpec extends AnyFlatSpec with Matchers { class Fixture { val phase: Phase = new WriteEmitted } - behavior of classOf[WriteEmitted].toString + behavior.of(classOf[WriteEmitted].toString) it should "write emitted circuits" in new Fixture { val annotations = Seq( TargetDirAnnotation("test_run_dir/WriteEmittedSpec"), EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("foo", "", ".foocircuit")), EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("bar", "", ".barcircuit")), - EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit("baz", "", ".bazcircuit")) ) + EmittedVerilogCircuitAnnotation(EmittedVerilogCircuit("baz", "", ".bazcircuit")) + ) val expected = Seq("foo.foocircuit", "bar.barcircuit", "baz.bazcircuit") .map(a => new File(s"test_run_dir/WriteEmittedSpec/$a")) info("annotations are unmodified") - phase.transform(annotations).toSeq should be (removeEmitted(annotations).toSeq) + phase.transform(annotations).toSeq should be(removeEmitted(annotations).toSeq) - expected.foreach{ a => + expected.foreach { a => info(s"$a was written") a should (exist) a.delete() @@ -47,11 +47,12 @@ class WriteEmittedSpec extends AnyFlatSpec with Matchers { val annotations = Seq( TargetDirAnnotation("test_run_dir/WriteEmittedSpec"), OutputFileAnnotation("quux"), - EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("qux", "", ".quxcircuit")) ) + EmittedFirrtlCircuitAnnotation(EmittedFirrtlCircuit("qux", "", ".quxcircuit")) + ) val expected = new File("test_run_dir/WriteEmittedSpec/quux.quxcircuit") info("annotations are unmodified") - phase.transform(annotations).toSeq should be (removeEmitted(annotations).toSeq) + phase.transform(annotations).toSeq should be(removeEmitted(annotations).toSeq) info(s"$expected was written") expected should (exist) @@ -63,14 +64,15 @@ class WriteEmittedSpec extends AnyFlatSpec with Matchers { TargetDirAnnotation("test_run_dir/WriteEmittedSpec"), EmittedFirrtlModuleAnnotation(EmittedFirrtlModule("foo", "", ".foomodule")), EmittedFirrtlModuleAnnotation(EmittedFirrtlModule("bar", "", ".barmodule")), - EmittedVerilogModuleAnnotation(EmittedVerilogModule("baz", "", ".bazmodule")) ) + EmittedVerilogModuleAnnotation(EmittedVerilogModule("baz", "", ".bazmodule")) + ) val expected = Seq("foo.foomodule", "bar.barmodule", "baz.bazmodule") .map(a => new File(s"test_run_dir/WriteEmittedSpec/$a")) info("EmittedComponent annotations are deleted") - phase.transform(annotations).toSeq should be (removeEmitted(annotations).toSeq) + phase.transform(annotations).toSeq should be(removeEmitted(annotations).toSeq) - expected.foreach{ a => + expected.foreach { a => info(s"$a was written") a should (exist) a.delete() diff --git a/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala b/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala index 2c746c99..a52df4a9 100644 --- a/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala +++ b/src/test/scala/firrtlTests/transforms/BlackBoxSourceHelperSpec.scala @@ -8,9 +8,8 @@ import firrtl.{Transform, VerilogEmitter} import firrtl.FileUtils import firrtl.testutils.LowTransformSpec - class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { - def transform: Transform = new BlackBoxSourceHelper + def transform: Transform = new BlackBoxSourceHelper private val moduleName = ModuleName("Top", CircuitName("Top")) private val input = """ @@ -31,21 +30,21 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { | y <= a1.bar """.stripMargin private val output = """ - |circuit Top : - | - | extmodule AdderExtModule : - | input foo : UInt<16> - | output bar : UInt<16> - | - | defname = BBFAdd - | - | module Top : - | input x : UInt<16> - | output y : UInt<16> - | - | inst a1 of AdderExtModule - | y <= a1.bar - | a1.foo <= x + |circuit Top : + | + | extmodule AdderExtModule : + | input foo : UInt<16> + | output bar : UInt<16> + | + | defname = BBFAdd + | + | module Top : + | input x : UInt<16> + | output y : UInt<16> + | + | inst a1 of AdderExtModule + | y <= a1.bar + | a1.foo <= x """.stripMargin "annotated external modules with absolute path" should "appear in output directory" in { @@ -61,8 +60,8 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { val module = new java.io.File("test_run_dir/AdderExtModule.v") val fileList = new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}") - module.exists should be (true) - fileList.exists should be (true) + module.exists should be(true) + fileList.exists should be(true) module.delete() fileList.delete() @@ -80,8 +79,8 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { val module = new java.io.File("test_run_dir/AdderExtModule.v") val fileList = new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}") - module.exists should be (true) - fileList.exists should be (true) + module.exists should be(true) + fileList.exists should be(true) module.delete() fileList.delete() @@ -96,8 +95,8 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { execute(input, output, annos) - new java.io.File("test_run_dir/AdderExtModule.v").exists should be (true) - new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}").exists should be (true) + new java.io.File("test_run_dir/AdderExtModule.v").exists should be(true) + new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}").exists should be(true) } "verilog header files" should "be available but not mentioned in the file list" in { @@ -114,40 +113,41 @@ class BlacklBoxSourceHelperTransformSpec extends LowTransformSpec { // We'll copy the following resources to the test_run_dir via BlackBoxResourceAnno's val resourceNames = Seq("ParameterizedViaHeaderAdderExtModule.v", "VerilogHeaderFile.vh") - val annos = Seq( - BlackBoxTargetDirAnno("test_run_dir")) ++ resourceNames.map{ n => BlackBoxResourceAnno(moduleName, "/blackboxes/" + n)} + val annos = Seq(BlackBoxTargetDirAnno("test_run_dir")) ++ resourceNames.map { n => + BlackBoxResourceAnno(moduleName, "/blackboxes/" + n) + } execute(pInput, pOutput, annos) // Our resource files should exist in the test_run_dir, for (n <- resourceNames) - new java.io.File("test_run_dir/" + n).exists should be (true) + new java.io.File("test_run_dir/" + n).exists should be(true) // but our file list should not include the verilog header file. val fileListFile = new java.io.File(s"test_run_dir/${BlackBoxSourceHelper.defaultFileListName}") - fileListFile.exists should be (true) + fileListFile.exists should be(true) val fileList = FileUtils.getText(fileListFile) - fileList.contains("ParameterizedViaHeaderAdderExtModule.v") should be (true) - fileList.contains("VerilogHeaderFile.vh") should be (false) + fileList.contains("ParameterizedViaHeaderAdderExtModule.v") should be(true) + fileList.contains("VerilogHeaderFile.vh") should be(false) } - behavior of "BlackBox resources that do not exist" + behavior.of("BlackBox resources that do not exist") it should "provide a useful error message for BlackBoxResourceAnno" in { - val annos = Seq( BlackBoxTargetDirAnno("test_run_dir"), - BlackBoxResourceAnno(moduleName, "/blackboxes/IDontExist.v") ) + val annos = Seq(BlackBoxTargetDirAnno("test_run_dir"), BlackBoxResourceAnno(moduleName, "/blackboxes/IDontExist.v")) - (the [BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }) - .getMessage should include ("Did you misspell it?") + (the[BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }).getMessage should include( + "Did you misspell it?" + ) } it should "provide a useful error message for BlackBoxPathAnno" in { val absPath = new java.io.File("src/test/resources/blackboxes/IDontExist.v").getCanonicalPath - val annos = Seq( BlackBoxTargetDirAnno("test_run_dir"), - BlackBoxPathAnno(moduleName, absPath) ) + val annos = Seq(BlackBoxTargetDirAnno("test_run_dir"), BlackBoxPathAnno(moduleName, absPath)) - (the [BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }) - .getMessage should include ("Did you misspell it?") + (the[BlackBoxNotFoundException] thrownBy { execute(input, "", annos) }).getMessage should include( + "Did you misspell it?" + ) } } diff --git a/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala index f2672bce..a916eac5 100644 --- a/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala +++ b/src/test/scala/firrtlTests/transforms/CombineCatsSpec.scala @@ -14,9 +14,11 @@ class CombineCatsSpec extends FirrtlFlatSpec { private val annotations = Seq(new MaxCatLenAnnotation(12)) private def execute(input: String, transforms: Seq[Transform], annotations: AnnotationSeq): CircuitState = { - val c = transforms.foldLeft(CircuitState(parse(input), UnknownForm, annotations)) { - (c: CircuitState, t: Transform) => t.runTransform(c) - }.circuit + val c = transforms + .foldLeft(CircuitState(parse(input), UnknownForm, annotations)) { (c: CircuitState, t: Transform) => + t.runTransform(c) + } + .circuit CircuitState(c, UnknownForm, Seq(), None) } @@ -86,11 +88,24 @@ class CombineCatsSpec extends FirrtlFlatSpec { // temp5 should get cat(cat(cat(in3, in2), cat(in4, in3)), cat(cat(in3, in2), cat(in4, in3))) result should containTree { - case DoPrim(Cat, Seq( - DoPrim(Cat, Seq( - DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _), - DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _)), _, _), - DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _)), _, _) => true + case DoPrim( + Cat, + Seq( + DoPrim( + Cat, + Seq( + DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _), + DoPrim(Cat, Seq(WRef("in3", _, _, _), WRef("in2", _, _, _)), _, _) + ), + _, + _ + ), + DoPrim(Cat, Seq(WRef("in4", _, _, _), WRef("in3", _, _, _)), _, _) + ), + _, + _ + ) => + true } } @@ -117,17 +132,19 @@ class CombineCatsSpec extends FirrtlFlatSpec { // should not contain any cat chains greater than 3 result shouldNot containTree { - case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _)), _, _) => true } // temp2 should get cat(in3, cat(in2, in1)) result should containTree { - case DoPrim(Cat, Seq( - WRef("in3", _, _, _), - DoPrim(Cat, Seq( - WRef("in2", _, _, _), - WRef("in1", _, _, _)), _, _)), _, _) => true + case DoPrim( + Cat, + Seq(WRef("in3", _, _, _), DoPrim(Cat, Seq(WRef("in2", _, _, _), WRef("in1", _, _, _)), _, _)), + _, + _ + ) => + true } } @@ -152,8 +169,8 @@ class CombineCatsSpec extends FirrtlFlatSpec { val result = execute(input, transforms, Seq.empty) result shouldNot containTree { - case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true - case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Add, _, _, _)), _, _) => true + case DoPrim(Cat, Seq(_, DoPrim(Sub, _, _, _)), _, _) => true case DoPrim(Cat, Seq(_, DoPrim(Cat, Seq(_, DoPrim(Cat, _, _, _)), _, _)), _, _) => true } } diff --git a/src/test/scala/firrtlTests/transforms/DedupTests.scala b/src/test/scala/firrtlTests/transforms/DedupTests.scala index 8ab3026c..8c2835dd 100644 --- a/src/test/scala/firrtlTests/transforms/DedupTests.scala +++ b/src/test/scala/firrtlTests/transforms/DedupTests.scala @@ -10,8 +10,8 @@ import firrtl.transforms.{DedupModules, NoCircuitDedupAnnotation} import firrtl.testutils._ /** - * Tests inline instances transformation - */ + * Tests inline instances transformation + */ class DedupModuleTests extends HighTransformSpec { case class MultiTargetDummyAnnotation(targets: Seq[Target], tag: Int) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { @@ -24,234 +24,236 @@ class DedupModuleTests extends HighTransformSpec { } def transform = new DedupModules "The module A" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | module A_ : - | output x: UInt<1> - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : - | output x: UInt<1> - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "The module A and B" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | inst b of B - | x <= b.x - | module A_ : - | output x: UInt<1> - | inst b of B_ - | x <= b.x - | module B : - | output x: UInt<1> - | x <= UInt(1) - | module B_ : - | output x: UInt<1> - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module A_ : + | output x: UInt<1> + | inst b of B_ + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : - | output x: UInt<1> - | inst b of B - | x <= b.x - | module B : - | output x: UInt<1> - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : + | output x: UInt<1> + | inst b of B + | x <= b.x + | module B : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "The module A and B with comments" should "be deduped" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B @[yy 2:2] - | x <= b.x @[yy 2:2] - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | inst b of B_ @[xx 1:1] - | x <= b.x @[xx 1:1] - | module B : - | output x: UInt<1> - | x <= UInt(1) - | module B_ : - | output x: UInt<1> - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B @[yy 2:2] + | x <= b.x @[yy 2:2] + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | inst b of B_ @[xx 1:1] + | x <= b.x @[xx 1:1] + | module B : + | output x: UInt<1> + | x <= UInt(1) + | module B_ : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B @[yy 2:2] - | x <= b.x @[yy 2:2] - | module B : - | output x: UInt<1> - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B @[yy 2:2] + | x <= b.x @[yy 2:2] + | module B : + | output x: UInt<1> + | x <= UInt(1) """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "A_ but not A" should "be deduped if not annotated" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | x <= UInt(1) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | x <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output x: UInt<1> @[xx 1:1] - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output x: UInt<1> @[xx 1:1] + | x <= UInt(1) """.stripMargin - execute(input, check, Seq(dontDedup("A"))) + execute(input, check, Seq(dontDedup("A"))) } "The module A and A_" should "be deduped even with different port names and info, and annotations should remapped" in { - val input = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A_ - | out <= and(a1.x, a2.y) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) - | module A_ : @[xx 1:1] - | output y: UInt<1> @[xx 1:1] - | y <= UInt(1) + val input = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | y <= UInt(1) """.stripMargin - val check = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A - | out <= and(a1.x, a2.x) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | x <= UInt(1) + val check = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A + | out <= and(a1.x, a2.x) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | x <= UInt(1) """.stripMargin - val mname = ModuleName("Top", CircuitName("Top")) - val finalState = execute(input, check, Seq(SingleTargetDummyAnnotation(ComponentName("a2.y", mname)))) - finalState.annotations.collect({ case d: SingleTargetDummyAnnotation => d }).head should be(SingleTargetDummyAnnotation(ComponentName("a2.x", mname))) + val mname = ModuleName("Top", CircuitName("Top")) + val finalState = execute(input, check, Seq(SingleTargetDummyAnnotation(ComponentName("a2.y", mname)))) + finalState.annotations.collect({ case d: SingleTargetDummyAnnotation => d }).head should be( + SingleTargetDummyAnnotation(ComponentName("a2.x", mname)) + ) } "Extmodules" should "with the same defname and parameters should dedup" in { - val input = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A_ - | out <= and(a1.x, a2.y) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B - | x <= b.u - | module A_ : @[xx 1:1] - | output y: UInt<1> @[xx 1:1] - | inst c of C - | y <= c.u - | extmodule B : @[aa 3:3] - | output u : UInt<1> @[aa 4:4] - | defname = BB - | parameter N = 0 - | extmodule C : @[bb 5:5] - | output u : UInt<1> @[bb 6:6] - | defname = BB - | parameter N = 0 + val input = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B + | x <= b.u + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | inst c of C + | y <= c.u + | extmodule B : @[aa 3:3] + | output u : UInt<1> @[aa 4:4] + | defname = BB + | parameter N = 0 + | extmodule C : @[bb 5:5] + | output u : UInt<1> @[bb 6:6] + | defname = BB + | parameter N = 0 """.stripMargin - val check = - """circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A - | out <= and(a1.x, a2.x) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B - | x <= b.u - | extmodule B : @[aa 3:3] - | output u : UInt<1> @[aa 4:4] - | defname = BB - | parameter N = 0 + val check = + """circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A + | out <= and(a1.x, a2.x) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B + | x <= b.u + | extmodule B : @[aa 3:3] + | output u : UInt<1> @[aa 4:4] + | defname = BB + | parameter N = 0 """.stripMargin - execute(input, check, Seq.empty) + execute(input, check, Seq.empty) } "Extmodules" should "with the different defname or parameters should NOT dedup" in { - def mkfir(defnames: (String, String), params: (String, String)) = - s"""circuit Top : - | module Top : - | output out: UInt<1> - | inst a1 of A - | inst a2 of A_ - | out <= and(a1.x, a2.y) - | module A : @[yy 2:2] - | output x: UInt<1> @[yy 2:2] - | inst b of B - | x <= b.u - | module A_ : @[xx 1:1] - | output y: UInt<1> @[xx 1:1] - | inst c of C - | y <= c.u - | extmodule B : @[aa 3:3] - | output u : UInt<1> @[aa 4:4] - | defname = ${defnames._1} - | parameter N = ${params._1} - | extmodule C : @[bb 5:5] - | output u : UInt<1> @[bb 6:6] - | defname = ${defnames._2} - | parameter N = ${params._2} + def mkfir(defnames: (String, String), params: (String, String)) = + s"""circuit Top : + | module Top : + | output out: UInt<1> + | inst a1 of A + | inst a2 of A_ + | out <= and(a1.x, a2.y) + | module A : @[yy 2:2] + | output x: UInt<1> @[yy 2:2] + | inst b of B + | x <= b.u + | module A_ : @[xx 1:1] + | output y: UInt<1> @[xx 1:1] + | inst c of C + | y <= c.u + | extmodule B : @[aa 3:3] + | output u : UInt<1> @[aa 4:4] + | defname = ${defnames._1} + | parameter N = ${params._1} + | extmodule C : @[bb 5:5] + | output u : UInt<1> @[bb 6:6] + | defname = ${defnames._2} + | parameter N = ${params._2} """.stripMargin - val diff_defname = mkfir(("BB", "CC"), ("0", "0")) - execute(diff_defname, diff_defname, Seq.empty) - val diff_params = mkfir(("BB", "BB"), ("0", "1")) - execute(diff_params, diff_params, Seq.empty) + val diff_defname = mkfir(("BB", "CC"), ("0", "0")) + execute(diff_defname, diff_defname, Seq.empty) + val diff_params = mkfir(("BB", "BB"), ("0", "1")) + execute(diff_params, diff_params, Seq.empty) } "Modules with aggregate ports that are bulk connected" should "NOT dedup if their port names differ" in { @@ -426,12 +428,16 @@ class DedupModuleTests extends HighTransformSpec { | wire b: UInt<1> | x <= b """.stripMargin - val cs = execute(input, check, Seq( - dontTouch(ReferenceTarget("Top", "A", Nil, "b", Nil)), - dontTouch(ReferenceTarget("Top", "A_", Nil, "b", Nil)) - )) - cs.annotations.toSeq should contain (dontTouch(ModuleTarget("Top", "Top").instOf("a1", "A").ref("b"))) - cs.annotations.toSeq should contain (dontTouch(ModuleTarget("Top", "Top").instOf("a2", "A").ref("b"))) + val cs = execute( + input, + check, + Seq( + dontTouch(ReferenceTarget("Top", "A", Nil, "b", Nil)), + dontTouch(ReferenceTarget("Top", "A_", Nil, "b", Nil)) + ) + ) + cs.annotations.toSeq should contain(dontTouch(ModuleTarget("Top", "Top").instOf("a1", "A").ref("b"))) + cs.annotations.toSeq should contain(dontTouch(ModuleTarget("Top", "Top").instOf("a2", "A").ref("b"))) cs.annotations.toSeq should not contain dontTouch(ReferenceTarget("Top", "A_", Nil, "b", Nil)) } "The module A and A_" should "be deduped with same annotation targets when there are a lot" in { @@ -508,12 +514,24 @@ class DedupModuleTests extends HighTransformSpec { val annoAB = MultiTargetDummyAnnotation(Seq(A, B), 0) val annoA_B_ = MultiTargetDummyAnnotation(Seq(A_, B_), 1) val cs = execute(input, check, Seq(annoAB, annoA_B_)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top_a1, Top_a1_b - ), 0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top_a2, Top_a2_b - ), 1)) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top_a1, + Top_a1_b + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top_a2, + Top_a2_b + ), + 1 + ) + ) } "The module A and A_" should "be deduped with same annotations with same multi-targets, that share roots" in { val input = @@ -555,15 +573,25 @@ class DedupModuleTests extends HighTransformSpec { val annoA = MultiTargetDummyAnnotation(Seq(A, A.instOf("b", "B")), 0) val annoA_ = MultiTargetDummyAnnotation(Seq(A_, A_.instOf("b", "B_")), 1) val cs = execute(input, check, Seq(annoA, annoA_)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a1", "A"), - Top.module("Top").instOf("a1", "A").instOf("b", "B") - ),0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a2", "A"), - Top.module("Top").instOf("a2", "A").instOf("b", "B") - ),1)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a1", "A"), + Top.module("Top").instOf("a1", "A").instOf("b", "B") + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a2", "A"), + Top.module("Top").instOf("a2", "A").instOf("b", "B") + ), + 1 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "The deduping module A and A_" should "rename internal signals that have different names" in { val input = @@ -600,12 +628,12 @@ class DedupModuleTests extends HighTransformSpec { val Top = CircuitTarget("Top") val A = Top.module("A") val A_ = Top.module("A_") - val annoA = SingleTargetDummyAnnotation(A.ref("a")) + val annoA = SingleTargetDummyAnnotation(A.ref("a")) val annoA_ = SingleTargetDummyAnnotation(A_.ref("b")) val cs = execute(input, check, Seq(annoA, annoA_)) - cs.annotations.toSeq should contain (annoA) + cs.annotations.toSeq should contain(annoA) cs.annotations.toSeq should not contain (SingleTargetDummyAnnotation(A.ref("b"))) - cs.deletedAnnotations.isEmpty should be (true) + cs.deletedAnnotations.isEmpty should be(true) } "main" should "not be deduped even if it's the last module" in { val input = @@ -691,14 +719,25 @@ class DedupModuleTests extends HighTransformSpec { val anno1 = MultiTargetDummyAnnotation(Seq(inst1, ref1), 0) val anno2 = MultiTargetDummyAnnotation(Seq(inst2, ref2), 1) val cs = execute(input, check, Seq(anno1, anno2)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - inst1, ref1 - ),0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a_", "A").instOf("b", "B"), - Top.module("Top").instOf("a_", "A").instOf("b", "B").ref("foo") - ),1)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + inst1, + ref1 + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a_", "A").instOf("b", "B"), + Top.module("Top").instOf("a_", "A").instOf("b", "B").ref("foo") + ), + 1 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "The deduping module A and A_" should "rename nested instances that have different names" in { @@ -746,14 +785,25 @@ class DedupModuleTests extends HighTransformSpec { val anno1 = MultiTargetDummyAnnotation(Seq(inst1, ref1), 0) val anno2 = MultiTargetDummyAnnotation(Seq(inst2, ref2), 1) val cs = execute(input, check, Seq(anno1, anno2)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - inst1, ref1 - ),0)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D"), - Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D").ref("foo") - ),1)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + inst1, + ref1 + ), + 0 + ) + ) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D"), + Top.module("Top").instOf("a_", "A").instOf("b", "B").instOf("c", "C").instOf("d", "D").ref("foo") + ), + 1 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "Deduping modules with multiple instances" should "corectly rename instances" in { @@ -801,50 +851,55 @@ class DedupModuleTests extends HighTransformSpec { val cInstances = bInstances.map(_.instOf("c", "C")) val annos = MultiTargetDummyAnnotation(bInstances ++ cInstances, 0) val cs = execute(input, check, Seq(annos)) - cs.annotations.toSeq should contain (MultiTargetDummyAnnotation(Seq( - Top.instOf("b", "B"), - Top.instOf("b_", "B"), - Top.instOf("a1", "A").instOf("b_", "B"), - Top.instOf("a2", "A").instOf("b_", "B"), - Top.instOf("a1", "A").instOf("b", "B"), - Top.instOf("a2", "A").instOf("b", "B"), - Top.instOf("b", "B").instOf("c", "C"), - Top.instOf("b_", "B").instOf("c", "C"), - Top.instOf("a1", "A").instOf("b_", "B").instOf("c", "C"), - Top.instOf("a2", "A").instOf("b_", "B").instOf("c", "C"), - Top.instOf("a1", "A").instOf("b", "B").instOf("c", "C"), - Top.instOf("a2", "A").instOf("b", "B").instOf("c", "C") - ),0)) - cs.deletedAnnotations.isEmpty should be (true) + cs.annotations.toSeq should contain( + MultiTargetDummyAnnotation( + Seq( + Top.instOf("b", "B"), + Top.instOf("b_", "B"), + Top.instOf("a1", "A").instOf("b_", "B"), + Top.instOf("a2", "A").instOf("b_", "B"), + Top.instOf("a1", "A").instOf("b", "B"), + Top.instOf("a2", "A").instOf("b", "B"), + Top.instOf("b", "B").instOf("c", "C"), + Top.instOf("b_", "B").instOf("c", "C"), + Top.instOf("a1", "A").instOf("b_", "B").instOf("c", "C"), + Top.instOf("a2", "A").instOf("b_", "B").instOf("c", "C"), + Top.instOf("a1", "A").instOf("b", "B").instOf("c", "C"), + Top.instOf("a2", "A").instOf("b", "B").instOf("c", "C") + ), + 0 + ) + ) + cs.deletedAnnotations.isEmpty should be(true) } "dedup" should "properly rename target components after retyping" in { val input = """ - |circuit top: - | module top: - | input ia: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | input ib: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | output oa: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | output ob: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | inst a of a - | a.i.z.y.x <= ia.z.y.x - | a.i.a <= ia.a - | oa.z.y.x <= a.o.z.y.x - | oa.a <= a.o.a - | inst b of b - | b.q.a.b.c <= ib.a.b.c - | b.q.z <= ib.z - | ob.a.b.c <= b.r.a.b.c - | ob.z <= b.r.z - | module a: - | input i: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | output o: {z: {y: {x: UInt<1>}}, a: UInt<1>} - | o <= i - | module b: - | input q: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | output r: {a: {b: {c: UInt<1>}}, z: UInt<1>} - | r <= q - |""".stripMargin + |circuit top: + | module top: + | input ia: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | input ib: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | output oa: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | output ob: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | inst a of a + | a.i.z.y.x <= ia.z.y.x + | a.i.a <= ia.a + | oa.z.y.x <= a.o.z.y.x + | oa.a <= a.o.a + | inst b of b + | b.q.a.b.c <= ib.a.b.c + | b.q.z <= ib.z + | ob.a.b.c <= b.r.a.b.c + | ob.z <= b.r.z + | module a: + | input i: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | output o: {z: {y: {x: UInt<1>}}, a: UInt<1>} + | o <= i + | module b: + | input q: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | output r: {a: {b: {c: UInt<1>}}, z: UInt<1>} + | r <= q + |""".stripMargin case class DummyRTAnnotation(target: ReferenceTarget) extends SingleTargetAnnotation[ReferenceTarget] { def duplicate(n: ReferenceTarget) = DummyRTAnnotation(n) @@ -853,7 +908,6 @@ class DedupModuleTests extends HighTransformSpec { val annA = DummyRTAnnotation(ReferenceTarget("top", "a", Nil, "i", Seq(TargetToken.Field("a")))) val annB = DummyRTAnnotation(ReferenceTarget("top", "b", Nil, "q", Seq(TargetToken.Field("a")))) - val cs = CircuitState(Parser.parseString(input, Parser.IgnoreInfo), Seq(annA, annB)) val deduper = new stage.transforms.Compiler(stage.Forms.Deduped, Nil) @@ -871,7 +925,7 @@ class DedupModuleTests extends HighTransformSpec { val bPath = Seq((TargetToken.Instance("b"), TargetToken.OfModule("a"))) val expectedAnnA = DummyRTAnnotation(ReferenceTarget("top", "top", aPath, "i", Seq(TargetToken.Field("a")))) val expectedAnnB = DummyRTAnnotation(ReferenceTarget("top", "top", aPath, "i", Seq(TargetToken.Field("a")))) - csDeduped.annotations.toSeq should contain (expectedAnnA) - csDeduped.annotations.toSeq should contain (expectedAnnB) + csDeduped.annotations.toSeq should contain(expectedAnnA) + csDeduped.annotations.toSeq should contain(expectedAnnB) } } diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala index fdb129a1..65544764 100644 --- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala +++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala @@ -17,75 +17,75 @@ class GroupComponentsSpec extends MiddleTransformSpec { def topComp(name: String): ComponentName = ComponentName(name, ModuleName(top, CircuitName(top))) "The register r" should "be grouped" in { val input = - s"""circuit $top : - | module $top : - | input clk: Clock - | input data: UInt<16> - | output out: UInt<16> - | reg r: UInt<16>, clk - | r <= data - | out <= r + s"""circuit $top : + | module $top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | reg r: UInt<16>, clk + | r <= data + | out <= r """.stripMargin val groups = Seq( GroupAnnotation(Seq(topComp("r")), "MyReg", "rInst", Some("_OUT"), Some("_IN")) ) val check = - s"""circuit Top : - | module $top : - | input clk: Clock - | input data: UInt<16> - | output out: UInt<16> - | inst rInst of MyReg - | rInst.clk_IN <= clk - | out <= rInst.r_OUT - | rInst.data_IN <= data - | module MyReg : - | input clk_IN: Clock - | output r_OUT: UInt<16> - | input data_IN: UInt<16> - | reg r: UInt<16>, clk_IN - | r_OUT <= r - | r <= data_IN + s"""circuit Top : + | module $top : + | input clk: Clock + | input data: UInt<16> + | output out: UInt<16> + | inst rInst of MyReg + | rInst.clk_IN <= clk + | out <= rInst.r_OUT + | rInst.data_IN <= data + | module MyReg : + | input clk_IN: Clock + | output r_OUT: UInt<16> + | input data_IN: UInt<16> + | reg r: UInt<16>, clk_IN + | r_OUT <= r + | r <= data_IN """.stripMargin execute(input, check, groups) } "Grouping" should "work even when there are unused nodes" in { val input = - s"""circuit $top : - | module $top : - | input in: UInt<16> - | output out: UInt<16> - | node n = UInt<16>("h0") - | wire w : UInt<16> - | wire a : UInt<16> - | wire b : UInt<16> - | a <= UInt<16>("h0") - | b <= a - | w <= in - | out <= w + s"""circuit $top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | node n = UInt<16>("h0") + | wire w : UInt<16> + | wire a : UInt<16> + | wire b : UInt<16> + | a <= UInt<16>("h0") + | b <= a + | w <= in + | out <= w """.stripMargin val groups = Seq( GroupAnnotation(Seq(topComp("w")), "Child", "inst", Some("_OUT"), Some("_IN")) ) val check = - s"""circuit Top : - | module $top : - | input in: UInt<16> - | output out: UInt<16> - | inst inst of Child - | node n = UInt<16>("h0") - | wire a : UInt<16> - | wire b : UInt<16> - | out <= inst.w_OUT - | inst.in_IN <= in - | a <= UInt<16>("h0") - | b <= a - | module Child : - | output w_OUT : UInt<16> - | input in_IN : UInt<16> - | wire w : UInt<16> - | w_OUT <= w - | w <= in_IN + s"""circuit Top : + | module $top : + | input in: UInt<16> + | output out: UInt<16> + | inst inst of Child + | node n = UInt<16>("h0") + | wire a : UInt<16> + | wire b : UInt<16> + | out <= inst.w_OUT + | inst.in_IN <= in + | a <= UInt<16>("h0") + | b <= a + | module Child : + | output w_OUT : UInt<16> + | input in_IN : UInt<16> + | wire w : UInt<16> + | w_OUT <= w + | w <= in_IN """.stripMargin execute(input, check, groups) } @@ -116,8 +116,8 @@ class GroupComponentsSpec extends MiddleTransformSpec { | out <= UInt(2) """.stripMargin val annotations = Seq( - GroupAnnotation(Seq(topComp("c1a"), topComp("c2a")/*, topComp("asum")*/), "A", "cA", Some("_OUT"), Some("_IN")), - GroupAnnotation(Seq(topComp("c1b"), topComp("c2b")/*, topComp("bsum")*/), "B", "cB", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1a"), topComp("c2a") /*, topComp("asum")*/ ), "A", "cA", Some("_OUT"), Some("_IN")), + GroupAnnotation(Seq(topComp("c1b"), topComp("c2b") /*, topComp("bsum")*/ ), "B", "cB", Some("_OUT"), Some("_IN")), NoCircuitDedupAnnotation ) val check = @@ -380,7 +380,7 @@ class GroupComponentsIntegrationSpec extends FirrtlFlatSpec { def topComp(name: String): ComponentName = ComponentName(name, ModuleName("Top", CircuitName("Top"))) "Grouping" should "properly set kinds" in { val input = - """circuit Top : + """circuit Top : | module Top : | input clk: Clock | input data: UInt<16> @@ -397,13 +397,13 @@ class GroupComponentsIntegrationSpec extends FirrtlFlatSpec { Seq(new GroupComponents) ) result should containTree { - case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "data_IN", _,_), WRef("data",_,_,_)) => true + case Connect(_, WSubField(WRef("inst", _, InstanceKind, _), "data_IN", _, _), WRef("data", _, _, _)) => true } result should containTree { - case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "clk_IN", _,_), WRef("clk",_,_,_)) => true + case Connect(_, WSubField(WRef("inst", _, InstanceKind, _), "clk_IN", _, _), WRef("clk", _, _, _)) => true } result should containTree { - case Connect(_, WRef("out",_,_,_), WSubField(WRef("inst",_, InstanceKind,_), "r_OUT", _,_)) => true + case Connect(_, WRef("out", _, _, _), WSubField(WRef("inst", _, InstanceKind, _), "r_OUT", _, _)) => true } } } diff --git a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala index c5847364..0043cb1f 100644 --- a/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala +++ b/src/test/scala/firrtlTests/transforms/InferWidthsWithAnnosSpec.scala @@ -5,36 +5,25 @@ package firrtlTests.transforms import firrtl.testutils.FirrtlFlatSpec import firrtl._ import firrtl.passes._ -import firrtl.passes.wiring.{WiringTransform, SourceAnnotation, SinkAnnotation} +import firrtl.passes.wiring.{SinkAnnotation, SourceAnnotation, WiringTransform} import firrtl.annotations._ import firrtl.annotations.TargetToken.{Field, Index} - class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { - private def executeTest(input: String, - check: String, - transforms: Seq[Transform], - annotations: Seq[Annotation]) = { + private def executeTest(input: String, check: String, transforms: Seq[Transform], annotations: Seq[Annotation]) = { val start = CircuitState(parse(input), ChirrtlForm, annotations) - val end = transforms.foldLeft(start) { - (c: CircuitState, t: Transform) => t.runTransform(c) + val end = transforms.foldLeft(start) { (c: CircuitState, t: Transform) => + t.runTransform(c) } - val resLines = end.circuit.serialize.split("\n") map normalized - val checkLines = parse(check).serialize.split("\n") map normalized + val resLines = end.circuit.serialize.split("\n").map(normalized) + val checkLines = parse(check).serialize.split("\n").map(normalized) - resLines should be (checkLines) + resLines should be(checkLines) } "CheckWidths on wires with unknown widths" should "result in an error" in { - val transforms = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val transforms = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths, CheckWidths) val input = """circuit Top : @@ -55,19 +44,15 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { } "InferWidthsWithAnnos" should "infer widths using WidthGeqConstraintAnnotation" in { - val transforms = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val transforms = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths, CheckWidths) - val annos = Seq(WidthGeqConstraintAnnotation( - ReferenceTarget("Top", "A", Nil, "y", Nil), - ReferenceTarget("Top", "B", Nil, "x", Nil))) + val annos = Seq( + WidthGeqConstraintAnnotation( + ReferenceTarget("Top", "A", Nil, "y", Nil), + ReferenceTarget("Top", "B", Nil, "x", Nil) + ) + ) val input = """circuit Top : @@ -98,15 +83,8 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { } "InferWidthsWithAnnos" should "work with token paths" in { - val transforms = Seq( - ToWorkingIR, - CheckHighForm, - ResolveKinds, - InferTypes, - CheckTypes, - ResolveFlows, - new InferWidths, - CheckWidths) + val transforms = + Seq(ToWorkingIR, CheckHighForm, ResolveKinds, InferTypes, CheckTypes, ResolveFlows, new InferWidths, CheckWidths) val tokenLists = Seq( Seq(Field("x")), @@ -117,7 +95,8 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { val annos = tokenLists.map { tokens => WidthGeqConstraintAnnotation( ReferenceTarget("Top", "A", Nil, "bundle", tokens), - ReferenceTarget("Top", "B", Nil, "bundle", tokens)) + ReferenceTarget("Top", "B", Nil, "bundle", tokens) + ) } val input = @@ -174,7 +153,8 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { val wgeqAnnos = tokenLists.map { tokens => WidthGeqConstraintAnnotation( ReferenceTarget("Top", "A", Nil, "bundle", tokens), - ReferenceTarget("Top", "B", Nil, "bundle", tokens)) + ReferenceTarget("Top", "B", Nil, "bundle", tokens) + ) } val failAnnos = Seq(source, sink) @@ -209,8 +189,7 @@ class InferWidthsWithAnnosSpec extends FirrtlFlatSpec { | module A : | output bundle_0 : {x : UInt<1>, y: {yy : UInt<3>}[2] } | wire bundle : {x : UInt<1>, y: {yy : UInt<3>}[2] } - | bundle_0 <= bundle""" - .stripMargin + | bundle_0 <= bundle""".stripMargin // should fail without extra constraint annos due to UninferredWidths val exceptions = intercept[PassExceptions] { diff --git a/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala b/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala index f57586f6..6ee0f5a0 100644 --- a/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala +++ b/src/test/scala/firrtlTests/transforms/LegalizeClocks.scala @@ -10,7 +10,7 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { def compile(input: String): CircuitState = (new MinimumVerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), Nil) - behavior of "LegalizeClocksTransform" + behavior.of("LegalizeClocksTransform") it should "not emit @(posedge 1'h0) for stop" in { val input = @@ -19,8 +19,8 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | stop(asClock(UInt(1)), UInt(1), 1) |""".stripMargin val result = compile(input) - result should containLine (s"always @(posedge _GEN_0) begin") - result.getEmittedCircuit.value shouldNot include ("always @(posedge 1") + result should containLine(s"always @(posedge _GEN_0) begin") + result.getEmittedCircuit.value shouldNot include("always @(posedge 1") } it should "not emit @(posedge 1'h0) for printf" in { @@ -30,8 +30,8 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | printf(asClock(UInt(1)), UInt(1), "hi") |""".stripMargin val result = compile(input) - result should containLine (s"always @(posedge _GEN_0) begin") - result.getEmittedCircuit.value shouldNot include ("always @(posedge 1") + result should containLine(s"always @(posedge _GEN_0) begin") + result.getEmittedCircuit.value shouldNot include("always @(posedge 1") } it should "not emit @(posedge 1'h0) for reg" in { @@ -45,8 +45,8 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | out <= r |""".stripMargin val result = compile(input) - result should containLine (s"always @(posedge _GEN_0) begin") - result.getEmittedCircuit.value shouldNot include ("always @(posedge 1") + result should containLine(s"always @(posedge _GEN_0) begin") + result.getEmittedCircuit.value shouldNot include("always @(posedge 1") } it should "deduplicate injected nodes for literal clocks" in { @@ -57,11 +57,11 @@ class LegalizeClocksTransformSpec extends FirrtlFlatSpec { | stop(asClock(UInt(1)), UInt(1), 1) |""".stripMargin val result = compile(input) - result should containLine (s"wire _GEN_0 = 1'h1;") + result should containLine(s"wire _GEN_0 = 1'h1;") // Check that there's only 1 _GEN_0 instantiation val verilog = result.getEmittedCircuit.value val matches = "wire\\s+_GEN_0\\s+=\\s+1'h1".r.findAllIn(verilog) - matches.size should be (1) + matches.size should be(1) } } diff --git a/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala b/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala index 5368c54c..3df47f1d 100644 --- a/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala +++ b/src/test/scala/firrtlTests/transforms/LegalizeReductions.scala @@ -12,12 +12,11 @@ import java.io.File object LegalizeAndReductionsTransformSpec extends FirrtlRunners { private case class Test( - name: String, - op: String, - input: BigInt, - expected: BigInt, - forceWidth: Option[Int] = None - ) { + name: String, + op: String, + input: BigInt, + expected: BigInt, + forceWidth: Option[Int] = None) { def toFirrtl: String = { val width = forceWidth.getOrElse(input.bitLength) val inputLit = s"""UInt("h${input.toString(16)}")""" @@ -62,9 +61,9 @@ circuit $name : // Run FIRRTL val annos = FirrtlSourceAnnotation(test.toFirrtl) :: - TargetDirAnnotation(testDir.toString) :: - CompilerAnnotation(new MinimumVerilogCompiler) :: - Nil + TargetDirAnnotation(testDir.toString) :: + CompilerAnnotation(new MinimumVerilogCompiler) :: + Nil val resultAnnos = (new FirrtlStage).transform(annos) val outputFilename = resultAnnos.collectFirst { case OutputFileAnnotation(f) => f } outputFilename.toRight(s"Output file not found!") @@ -73,8 +72,8 @@ circuit $name : copyResourceToFile(cppHarnessResourceName, harness) // Run Verilator verilogToCpp(prefix, testDir, Nil, harness, suppressVcd = true) #&& - cppToExe(prefix, testDir) ! - loggingProcessLogger + cppToExe(prefix, testDir) ! + loggingProcessLogger // Run binary if (!executeExpectingSuccess(prefix, testDir)) { throw new Exception("Test failed!") with scala.util.control.NoStackTrace @@ -82,24 +81,23 @@ circuit $name : } } - class LegalizeAndReductionsTransformSpec extends AnyFlatSpec { import LegalizeAndReductionsTransformSpec._ - behavior of "LegalizeAndReductionsTransform" + behavior.of("LegalizeAndReductionsTransform") private val tests = // name primop input expected width - Test("andreduce_ones", "andr", BigInt("1"*68, 2), 1) :: - Test("andreduce_zero", "andr", 0, 0, Some(68)) :: - Test("orreduce_ones", "orr", BigInt("1"*68, 2), 1) :: - Test("orreduce_high_one", "orr", BigInt("1" + "0"*67, 2), 1) :: - Test("orreduce_zero", "orr", 0, 0, Some(68)) :: - Test("xorreduce_high_one", "xorr", BigInt("1" + "0"*67, 2), 1) :: - Test("xorreduce_high_low_one", "xorr", BigInt("1" + "0"*66 + "1", 2), 0) :: - Test("xorreduce_zero", "xorr", 0, 0, Some(68)) :: - Nil + Test("andreduce_ones", "andr", BigInt("1" * 68, 2), 1) :: + Test("andreduce_zero", "andr", 0, 0, Some(68)) :: + Test("orreduce_ones", "orr", BigInt("1" * 68, 2), 1) :: + Test("orreduce_high_one", "orr", BigInt("1" + "0" * 67, 2), 1) :: + Test("orreduce_zero", "orr", 0, 0, Some(68)) :: + Test("xorreduce_high_one", "xorr", BigInt("1" + "0" * 67, 2), 1) :: + Test("xorreduce_high_low_one", "xorr", BigInt("1" + "0" * 66 + "1", 2), 0) :: + Test("xorreduce_zero", "xorr", 0, 0, Some(68)) :: + Nil for (test <- tests) { it should s"support ${test.name}" in { diff --git a/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala b/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala index ec1b505b..a616b4bd 100644 --- a/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala +++ b/src/test/scala/firrtlTests/transforms/ManipulateNamesSpec.scala @@ -2,22 +2,15 @@ package firrtlTests.transforms -import firrtl.{ - ir, - CircuitState, - FirrtlUserException, - Namespace, - Parser, - RenameMap -} +import firrtl.{ir, CircuitState, FirrtlUserException, Namespace, Parser, RenameMap} import firrtl.annotations.CircuitTarget import firrtl.options.Dependency import firrtl.testutils.FirrtlCheckers._ import firrtl.transforms.{ ManipulateNames, - ManipulateNamesBlocklistAnnotation, ManipulateNamesAllowlistAnnotation, - ManipulateNamesAllowlistResultAnnotation + ManipulateNamesAllowlistResultAnnotation, + ManipulateNamesBlocklistAnnotation } import org.scalatest.flatspec.AnyFlatSpec @@ -57,24 +50,24 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val tm = new firrtl.stage.transforms.Compiler(Seq(Dependency[AddPrefix])) } - behavior of "ManipulateNames" + behavior.of("ManipulateNames") it should "rename everything by default" in new CircuitFixture { val state = CircuitState(Parser.parse(input), Seq.empty) val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "prefix_Foo") => true }, - { case ir.Module(_, "prefix_Foo", _, _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true} + { case ir.Module(_, "prefix_Foo", _, _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "do nothing if the circuit is blocklisted" in new CircuitFixture { val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo`)), Dependency[AddPrefix])) val state = CircuitState(Parser.parse(input), annotations) val statex = tm.execute(state) - state.circuit.serialize should be (statex.circuit.serialize) + state.circuit.serialize should be(statex.circuit.serialize) } it should "not rename the circuit if the top module is blocklisted" in new CircuitFixture { @@ -82,31 +75,31 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val state = CircuitState(Parser.parse(input), annotations) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "Foo") => true }, - { case ir.Module(_, "Foo", _, _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true} + { case ir.Module(_, "Foo", _, _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true } ) val statex = tm.execute(state) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "not rename instances if blocklisted" in new CircuitFixture { val annotations = Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(`~Foo|Foo/bar:Bar`)), Dependency[AddPrefix])) val state = CircuitState(Parser.parse(input), annotations) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( - { case ir.DefInstance(_, "bar", "prefix_Bar", _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true} + { case ir.DefInstance(_, "bar", "prefix_Bar", _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true } ) val statex = tm.execute(state) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } - it should "do nothing if the circuit is not allowlisted" in new CircuitFixture { + it should "do nothing if the circuit is not allowlisted" in new CircuitFixture { val annotations = Seq( ManipulateNamesAllowlistAnnotation(Seq(Seq(`~Foo|Foo`)), Dependency[AddPrefix]) ) val state = CircuitState(Parser.parse(input), annotations) val statex = tm.execute(state) - state.circuit.serialize should be (statex.circuit.serialize) + state.circuit.serialize should be(statex.circuit.serialize) } it should "rename only the circuit if allowlisted" in new CircuitFixture { @@ -118,13 +111,13 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "prefix_Foo") => true }, - { case ir.Module(_, "prefix_Foo", _, _) => true}, - { case ir.DefInstance(_, "bar", "Bar", _) => true}, - { case ir.DefInstance(_, "bar2", "Bar", _) => true}, - { case ir.Module(_, "Bar", _, _) => true}, - { case ir.DefNode(_, "a", _) => true} + { case ir.Module(_, "prefix_Foo", _, _) => true }, + { case ir.DefInstance(_, "bar", "Bar", _) => true }, + { case ir.DefInstance(_, "bar2", "Bar", _) => true }, + { case ir.Module(_, "Bar", _, _) => true }, + { case ir.DefNode(_, "a", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "rename an instance via allowlisting" in new CircuitFixture { @@ -136,13 +129,13 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "Foo") => true }, - { case ir.Module(_, "Foo", _, _) => true}, - { case ir.DefInstance(_, "prefix_bar", "Bar", _) => true}, - { case ir.DefInstance(_, "bar2", "Bar", _) => true}, - { case ir.Module(_, "Bar", _, _) => true}, - { case ir.DefNode(_, "a", _) => true} + { case ir.Module(_, "Foo", _, _) => true }, + { case ir.DefInstance(_, "prefix_bar", "Bar", _) => true }, + { case ir.DefInstance(_, "bar2", "Bar", _) => true }, + { case ir.Module(_, "Bar", _, _) => true }, + { case ir.DefNode(_, "a", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "rename a node via allowlisting" in new CircuitFixture { @@ -154,13 +147,13 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "Foo") => true }, - { case ir.Module(_, "Foo", _, _) => true}, - { case ir.DefInstance(_, "bar", "Bar", _) => true}, - { case ir.DefInstance(_, "bar2", "Bar", _) => true}, - { case ir.Module(_, "Bar", _, _) => true}, - { case ir.DefNode(_, "prefix_a", _) => true} + { case ir.Module(_, "Foo", _, _) => true }, + { case ir.DefInstance(_, "bar", "Bar", _) => true }, + { case ir.DefInstance(_, "bar2", "Bar", _) => true }, + { case ir.Module(_, "Bar", _, _) => true }, + { case ir.DefNode(_, "prefix_a", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } it should "throw user errors on circuits that haven't been run through LowerTypes" in { @@ -171,9 +164,9 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { | node baz = bar.a |""".stripMargin val state = CircuitState(Parser.parse(input), Seq.empty) - intercept [FirrtlUserException] { + intercept[FirrtlUserException] { (new AddPrefix).transform(state) - }.getMessage should include ("LowerTypes") + }.getMessage should include("LowerTypes") } it should "only consume annotations whose type parameter matches" in new CircuitFixture { @@ -187,25 +180,25 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val statex = tm.execute(state) val expected: Seq[PartialFunction[Any, Boolean]] = Seq( { case ir.Circuit(_, _, "prefix_Foo") => true }, - { case ir.Module(_, "prefix_Foo", _, _) => true}, - { case ir.DefInstance(_, "prefix_bar", "prefix_Bar", _) => true}, - { case ir.DefInstance(_, "prefix_bar2", "prefix_Bar", _) => true}, - { case ir.Module(_, "prefix_Bar", _, _) => true}, - { case ir.DefNode(_, "a_suffix", _) => true} + { case ir.Module(_, "prefix_Foo", _, _) => true }, + { case ir.DefInstance(_, "prefix_bar", "prefix_Bar", _) => true }, + { case ir.DefInstance(_, "prefix_bar2", "prefix_Bar", _) => true }, + { case ir.Module(_, "prefix_Bar", _, _) => true }, + { case ir.DefNode(_, "a_suffix", _) => true } ) - expected.foreach(statex should containTree (_)) + expected.foreach(statex should containTree(_)) } - behavior of "ManipulateNamesBlocklistAnnotation" + behavior.of("ManipulateNamesBlocklistAnnotation") it should "throw an exception if a non-local target is skipped" in new CircuitFixture { val barA = CircuitTarget("Foo").module("Foo").instOf("bar", "Bar").ref("a") - assertThrows[java.lang.IllegalArgumentException]{ + assertThrows[java.lang.IllegalArgumentException] { Seq(ManipulateNamesBlocklistAnnotation(Seq(Seq(barA)), Dependency[AddPrefix])) } } - behavior of "ManipulateNamesAllowlistResultAnnotation" + behavior.of("ManipulateNamesAllowlistResultAnnotation") it should "delete itself if the new target is deleted" in { val `~Foo|Bar` = CircuitTarget("Foo").module("Bar") @@ -220,7 +213,7 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { val r = RenameMap() r.delete(`~Foo|prefix_Bar`) - a.update(r) should be (empty) + a.update(r) should be(empty) } it should "drop a deleted target" in { @@ -242,12 +235,12 @@ class ManipulateNamesSpec extends AnyFlatSpec with Matchers { case b: ManipulateNamesAllowlistResultAnnotation[_] => b } - ax should not be length (1) + ax should not be length(1) val keys = ax.head.toRenameMap.getUnderlying.keys keys should not contain (`~Foo|Bar`) - keys should contain (`~Foo|Baz`) + keys should contain(`~Foo|Baz`) } } diff --git a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala index 299a4f48..d603db69 100644 --- a/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala +++ b/src/test/scala/firrtlTests/transforms/RemoveResetSpec.scala @@ -8,7 +8,7 @@ import firrtl.testutils.FirrtlFlatSpec import firrtl.testutils.FirrtlCheckers._ import firrtl.{CircuitState, WRef} -import firrtl.ir.{Connect, Mux, DefRegister} +import firrtl.ir.{Connect, DefRegister, Mux} import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlSourceAnnotation, FirrtlStage} class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { @@ -17,12 +17,12 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { When("the circuit is compiled to low FIRRTL") (new FirrtlStage) .execute(Array("-X", "low"), Seq(FirrtlSourceAnnotation(string))) - .collectFirst{ case FirrtlCircuitAnnotation(a) => a } + .collectFirst { case FirrtlCircuitAnnotation(a) => a } .map(a => firrtl.CircuitState(a, firrtl.UnknownForm)) .get } - behavior of "RemoveReset" + behavior.of("RemoveReset") it should "not generate a reset mux for an invalid init" in { Given("a 1-bit register 'foo' initialized to invalid, 1-bit wire 'bar'") @@ -44,7 +44,7 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("'foo' is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true } } it should "generate a reset mux for only the portion of an invalid aggregate that is reset" in { @@ -71,11 +71,11 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("foo.a[0] is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo_a_0",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo_a_0", _, _, _), Mux(_, _, _, _)) => true } And("foo.a[1] is connected to a reset mux") - outputState should containTree { case Connect(_, WRef("foo_a_1",_,_,_), Mux(_,_,_,_)) => true } + outputState should containTree { case Connect(_, WRef("foo_a_1", _, _, _), Mux(_, _, _, _)) => true } And("foo.b is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo_b", _, _, _), Mux(_, _, _, _)) => true } } it should "propagate invalidations across connects" in { @@ -107,9 +107,9 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("'foo.a' is connected to a reset mux") - outputState should containTree { case Connect(_, WRef("foo_a",_,_,_), Mux(_,_,_,_)) => true } + outputState should containTree { case Connect(_, WRef("foo_a", _, _, _), Mux(_, _, _, _)) => true } And("'foo.b' is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo_b",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo_b", _, _, _), Mux(_, _, _, _)) => true } } it should "canvert a reset wired to UInt<0> to a canonical non-reset" in { @@ -128,8 +128,8 @@ class RemoveResetSpec extends FirrtlFlatSpec with GivenWhenThen { val outputState = toLowFirrtl(input) Then("foo has a canonical non-reset declaration after RemoveReset") - outputState should containTree { case DefRegister(_, "foo", _,_, firrtl.Utils.zero, WRef("foo", _,_,_)) => true } + outputState should containTree { case DefRegister(_, "foo", _, _, firrtl.Utils.zero, WRef("foo", _, _, _)) => true } And("foo is NOT connected to a reset mux") - outputState shouldNot containTree { case Connect(_, WRef("foo",_,_,_), Mux(_,_,_,_)) => true } + outputState shouldNot containTree { case Connect(_, WRef("foo", _, _, _), Mux(_, _, _, _)) => true } } } diff --git a/src/test/scala/firrtlTests/transforms/TopWiringTest.scala b/src/test/scala/firrtlTests/transforms/TopWiringTest.scala index 0ac12ef8..97fafe41 100644 --- a/src/test/scala/firrtlTests/transforms/TopWiringTest.scala +++ b/src/test/scala/firrtlTests/transforms/TopWiringTest.scala @@ -6,724 +6,718 @@ package transforms import java.io._ import firrtl._ -import firrtl.ir.{Type, GroundType, IntWidth} +import firrtl.ir.{GroundType, IntWidth, Type} import firrtl.Parser -import firrtl.annotations.{ - CircuitName, - ModuleName, - ComponentName, - Target -} +import firrtl.annotations.{CircuitName, ComponentName, ModuleName, Target} import firrtl.transforms.TopWiring._ import firrtl.testutils._ - trait TopWiringTestsCommon extends FirrtlRunners { - val testDir = createTestDirectory("TopWiringTests") - val testDirName = testDir.getPath - def transform = new TopWiringTransform + val testDir = createTestDirectory("TopWiringTests") + val testDirName = testDir.getPath + def transform = new TopWiringTransform - def topWiringDummyOutputFilesFunction(dir: String, - mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], - state: CircuitState): CircuitState = { - state - } + def topWiringDummyOutputFilesFunction( + dir: String, + mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + state: CircuitState + ): CircuitState = { + state + } - def topWiringTestOutputFilesFunction(dir: String, - mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], - state: CircuitState): CircuitState = { - val testOutputFile = new PrintWriter(new File(dir, "TopWiringOutputTest.txt" )) - mapping map { - case ((_, tpe, _, path, prefix), index) => { - val portwidth = tpe match { case GroundType(IntWidth(w)) => w } - val portnum = index - val portname = prefix + path.mkString("_") - testOutputFile.append(s"new top level port $portnum : $portname, with width $portwidth \n") - } - } - testOutputFile.close() - state - } + def topWiringTestOutputFilesFunction( + dir: String, + mapping: Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)], + state: CircuitState + ): CircuitState = { + val testOutputFile = new PrintWriter(new File(dir, "TopWiringOutputTest.txt")) + mapping.map { + case ((_, tpe, _, path, prefix), index) => { + val portwidth = tpe match { case GroundType(IntWidth(w)) => w } + val portnum = index + val portname = prefix + path.mkString("_") + testOutputFile.append(s"new top level port $portnum : $portname, with width $portwidth \n") + } + } + testOutputFile.close() + state + } } /** - * Tests TopWiring transformation - */ -class TopWiringTests extends MiddleTransformSpec with TopWiringTestsCommon { + * Tests TopWiring transformation + */ +class TopWiringTests extends MiddleTransformSpec with TopWiringTestsCommon { - "The signal x in module C" should s"be connected to Top port with topwiring prefix and outputfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C" should s"be connected to Top port with topwiring prefix and outputfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | inst c1 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | inst c1 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C inst c1 and c2" should + "The signal x in module C inst c1 and c2" should s"be connected to Top port with topwiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | output topwiring_a1_b1_c2_x: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | output topwiring_b1_c2_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | topwiring_b1_c2_x <= b1.topwiring_c2_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | output topwiring_c2_x: UInt<1> - | inst c1 of C - | inst c2 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | topwiring_c2_x <= c2.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | output topwiring_a1_b1_c2_x: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | output topwiring_b1_c2_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | topwiring_b1_c2_x <= b1.topwiring_c2_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | output topwiring_c2_x: UInt<1> + | inst c1 of C + | inst c2 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | topwiring_c2_x <= c2.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C" should - s"be connected to Top port with topwiring prefix and outputfile in $testDirName, after name colission" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | wire topwiring_a1_b1_c1_x : UInt<1> - | topwiring_a1_b1_c1_x <= UInt(0) - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | wire topwiring_b1_c1_x : UInt<1> - | topwiring_b1_c1_x <= UInt(0) - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C" should + s"be connected to Top port with topwiring prefix and outputfile in $testDirName, after name colission" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | wire topwiring_a1_b1_c1_x : UInt<1> + | topwiring_a1_b1_c1_x <= UInt(0) + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | wire topwiring_b1_c1_x : UInt<1> + | topwiring_b1_c1_x <= UInt(0) + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x_0: UInt<1> - | inst a1 of A - | inst a2 of A_ - | wire topwiring_a1_b1_c1_x : UInt<1> - | topwiring_a1_b1_c1_x <= UInt<1>("h0") - | topwiring_a1_b1_c1_x_0 <= a1.topwiring_b1_c1_x_0 - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x_0: UInt<1> - | inst b1 of B - | wire topwiring_b1_c1_x : UInt<1> - | x <= UInt(1) - | topwiring_b1_c1_x <= UInt<1>("h0") - | topwiring_b1_c1_x_0 <= b1.topwiring_c1_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | inst c1 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x_0: UInt<1> + | inst a1 of A + | inst a2 of A_ + | wire topwiring_a1_b1_c1_x : UInt<1> + | topwiring_a1_b1_c1_x <= UInt<1>("h0") + | topwiring_a1_b1_c1_x_0 <= a1.topwiring_b1_c1_x_0 + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x_0: UInt<1> + | inst b1 of B + | wire topwiring_b1_c1_x : UInt<1> + | x <= UInt(1) + | topwiring_b1_c1_x <= UInt<1>("h0") + | topwiring_b1_c1_x_0 <= b1.topwiring_c1_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | inst c1 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C" should - "be connected to Top port with topwiring prefix and no output function" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C" should + "be connected to Top port with topwiring prefix and no output function" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | module A_ : - | output x: UInt<1> - | x <= UInt(1) - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | inst c1 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = + Seq(TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_")) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | module A_ : + | output x: UInt<1> + | x <= UInt(1) + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | inst c1 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C inst c1 and c2 and signal y in module A_" should - s"be connected to Top port with topwiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C inst c1 and c2 and signal y in module A_" should + s"be connected to Top port with topwiring prefix and outfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"y", - ModuleName(s"A_", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | output topwiring_a1_b1_c2_x: UInt<1> - | output topwiring_a2_y: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x - | topwiring_a2_y <= a2.topwiring_y - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | output topwiring_b1_c2_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | topwiring_b1_c2_x <= b1.topwiring_c2_x - | module A_ : - | output x: UInt<1> - | output topwiring_y: UInt<1> - | wire y : UInt<1> - | x <= UInt(1) - | y <= UInt<1>("h1") - | topwiring_y <= y - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | output topwiring_c2_x: UInt<1> - | inst c1 of C - | inst c2 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | topwiring_c2_x <= c2.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"y", ModuleName(s"A_", CircuitName(s"Top"))), s"topwiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | output topwiring_a1_b1_c2_x: UInt<1> + | output topwiring_a2_y: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x + | topwiring_a2_y <= a2.topwiring_y + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | output topwiring_b1_c2_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | topwiring_b1_c2_x <= b1.topwiring_c2_x + | module A_ : + | output x: UInt<1> + | output topwiring_y: UInt<1> + | wire y : UInt<1> + | x <= UInt(1) + | y <= UInt<1>("h1") + | topwiring_y <= y + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | output topwiring_c2_x: UInt<1> + | inst c1 of C + | inst c2 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | topwiring_c2_x <= c2.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal x in module C inst c1 and c2 and signal y in module A_" should - s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output x: UInt<1> - | x <= UInt(1) - | inst b1 of B - | module A_ : - | output x: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | x <= UInt(1) - | module B : - | output x: UInt<1> - | x <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output x: UInt<1> - | x <= UInt(0) + "The signal x in module C inst c1 and c2 and signal y in module A_" should + s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output x: UInt<1> + | x <= UInt(1) + | inst b1 of B + | module A_ : + | output x: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | x <= UInt(1) + | module B : + | output x: UInt<1> + | x <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output x: UInt<1> + | x <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"x", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"y", - ModuleName(s"A_", CircuitName(s"Top"))), - s"top2wiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_x: UInt<1> - | output topwiring_a1_b1_c2_x: UInt<1> - | output top2wiring_a2_y: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x - | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x - | top2wiring_a2_y <= a2.top2wiring_y - | module A : - | output x: UInt<1> - | output topwiring_b1_c1_x: UInt<1> - | output topwiring_b1_c2_x: UInt<1> - | inst b1 of B - | x <= UInt(1) - | topwiring_b1_c1_x <= b1.topwiring_c1_x - | topwiring_b1_c2_x <= b1.topwiring_c2_x - | module A_ : - | output x: UInt<1> - | output top2wiring_y: UInt<1> - | wire y : UInt<1> - | x <= UInt(1) - | y <= UInt<1>("h1") - | top2wiring_y <= y - | module B : - | output x: UInt<1> - | output topwiring_c1_x: UInt<1> - | output topwiring_c2_x: UInt<1> - | inst c1 of C - | inst c2 of C - | x <= UInt(1) - | topwiring_c1_x <= c1.topwiring_x - | topwiring_c2_x <= c2.topwiring_x - | module C: - | output x: UInt<1> - | output topwiring_x: UInt<1> - | x <= UInt(0) - | topwiring_x <= x + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"x", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"y", ModuleName(s"A_", CircuitName(s"Top"))), s"top2wiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_x: UInt<1> + | output topwiring_a1_b1_c2_x: UInt<1> + | output top2wiring_a2_y: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_x <= a1.topwiring_b1_c1_x + | topwiring_a1_b1_c2_x <= a1.topwiring_b1_c2_x + | top2wiring_a2_y <= a2.top2wiring_y + | module A : + | output x: UInt<1> + | output topwiring_b1_c1_x: UInt<1> + | output topwiring_b1_c2_x: UInt<1> + | inst b1 of B + | x <= UInt(1) + | topwiring_b1_c1_x <= b1.topwiring_c1_x + | topwiring_b1_c2_x <= b1.topwiring_c2_x + | module A_ : + | output x: UInt<1> + | output top2wiring_y: UInt<1> + | wire y : UInt<1> + | x <= UInt(1) + | y <= UInt<1>("h1") + | top2wiring_y <= y + | module B : + | output x: UInt<1> + | output topwiring_c1_x: UInt<1> + | output topwiring_c2_x: UInt<1> + | inst c1 of C + | inst c2 of C + | x <= UInt(1) + | topwiring_c1_x <= c1.topwiring_x + | topwiring_c2_x <= c2.topwiring_x + | module C: + | output x: UInt<1> + | output topwiring_x: UInt<1> + | x <= UInt(0) + | topwiring_x <= x """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "The signal fullword in module C inst c1 and c2 and signal y in module A_" should - s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst b1 of B - | module A_ : - | output fullword: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | fullword <= UInt(1) - | module B : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output fullword: UInt<1> - | fullword <= UInt(0) + "The signal fullword in module C inst c1 and c2 and signal y in module A_" should + s"be connected to Top port with topwiring and top2wiring prefix and outfile in $testDirName" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst b1 of B + | module A_ : + | output fullword: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | fullword <= UInt(1) + | module B : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output fullword: UInt<1> + | fullword <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"fullword", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"y", - ModuleName(s"A_", CircuitName(s"Top"))), - s"top2wiring_"), - TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction)) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_c1_fullword: UInt<1> - | output topwiring_a1_b1_c2_fullword: UInt<1> - | output top2wiring_a2_y: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword - | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword - | top2wiring_a2_y <= a2.top2wiring_y - | module A : - | output fullword: UInt<1> - | output topwiring_b1_c1_fullword: UInt<1> - | output topwiring_b1_c2_fullword: UInt<1> - | inst b1 of B - | fullword <= UInt(1) - | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword - | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword - | module A_ : - | output fullword: UInt<1> - | output top2wiring_y: UInt<1> - | wire y : UInt<1> - | fullword <= UInt(1) - | y <= UInt<1>("h1") - | top2wiring_y <= y - | module B : - | output fullword: UInt<1> - | output topwiring_c1_fullword: UInt<1> - | output topwiring_c2_fullword: UInt<1> - | inst c1 of C - | inst c2 of C - | fullword <= UInt(1) - | topwiring_c1_fullword <= c1.topwiring_fullword - | topwiring_c2_fullword <= c2.topwiring_fullword - | module C: - | output fullword: UInt<1> - | output topwiring_fullword: UInt<1> - | fullword <= UInt(0) - | topwiring_fullword <= fullword + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"fullword", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"y", ModuleName(s"A_", CircuitName(s"Top"))), s"top2wiring_"), + TopWiringOutputFilesAnnotation(testDirName, topWiringTestOutputFilesFunction) + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_c1_fullword: UInt<1> + | output topwiring_a1_b1_c2_fullword: UInt<1> + | output top2wiring_a2_y: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword + | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword + | top2wiring_a2_y <= a2.top2wiring_y + | module A : + | output fullword: UInt<1> + | output topwiring_b1_c1_fullword: UInt<1> + | output topwiring_b1_c2_fullword: UInt<1> + | inst b1 of B + | fullword <= UInt(1) + | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword + | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword + | module A_ : + | output fullword: UInt<1> + | output top2wiring_y: UInt<1> + | wire y : UInt<1> + | fullword <= UInt(1) + | y <= UInt<1>("h1") + | top2wiring_y <= y + | module B : + | output fullword: UInt<1> + | output topwiring_c1_fullword: UInt<1> + | output topwiring_c2_fullword: UInt<1> + | inst c1 of C + | inst c2 of C + | fullword <= UInt(1) + | topwiring_c1_fullword <= c1.topwiring_fullword + | topwiring_c2_fullword <= c2.topwiring_fullword + | module C: + | output fullword: UInt<1> + | output topwiring_fullword: UInt<1> + | fullword <= UInt(0) + | topwiring_fullword <= fullword """.stripMargin - execute(input, check, topwiringannos) - } - + execute(input, check, topwiringannos) + } - "The signal fullword in module C inst c1 and c2 and signal fullword in module B" should - s"be connected to Top port with topwiring prefix" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst a2 of A_ - | module A : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst b1 of B - | module A_ : - | output fullword: UInt<1> - | wire y : UInt<1> - | y <= UInt(1) - | fullword <= UInt(1) - | module B : - | output fullword: UInt<1> - | fullword <= UInt(1) - | inst c1 of C - | inst c2 of C - | module C: - | output fullword: UInt<1> - | fullword <= UInt(0) + "The signal fullword in module C inst c1 and c2 and signal fullword in module B" should + s"be connected to Top port with topwiring prefix" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst a2 of A_ + | module A : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst b1 of B + | module A_ : + | output fullword: UInt<1> + | wire y : UInt<1> + | y <= UInt(1) + | fullword <= UInt(1) + | module B : + | output fullword: UInt<1> + | fullword <= UInt(1) + | inst c1 of C + | inst c2 of C + | module C: + | output fullword: UInt<1> + | fullword <= UInt(0) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"fullword", - ModuleName(s"C", CircuitName(s"Top"))), - s"topwiring_"), - TopWiringAnnotation(ComponentName(s"fullword", - ModuleName(s"B", CircuitName(s"Top"))), - s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_b1_fullword: UInt<1> - | output topwiring_a1_b1_c1_fullword: UInt<1> - | output topwiring_a1_b1_c2_fullword: UInt<1> - | inst a1 of A - | inst a2 of A_ - | topwiring_a1_b1_fullword <= a1.topwiring_b1_fullword - | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword - | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword - | module A : - | output fullword: UInt<1> - | output topwiring_b1_fullword: UInt<1> - | output topwiring_b1_c1_fullword: UInt<1> - | output topwiring_b1_c2_fullword: UInt<1> - | inst b1 of B - | fullword <= UInt(1) - | topwiring_b1_fullword <= b1.topwiring_fullword - | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword - | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword - | module A_ : - | output fullword: UInt<1> - | wire y : UInt<1> - | fullword <= UInt(1) - | y <= UInt<1>("h1") - | module B : - | output fullword: UInt<1> - | output topwiring_fullword: UInt<1> - | output topwiring_c1_fullword: UInt<1> - | output topwiring_c2_fullword: UInt<1> - | inst c1 of C - | inst c2 of C - | fullword <= UInt(1) - | topwiring_fullword <= fullword - | topwiring_c1_fullword <= c1.topwiring_fullword - | topwiring_c2_fullword <= c2.topwiring_fullword - | module C: - | output fullword: UInt<1> - | output topwiring_fullword: UInt<1> - | fullword <= UInt(0) - | topwiring_fullword <= fullword + val topwiringannos = Seq( + TopWiringAnnotation(ComponentName(s"fullword", ModuleName(s"C", CircuitName(s"Top"))), s"topwiring_"), + TopWiringAnnotation(ComponentName(s"fullword", ModuleName(s"B", CircuitName(s"Top"))), s"topwiring_") + ) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_b1_fullword: UInt<1> + | output topwiring_a1_b1_c1_fullword: UInt<1> + | output topwiring_a1_b1_c2_fullword: UInt<1> + | inst a1 of A + | inst a2 of A_ + | topwiring_a1_b1_fullword <= a1.topwiring_b1_fullword + | topwiring_a1_b1_c1_fullword <= a1.topwiring_b1_c1_fullword + | topwiring_a1_b1_c2_fullword <= a1.topwiring_b1_c2_fullword + | module A : + | output fullword: UInt<1> + | output topwiring_b1_fullword: UInt<1> + | output topwiring_b1_c1_fullword: UInt<1> + | output topwiring_b1_c2_fullword: UInt<1> + | inst b1 of B + | fullword <= UInt(1) + | topwiring_b1_fullword <= b1.topwiring_fullword + | topwiring_b1_c1_fullword <= b1.topwiring_c1_fullword + | topwiring_b1_c2_fullword <= b1.topwiring_c2_fullword + | module A_ : + | output fullword: UInt<1> + | wire y : UInt<1> + | fullword <= UInt(1) + | y <= UInt<1>("h1") + | module B : + | output fullword: UInt<1> + | output topwiring_fullword: UInt<1> + | output topwiring_c1_fullword: UInt<1> + | output topwiring_c2_fullword: UInt<1> + | inst c1 of C + | inst c2 of C + | fullword <= UInt(1) + | topwiring_fullword <= fullword + | topwiring_c1_fullword <= c1.topwiring_fullword + | topwiring_c2_fullword <= c2.topwiring_fullword + | module C: + | output fullword: UInt<1> + | output topwiring_fullword: UInt<1> + | fullword <= UInt(0) + | topwiring_fullword <= fullword """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "TopWiringTransform" should "do nothing if run without TopWiring* annotations" in { - val input = """|circuit Top : - | module Top : - | input foo : UInt<1>""".stripMargin - val inputFile = { - val fileName = s"${testDir.getAbsolutePath}/input-no-sources.fir" - val w = new PrintWriter(fileName) - w.write(input) - w.close() - fileName - } - val args = Array( - "--custom-transforms", "firrtl.transforms.TopWiring.TopWiringTransform", - "--input-file", inputFile, - "--top-name", "Top", - "--compiler", "low", - "--info-mode", "ignore" - ) - firrtl.Driver.execute(args) match { - case FirrtlExecutionSuccess(_, emitted) => - parse(emitted).serialize should be (parse(input).serialize) - case _ => fail - } - } + "TopWiringTransform" should "do nothing if run without TopWiring* annotations" in { + val input = """|circuit Top : + | module Top : + | input foo : UInt<1>""".stripMargin + val inputFile = { + val fileName = s"${testDir.getAbsolutePath}/input-no-sources.fir" + val w = new PrintWriter(fileName) + w.write(input) + w.close() + fileName + } + val args = Array( + "--custom-transforms", + "firrtl.transforms.TopWiring.TopWiringTransform", + "--input-file", + inputFile, + "--top-name", + "Top", + "--compiler", + "low", + "--info-mode", + "ignore" + ) + firrtl.Driver.execute(args) match { + case FirrtlExecutionSuccess(_, emitted) => + parse(emitted).serialize should be(parse(input).serialize) + case _ => fail + } + } - "TopWiringTransform" should "remove TopWiringAnnotations" in { - val input = - """|circuit Top: - | module Top: - | wire foo: UInt<1>""".stripMargin + "TopWiringTransform" should "remove TopWiringAnnotations" in { + val input = + """|circuit Top: + | module Top: + | wire foo: UInt<1>""".stripMargin - val bar = - Target - .deserialize("~Top|Top>foo") - .toNamed match { case a: ComponentName => a } + val bar = + Target + .deserialize("~Top|Top>foo") + .toNamed match { case a: ComponentName => a } - val annotations = Seq(TopWiringAnnotation(bar, "bar_")) - val outputState = (new TopWiringTransform).execute(CircuitState(Parser.parse(input), MidForm, annotations, None)) + val annotations = Seq(TopWiringAnnotation(bar, "bar_")) + val outputState = (new TopWiringTransform).execute(CircuitState(Parser.parse(input), MidForm, annotations, None)) - outputState.circuit.serialize should include ("output bar_foo") - outputState.annotations.toSeq should be (empty) - } + outputState.circuit.serialize should include("output bar_foo") + outputState.annotations.toSeq should be(empty) + } } class AggregateTopWiringTests extends MiddleTransformSpec with TopWiringTestsCommon { - "An aggregate wire named myAgg in A" should s"be wired to Top's IO as topwiring_a1_myAgg" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | module A: - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) + "An aggregate wire named myAgg in A" should s"be wired to Top's IO as topwiring_a1_myAgg" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | module A: + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) """.stripMargin - val topwiringannos = Seq(TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), - s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | inst a1 of A - | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a - | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b - | module A : - | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) - | topwiring_myAgg.a <= myAgg.a - | topwiring_myAgg.b <= myAgg.b + val topwiringannos = + Seq(TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), s"topwiring_")) + val check = + """circuit Top : + | module Top : + | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | inst a1 of A + | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a + | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b + | module A : + | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) + | topwiring_myAgg.a <= myAgg.a + | topwiring_myAgg.b <= myAgg.b """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } - "Aggregate wires myAgg in Top.a1, Top.b.a1 and Top.b.a2" should - s"be wired to Top's IO as topwiring_a1_myAgg, topwiring_b_a1_myAgg, and topwiring_b_a2_myAgg" in { - val input = - """circuit Top : - | module Top : - | inst a1 of A - | inst b of B - | module B: - | inst a1 of A - | inst a2 of A - | module A: - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) + "Aggregate wires myAgg in Top.a1, Top.b.a1 and Top.b.a2" should + s"be wired to Top's IO as topwiring_a1_myAgg, topwiring_b_a1_myAgg, and topwiring_b_a2_myAgg" in { + val input = + """circuit Top : + | module Top : + | inst a1 of A + | inst b of B + | module B: + | inst a1 of A + | inst a2 of A + | module A: + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) """.stripMargin - val topwiringannos = Seq( - TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), s"topwiring_")) + val topwiringannos = + Seq(TopWiringAnnotation(ComponentName(s"myAgg", ModuleName(s"A", CircuitName(s"Top"))), s"topwiring_")) - val check = - """circuit Top : - | module Top : - | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | output topwiring_b_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | output topwiring_b_a2_myAgg: { a: UInt<1>, b: SInt<8> } - | inst a1 of A - | inst b of B - | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a - | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b - | topwiring_b_a1_myAgg.a <= b.topwiring_a1_myAgg.a - | topwiring_b_a1_myAgg.b <= b.topwiring_a1_myAgg.b - | topwiring_b_a2_myAgg.a <= b.topwiring_a2_myAgg.a - | topwiring_b_a2_myAgg.b <= b.topwiring_a2_myAgg.b - | module B: - | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } - | output topwiring_a2_myAgg: { a: UInt<1>, b: SInt<8> } - | inst a1 of A - | inst a2 of A - | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a - | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b - | topwiring_a2_myAgg.a <= a2.topwiring_myAgg.a - | topwiring_a2_myAgg.b <= a2.topwiring_myAgg.b - | module A : - | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } - | wire myAgg: { a: UInt<1>, b: SInt<8> } - | myAgg.a <= UInt(0) - | myAgg.b <= SInt(-1) - | topwiring_myAgg.a <= myAgg.a - | topwiring_myAgg.b <= myAgg.b + val check = + """circuit Top : + | module Top : + | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | output topwiring_b_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | output topwiring_b_a2_myAgg: { a: UInt<1>, b: SInt<8> } + | inst a1 of A + | inst b of B + | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a + | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b + | topwiring_b_a1_myAgg.a <= b.topwiring_a1_myAgg.a + | topwiring_b_a1_myAgg.b <= b.topwiring_a1_myAgg.b + | topwiring_b_a2_myAgg.a <= b.topwiring_a2_myAgg.a + | topwiring_b_a2_myAgg.b <= b.topwiring_a2_myAgg.b + | module B: + | output topwiring_a1_myAgg: { a: UInt<1>, b: SInt<8> } + | output topwiring_a2_myAgg: { a: UInt<1>, b: SInt<8> } + | inst a1 of A + | inst a2 of A + | topwiring_a1_myAgg.a <= a1.topwiring_myAgg.a + | topwiring_a1_myAgg.b <= a1.topwiring_myAgg.b + | topwiring_a2_myAgg.a <= a2.topwiring_myAgg.a + | topwiring_a2_myAgg.b <= a2.topwiring_myAgg.b + | module A : + | output topwiring_myAgg: { a: UInt<1>, b: SInt<8> } + | wire myAgg: { a: UInt<1>, b: SInt<8> } + | myAgg.a <= UInt(0) + | myAgg.b <= SInt(-1) + | topwiring_myAgg.a <= myAgg.a + | topwiring_myAgg.b <= myAgg.b """.stripMargin - execute(input, check, topwiringannos) - } + execute(input, check, topwiringannos) + } } diff --git a/src/test/scala/loggertests/LoggerSpec.scala b/src/test/scala/loggertests/LoggerSpec.scala index c8aae949..553f4966 100644 --- a/src/test/scala/loggertests/LoggerSpec.scala +++ b/src/test/scala/loggertests/LoggerSpec.scala @@ -260,7 +260,6 @@ class LoggerSpec extends AnyFreeSpec with Matchers with OneInstancePerTest with val captor = new OutputCaptor Logger.setOutput(captor.printStream) - Logger.setLevel(LogLevel.Info) Logger.setLevel("loggertests.LogsInfo2", LogLevel.Error) @@ -302,47 +301,47 @@ class LoggerSpec extends AnyFreeSpec with Matchers with OneInstancePerTest with } val logText = captor.getOutputAsString - logText should include ("message 1") - logText should include ("message 2") + logText should include("message 1") + logText should include("message 2") } } "Show that nested makeScopes share same state" in { - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) Logger.makeScope() { Logger.setLevel(LogLevel.Info) - Logger.getGlobalLevel should be (LogLevel.Info) + Logger.getGlobalLevel should be(LogLevel.Info) Logger.makeScope() { - Logger.getGlobalLevel should be (LogLevel.Info) + Logger.getGlobalLevel should be(LogLevel.Info) } Logger.makeScope() { Logger.setLevel(LogLevel.Debug) - Logger.getGlobalLevel should be (LogLevel.Debug) + Logger.getGlobalLevel should be(LogLevel.Debug) } - Logger.getGlobalLevel should be (LogLevel.Debug) + Logger.getGlobalLevel should be(LogLevel.Debug) } - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) } "Show that first makeScope starts with fresh state" in { - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) Logger.setLevel(LogLevel.Warn) - Logger.getGlobalLevel should be (LogLevel.Warn) + Logger.getGlobalLevel should be(LogLevel.Warn) Logger.makeScope() { - Logger.getGlobalLevel should be (LoggerSpec.globalLevel) + Logger.getGlobalLevel should be(LoggerSpec.globalLevel) Logger.setLevel(LogLevel.Trace) - Logger.getGlobalLevel should be (LogLevel.Trace) + Logger.getGlobalLevel should be(LogLevel.Trace) } - Logger.getGlobalLevel should be (LogLevel.Warn) + Logger.getGlobalLevel should be(LogLevel.Warn) } } } |
