aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/AddDescriptionNodes.scala116
-rw-r--r--src/main/scala/firrtl/Compiler.scala175
-rw-r--r--src/main/scala/firrtl/DependencyAPIMigration.scala8
-rw-r--r--src/main/scala/firrtl/Driver.scala53
-rw-r--r--src/main/scala/firrtl/EmissionOption.scala31
-rw-r--r--src/main/scala/firrtl/Emitter.scala807
-rw-r--r--src/main/scala/firrtl/ExecutionOptionsManager.scala333
-rw-r--r--src/main/scala/firrtl/FileUtils.scala25
-rw-r--r--src/main/scala/firrtl/FirrtlException.scala7
-rw-r--r--src/main/scala/firrtl/Implicits.scala8
-rw-r--r--src/main/scala/firrtl/LexerHelper.scala19
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala8
-rw-r--r--src/main/scala/firrtl/Mappers.scala45
-rw-r--r--src/main/scala/firrtl/Namespace.scala11
-rw-r--r--src/main/scala/firrtl/Parser.scala5
-rw-r--r--src/main/scala/firrtl/PrimOps.scala218
-rw-r--r--src/main/scala/firrtl/RenameMap.scala273
-rw-r--r--src/main/scala/firrtl/Utils.scala480
-rw-r--r--src/main/scala/firrtl/Visitor.scala299
-rw-r--r--src/main/scala/firrtl/WIR.scala251
-rw-r--r--src/main/scala/firrtl/analyses/CircuitGraph.scala7
-rw-r--r--src/main/scala/firrtl/analyses/ConnectionGraph.scala146
-rw-r--r--src/main/scala/firrtl/analyses/IRLookup.scala189
-rw-r--r--src/main/scala/firrtl/analyses/InstanceGraph.scala30
-rw-r--r--src/main/scala/firrtl/analyses/InstanceKeyGraph.scala36
-rw-r--r--src/main/scala/firrtl/analyses/NodeCount.scala11
-rw-r--r--src/main/scala/firrtl/analyses/SymbolTable.scala27
-rw-r--r--src/main/scala/firrtl/annotations/Annotation.scala61
-rw-r--r--src/main/scala/firrtl/annotations/AnnotationUtils.scala63
-rw-r--r--src/main/scala/firrtl/annotations/JsonProtocol.scala331
-rw-r--r--src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala14
-rw-r--r--src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala20
-rw-r--r--src/main/scala/firrtl/annotations/PresetAnnotations.scala12
-rw-r--r--src/main/scala/firrtl/annotations/Target.scala291
-rw-r--r--src/main/scala/firrtl/annotations/TargetToken.scala31
-rw-r--r--src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala26
-rw-r--r--src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala24
-rw-r--r--src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala92
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala82
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala86
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala316
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala29
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala89
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala49
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala131
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala20
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala83
-rw-r--r--src/main/scala/firrtl/checks/CheckResets.scala31
-rw-r--r--src/main/scala/firrtl/constraint/Constraint.scala2
-rw-r--r--src/main/scala/firrtl/constraint/ConstraintSolver.scala120
-rw-r--r--src/main/scala/firrtl/constraint/Inequality.scala6
-rw-r--r--src/main/scala/firrtl/constraint/IsAdd.scala37
-rw-r--r--src/main/scala/firrtl/constraint/IsFloor.scala16
-rw-r--r--src/main/scala/firrtl/constraint/IsKnown.scala4
-rw-r--r--src/main/scala/firrtl/constraint/IsMax.scala30
-rw-r--r--src/main/scala/firrtl/constraint/IsMin.scala37
-rw-r--r--src/main/scala/firrtl/constraint/IsMul.scala25
-rw-r--r--src/main/scala/firrtl/constraint/IsNeg.scala12
-rw-r--r--src/main/scala/firrtl/constraint/IsPow.scala10
-rw-r--r--src/main/scala/firrtl/constraint/IsVar.scala3
-rw-r--r--src/main/scala/firrtl/features/LetterCaseTransform.scala11
-rw-r--r--src/main/scala/firrtl/graph/DiGraph.scala46
-rw-r--r--src/main/scala/firrtl/graph/EdgeData.scala5
-rw-r--r--src/main/scala/firrtl/graph/EulerTour.scala61
-rw-r--r--src/main/scala/firrtl/graph/RenderDiGraph.scala78
-rw-r--r--src/main/scala/firrtl/ir/IR.scala869
-rw-r--r--src/main/scala/firrtl/ir/Serializer.scala247
-rw-r--r--src/main/scala/firrtl/ir/StructuralHash.scala291
-rw-r--r--src/main/scala/firrtl/options/DependencyManager.scala181
-rw-r--r--src/main/scala/firrtl/options/ExitCodes.scala2
-rw-r--r--src/main/scala/firrtl/options/OptionParser.scala15
-rw-r--r--src/main/scala/firrtl/options/Phase.scala30
-rw-r--r--src/main/scala/firrtl/options/Registration.scala28
-rw-r--r--src/main/scala/firrtl/options/Shell.scala19
-rw-r--r--src/main/scala/firrtl/options/Stage.scala11
-rw-r--r--src/main/scala/firrtl/options/StageAnnotations.scala21
-rw-r--r--src/main/scala/firrtl/options/StageOptions.scala28
-rw-r--r--src/main/scala/firrtl/options/StageUtils.scala10
-rw-r--r--src/main/scala/firrtl/options/package.scala7
-rw-r--r--src/main/scala/firrtl/options/phases/AddDefaults.scala2
-rw-r--r--src/main/scala/firrtl/options/phases/Checks.scala15
-rw-r--r--src/main/scala/firrtl/options/phases/GetIncludes.scala5
-rw-r--r--src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala12
-rw-r--r--src/main/scala/firrtl/passes/CInferMDir.scala67
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala4
-rw-r--r--src/main/scala/firrtl/passes/CheckFlows.scala84
-rw-r--r--src/main/scala/firrtl/passes/CheckHighForm.scala227
-rw-r--r--src/main/scala/firrtl/passes/CheckInitialization.scala11
-rw-r--r--src/main/scala/firrtl/passes/CheckTypes.scala376
-rw-r--r--src/main/scala/firrtl/passes/CheckWidths.scala139
-rw-r--r--src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala37
-rw-r--r--src/main/scala/firrtl/passes/ConvertFixedToSInt.scala90
-rw-r--r--src/main/scala/firrtl/passes/ExpandConnects.scala66
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala173
-rw-r--r--src/main/scala/firrtl/passes/InferBinaryPoints.scala98
-rw-r--r--src/main/scala/firrtl/passes/InferTypes.scala76
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala190
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala246
-rw-r--r--src/main/scala/firrtl/passes/Legalize.scala31
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala300
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala46
-rw-r--r--src/main/scala/firrtl/passes/Pass.scala2
-rw-r--r--src/main/scala/firrtl/passes/PullMuxes.scala80
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala79
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala196
-rw-r--r--src/main/scala/firrtl/passes/RemoveEmpty.scala2
-rw-r--r--src/main/scala/firrtl/passes/RemoveIntervals.scala149
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala22
-rw-r--r--src/main/scala/firrtl/passes/ReplaceAccesses.scala17
-rw-r--r--src/main/scala/firrtl/passes/ResolveFlows.scala25
-rw-r--r--src/main/scala/firrtl/passes/ResolveKinds.scala16
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala100
-rw-r--r--src/main/scala/firrtl/passes/ToWorkingIR.scala2
-rw-r--r--src/main/scala/firrtl/passes/TrimIntervals.scala58
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala241
-rw-r--r--src/main/scala/firrtl/passes/VerilogModulusCleanup.scala81
-rw-r--r--src/main/scala/firrtl/passes/VerilogPrep.scala34
-rw-r--r--src/main/scala/firrtl/passes/ZeroLengthVecs.scala21
-rw-r--r--src/main/scala/firrtl/passes/ZeroWidth.scala137
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockList.scala26
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala19
-rw-r--r--src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala45
-rw-r--r--src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala20
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala5
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala101
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemConf.scala65
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala56
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemLibOptions.scala3
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala16
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemUtils.scala69
-rw-r--r--src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala43
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala118
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala50
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala36
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala12
-rw-r--r--src/main/scala/firrtl/passes/memlib/ToMemIR.scala9
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala136
-rw-r--r--src/main/scala/firrtl/passes/memlib/YamlUtils.scala15
-rw-r--r--src/main/scala/firrtl/passes/wiring/Wiring.scala212
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringTransform.scala11
-rw-r--r--src/main/scala/firrtl/passes/wiring/WiringUtils.scala122
-rw-r--r--src/main/scala/firrtl/proto/FromProto.scala114
-rw-r--r--src/main/scala/firrtl/proto/ToProto.scala131
-rw-r--r--src/main/scala/firrtl/stage/FirrtlAnnotations.scala115
-rw-r--r--src/main/scala/firrtl/stage/FirrtlCli.scala22
-rw-r--r--src/main/scala/firrtl/stage/FirrtlOptions.scala22
-rw-r--r--src/main/scala/firrtl/stage/FirrtlStage.scala3
-rw-r--r--src/main/scala/firrtl/stage/FirrtlStageUtils.scala8
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala130
-rw-r--r--src/main/scala/firrtl/stage/TransformManager.scala8
-rw-r--r--src/main/scala/firrtl/stage/package.scala56
-rw-r--r--src/main/scala/firrtl/stage/phases/AddCircuit.scala11
-rw-r--r--src/main/scala/firrtl/stage/phases/AddDefaults.scala22
-rw-r--r--src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala19
-rw-r--r--src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala17
-rw-r--r--src/main/scala/firrtl/stage/phases/CatchExceptions.scala14
-rw-r--r--src/main/scala/firrtl/stage/phases/Checks.scala79
-rw-r--r--src/main/scala/firrtl/stage/phases/Compiler.scala85
-rw-r--r--src/main/scala/firrtl/stage/phases/DriverCompatibility.scala70
-rw-r--r--src/main/scala/firrtl/stage/phases/WriteEmitted.scala12
-rw-r--r--src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala3
-rw-r--r--src/main/scala/firrtl/stage/transforms/Compiler.scala11
-rw-r--r--src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala6
-rw-r--r--src/main/scala/firrtl/stage/transforms/TrackTransforms.scala12
-rw-r--r--src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala4
-rw-r--r--src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala55
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala93
-rw-r--r--src/main/scala/firrtl/transforms/CombineCats.scala33
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala463
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala130
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala297
-rw-r--r--src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala25
-rw-r--r--src/main/scala/firrtl/transforms/Flatten.scala179
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala30
-rw-r--r--src/main/scala/firrtl/transforms/GroupComponents.scala123
-rw-r--r--src/main/scala/firrtl/transforms/InferResets.scala93
-rw-r--r--src/main/scala/firrtl/transforms/InlineBitExtractions.scala48
-rw-r--r--src/main/scala/firrtl/transforms/InlineCasts.scala43
-rw-r--r--src/main/scala/firrtl/transforms/LegalizeClocks.scala16
-rw-r--r--src/main/scala/firrtl/transforms/LegalizeReductions.scala8
-rw-r--r--src/main/scala/firrtl/transforms/ManipulateNames.scala287
-rw-r--r--src/main/scala/firrtl/transforms/OptimizationAnnotations.scala18
-rw-r--r--src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala168
-rw-r--r--src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala27
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala7
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala49
-rw-r--r--src/main/scala/firrtl/transforms/RenameModules.scala2
-rw-r--r--src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala3
-rw-r--r--src/main/scala/firrtl/transforms/SimplifyMems.scala11
-rw-r--r--src/main/scala/firrtl/transforms/TopWiring.scala269
-rw-r--r--src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala23
-rw-r--r--src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala8
-rw-r--r--src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala23
-rw-r--r--src/main/scala/firrtl/traversals/Foreachers.scala34
-rw-r--r--src/main/scala/firrtl/util/BackendCompilationUtilities.scala148
-rw-r--r--src/main/scala/firrtl/util/ClassUtils.scala14
-rw-r--r--src/main/scala/logger/Logger.scala105
-rw-r--r--src/main/scala/logger/LoggerAnnotations.scala35
-rw-r--r--src/main/scala/logger/LoggerOptions.scala20
-rw-r--r--src/main/scala/logger/phases/AddDefaults.scala10
-rw-r--r--src/main/scala/logger/phases/Checks.scala27
-rw-r--r--src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala24
-rw-r--r--src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala21
203 files changed, 9099 insertions, 7351 deletions
diff --git a/src/main/scala/firrtl/AddDescriptionNodes.scala b/src/main/scala/firrtl/AddDescriptionNodes.scala
index 5ff07314..7adb28af 100644
--- a/src/main/scala/firrtl/AddDescriptionNodes.scala
+++ b/src/main/scala/firrtl/AddDescriptionNodes.scala
@@ -12,7 +12,7 @@ import firrtl.options.Dependency
* Usually, we would like to emit these descriptions in some way.
*/
sealed trait DescriptionAnnotation extends Annotation {
- def target: Target
+ def target: Target
def description: String
}
@@ -24,7 +24,7 @@ sealed trait DescriptionAnnotation extends Annotation {
case class DocStringAnnotation(target: Target, description: String) extends DescriptionAnnotation {
def update(renames: RenameMap): Seq[DocStringAnnotation] = {
renames.get(target) match {
- case None => Seq(this)
+ case None => Seq(this)
case Some(seq) => seq.map(n => this.copy(target = n))
}
}
@@ -38,7 +38,7 @@ case class DocStringAnnotation(target: Target, description: String) extends Desc
case class AttributeAnnotation(target: Target, description: String) extends DescriptionAnnotation {
def update(renames: RenameMap): Seq[AttributeAnnotation] = {
renames.get(target) match {
- case None => Seq(this)
+ case None => Seq(this)
case Some(seq) => seq.map(n => this.copy(target = n))
}
}
@@ -78,18 +78,20 @@ case class Attribute(string: StringLit) extends Description {
* @param descriptions
* @param stmt the encapsulated statement
*/
-private case class DescribedStmt(descriptions: Seq[Description], stmt: Statement) extends Statement with HasDescription {
+private case class DescribedStmt(descriptions: Seq[Description], stmt: Statement)
+ extends Statement
+ with HasDescription {
override def serialize: String = s"${descriptions.map(_.serialize).mkString("\n")}\n${stmt.serialize}"
- def mapStmt(f: Statement => Statement): Statement = f(stmt)
- def mapExpr(f: Expression => Expression): Statement = this.copy(stmt = stmt.mapExpr(f))
- def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f))
- def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f))
- def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f))
- def foreachStmt(f: Statement => Unit): Unit = f(stmt)
- def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f)
- def foreachType(f: Type => Unit): Unit = stmt.foreachType(f)
- def foreachString(f: String => Unit): Unit = stmt.foreachString(f)
- def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f)
+ def mapStmt(f: Statement => Statement): Statement = f(stmt)
+ def mapExpr(f: Expression => Expression): Statement = this.copy(stmt = stmt.mapExpr(f))
+ def mapType(f: Type => Type): Statement = this.copy(stmt = stmt.mapType(f))
+ def mapString(f: String => String): Statement = this.copy(stmt = stmt.mapString(f))
+ def mapInfo(f: Info => Info): Statement = this.copy(stmt = stmt.mapInfo(f))
+ def foreachStmt(f: Statement => Unit): Unit = f(stmt)
+ def foreachExpr(f: Expression => Unit): Unit = stmt.foreachExpr(f)
+ def foreachType(f: Type => Unit): Unit = stmt.foreachType(f)
+ def foreachString(f: String => Unit): Unit = stmt.foreachString(f)
+ def foreachInfo(f: Info => Unit): Unit = stmt.foreachInfo(f)
}
/**
@@ -98,21 +100,24 @@ private case class DescribedStmt(descriptions: Seq[Description], stmt: Statement
* @param portDescriptions list of descriptions for the module's ports
* @param mod the encapsulated module
*/
-private case class DescribedMod(descriptions: Seq[Description],
+private case class DescribedMod(
+ descriptions: Seq[Description],
portDescriptions: Map[String, Seq[Description]],
- mod: DefModule) extends DefModule with HasDescription {
+ mod: DefModule)
+ extends DefModule
+ with HasDescription {
val info = mod.info
val name = mod.name
val ports = mod.ports
override def serialize: String = s"${descriptions.map(_.serialize).mkString("\n")}\n${mod.serialize}"
- def mapStmt(f: Statement => Statement): DefModule = this.copy(mod = mod.mapStmt(f))
- def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f))
- def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f))
- def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f))
- def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f)
- def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f)
- def foreachString(f: String => Unit): Unit = mod.foreachString(f)
- def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f)
+ def mapStmt(f: Statement => Statement): DefModule = this.copy(mod = mod.mapStmt(f))
+ def mapPort(f: Port => Port): DefModule = this.copy(mod = mod.mapPort(f))
+ def mapString(f: String => String): DefModule = this.copy(mod = mod.mapString(f))
+ def mapInfo(f: Info => Info): DefModule = this.copy(mod = mod.mapInfo(f))
+ def foreachStmt(f: Statement => Unit): Unit = mod.foreachStmt(f)
+ def foreachPort(f: Port => Unit): Unit = mod.foreachPort(f)
+ def foreachString(f: String => Unit): Unit = mod.foreachString(f)
+ def foreachInfo(f: Info => Unit): Unit = mod.foreachInfo(f)
}
/** Wraps modules or statements with their respective described nodes. Descriptions come from [[DescriptionAnnotation]].
@@ -125,17 +130,19 @@ private case class DescribedMod(descriptions: Seq[Description],
class AddDescriptionNodes extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.PropagatePresetAnnotations],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(firrtl.passes.VerilogPrep) )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.PropagatePresetAnnotations],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(firrtl.passes.VerilogPrep)
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -149,18 +156,22 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration {
case d: IsDeclaration => Some(d.name)
case _ => None
}
- val descs = sname.flatMap({ case name =>
- compMap.get(name)
+ val descs = sname.flatMap({
+ case name =>
+ compMap.get(name)
})
(descs, s) match {
case (Some(d), DescribedStmt(prevDescs, ss)) => DescribedStmt(prevDescs ++ d, ss)
- case (Some(d), ss) => DescribedStmt(d, ss)
- case (None, _) => s
+ case (Some(d), ss) => DescribedStmt(d, ss)
+ case (None, _) => s
}
}
- def onModule(modMap: Map[String, Seq[Description]], compMaps: Map[String, Map[String, Seq[Description]]])
- (mod: DefModule): DefModule = {
+ def onModule(
+ modMap: Map[String, Seq[Description]],
+ compMaps: Map[String, Map[String, Seq[Description]]]
+ )(mod: DefModule
+ ): DefModule = {
val compMap = compMaps.getOrElse(mod.name, Map())
val newMod = mod.mapStmt(onStmt(compMap))
val portDesc = mod.ports.collect {
@@ -210,14 +221,18 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration {
rest ++ doc ++ attr
}
- def collectMaps(annos: Seq[Annotation]): (Map[String, Seq[Description]], Map[String, Map[String, Seq[Description]]]) = {
+ def collectMaps(
+ annos: Seq[Annotation]
+ ): (Map[String, Seq[Description]], Map[String, Map[String, Seq[Description]]]) = {
val modList = annos.collect {
case DocStringAnnotation(ModuleTarget(_, m), desc) => (m, DocString(StringLit.unescape(desc)))
case AttributeAnnotation(ModuleTarget(_, m), desc) => (m, Attribute(StringLit.unescape(desc)))
}
// map field 1 (module name) -> field 2 (a list of Descriptions)
- val modMap = modList.groupBy(_._1).mapValues(_.map(_._2))
+ val modMap = modList
+ .groupBy(_._1)
+ .mapValues(_.map(_._2))
// and then merge like descriptions (e.g. multiple docstrings into one big docstring)
.mapValues(mergeDescriptions)
@@ -229,11 +244,16 @@ class AddDescriptionNodes extends Transform with DependencyAPIMigration {
}
// map field 1 (name) -> a map that we build
- val compMap = compList.groupBy(_._1).mapValues(
- // map field 2 (component name) -> field 3 (a list of Descriptions)
- _.groupBy(_._2).mapValues(_.map(_._3))
- // and then merge like descriptions (e.g. multiple docstrings into one big docstring)
- .mapValues(mergeDescriptions).toMap)
+ val compMap = compList
+ .groupBy(_._1)
+ .mapValues(
+ // map field 2 (component name) -> field 3 (a list of Descriptions)
+ _.groupBy(_._2)
+ .mapValues(_.map(_._3))
+ // and then merge like descriptions (e.g. multiple docstrings into one big docstring)
+ .mapValues(mergeDescriptions)
+ .toMap
+ )
(modMap.toMap, compMap.toMap)
}
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index db4853a2..ec09cace 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -13,7 +13,7 @@ import firrtl.annotations._
import firrtl.ir.Circuit
import firrtl.Utils.throwInternalError
import firrtl.annotations.transforms.{EliminateTargetPaths, ResolvePaths}
-import firrtl.options.{DependencyAPI, Dependency, StageUtils, TransformLike}
+import firrtl.options.{Dependency, DependencyAPI, StageUtils, TransformLike}
import firrtl.stage.Forms
/** Container of all annotations for a Firrtl compiler */
@@ -34,19 +34,22 @@ object AnnotationSeq {
* Generally only a return value from [[Transform]]s
*/
case class CircuitState(
- circuit: Circuit,
- form: CircuitForm,
- annotations: AnnotationSeq,
- renames: Option[RenameMap]) {
+ circuit: Circuit,
+ form: CircuitForm,
+ annotations: AnnotationSeq,
+ renames: Option[RenameMap]) {
/** Helper for getting just an emitted circuit */
def emittedCircuitOption: Option[EmittedCircuit] =
- emittedComponents collectFirst { case x: EmittedCircuit => x }
+ emittedComponents.collectFirst { case x: EmittedCircuit => x }
+
/** Helper for getting an [[EmittedCircuit]] when it is known to exist */
def getEmittedCircuit: EmittedCircuit = emittedCircuitOption match {
case Some(emittedCircuit) => emittedCircuit
case None =>
- throw new FirrtlInternalException(s"No EmittedCircuit found! Did you delete any annotations?\n$deletedAnnotations")
+ throw new FirrtlInternalException(
+ s"No EmittedCircuit found! Did you delete any annotations?\n$deletedAnnotations"
+ )
}
/** Helper function for extracting emitted components from annotations */
@@ -64,7 +67,7 @@ case class CircuitState(
def resolvePaths(targets: Seq[CompleteTarget]): CircuitState = targets match {
case Nil => this
case _ =>
- val newCS = new EliminateTargetPaths().runTransform(this.copy(annotations = ResolvePaths(targets) +: annotations ))
+ val newCS = new EliminateTargetPaths().runTransform(this.copy(annotations = ResolvePaths(targets) +: annotations))
newCS.copy(form = form)
}
@@ -73,8 +76,8 @@ case class CircuitState(
* @return
*/
def resolvePathsOf(annoClasses: Class[_]*): CircuitState = {
- val targets = getAnnotationsOf(annoClasses:_*).flatMap(_.getTargets)
- if(targets.nonEmpty) resolvePaths(targets.flatMap{_.getComplete}) else this
+ val targets = getAnnotationsOf(annoClasses: _*).flatMap(_.getTargets)
+ if (targets.nonEmpty) resolvePaths(targets.flatMap { _.getComplete }) else this
}
/** Returns all annotations which are of a class in annoClasses
@@ -105,7 +108,8 @@ object CircuitState {
*/
@deprecated(
"Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+)
sealed abstract class CircuitForm(private val value: Int) extends Ordered[CircuitForm] {
// Note that value is used only to allow comparisons
def compare(that: CircuitForm): Int = this.value - that.value
@@ -125,7 +129,8 @@ sealed abstract class CircuitForm(private val value: Int) extends Ordered[Circui
*/
@deprecated(
"Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+)
final case object ChirrtlForm extends CircuitForm(value = 3) {
val outputSuffix: String = ".fir"
}
@@ -139,7 +144,8 @@ final case object ChirrtlForm extends CircuitForm(value = 3) {
*/
@deprecated(
"Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+)
final case object HighForm extends CircuitForm(2) {
val outputSuffix: String = ".hi.fir"
}
@@ -153,7 +159,8 @@ final case object HighForm extends CircuitForm(2) {
*/
@deprecated(
"Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+)
final case object MidForm extends CircuitForm(1) {
val outputSuffix: String = ".mid.fir"
}
@@ -166,7 +173,8 @@ final case object MidForm extends CircuitForm(1) {
*/
@deprecated(
"Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+)
final case object LowForm extends CircuitForm(0) {
val outputSuffix: String = ".lo.fir"
}
@@ -184,7 +192,8 @@ final case object LowForm extends CircuitForm(0) {
*/
@deprecated(
"Mix-in the DependencyAPIMigration trait into your Transform and specify its Dependency API dependencies. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+)
final case object UnknownForm extends CircuitForm(-1) {
override def compare(that: CircuitForm): Int = { sys.error("Illegal to compare UnknownForm"); 0 }
@@ -212,12 +221,15 @@ private[firrtl] object Transform {
logger.info(s"Form: ${after.form}")
logger.trace(s"Annotations:")
logger.trace {
- JsonProtocol.serializeTry(remappedAnnotations).recoverWith {
- case NonFatal(e) =>
- val msg = s"Exception thrown during Annotation serialization:\n " +
- e.toString.replaceAll("\n", "\n ")
- Try(msg)
- }.get
+ JsonProtocol
+ .serializeTry(remappedAnnotations)
+ .recoverWith {
+ case NonFatal(e) =>
+ val msg = s"Exception thrown during Annotation serialization:\n " +
+ e.toString.replaceAll("\n", "\n ")
+ Try(msg)
+ }
+ .get
}
logger.trace(s"Circuit:\n${after.circuit.serialize}")
@@ -234,17 +246,18 @@ private[firrtl] object Transform {
* @return the updated annotations
*/
def propagateAnnotations(
- name: String,
- logger: Logger,
- inAnno: AnnotationSeq,
- resAnno: AnnotationSeq,
- renameOpt: Option[RenameMap]): AnnotationSeq = {
+ name: String,
+ logger: Logger,
+ inAnno: AnnotationSeq,
+ resAnno: AnnotationSeq,
+ renameOpt: Option[RenameMap]
+ ): AnnotationSeq = {
val newAnnotations = {
val inSet = mutable.LinkedHashSet() ++ inAnno
val resSet = mutable.LinkedHashSet() ++ resAnno
val deleted = (inSet -- resSet).map {
case DeletedAnnotation(xFormName, delAnno) => DeletedAnnotation(s"$xFormName+$name", delAnno)
- case anno => DeletedAnnotation(name, anno)
+ case anno => DeletedAnnotation(name, anno)
}
val created = resSet -- inSet
val unchanged = resSet & inSet
@@ -260,7 +273,7 @@ private[firrtl] object Transform {
remappedAnnos.foreach { remapped =>
val set = remapped2original.getOrElseUpdate(remapped, mutable.LinkedHashSet.empty[Annotation])
set += anno
- if(set.size > 1) keysOfNote += remapped
+ if (set.size > 1) keysOfNote += remapped
}
remappedAnnos
}.toSeq
@@ -280,15 +293,11 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform
def name: String = this.getClass.getName
/** The [[firrtl.CircuitForm]] that this transform requires to operate on */
- @deprecated(
- "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3")
def inputForm: CircuitForm
/** The [[firrtl.CircuitForm]] that this transform outputs */
- @deprecated(
- "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3")
def outputForm: CircuitForm
/** Perform the transform, encode renaming with RenameMap, and can
@@ -324,8 +333,9 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform
Dependency[SystemVerilogEmitter] :: Nil
val emitters = inputForm match {
- case C => Dependency[ChirrtlEmitter] :: Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters
- case H => Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters
+ case C =>
+ Dependency[ChirrtlEmitter] :: Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters
+ case H => Dependency[HighFirrtlEmitter] :: Dependency[MiddleFirrtlEmitter] :: lowEmitters
case M => Dependency[MiddleFirrtlEmitter] :: lowEmitters
case L => lowEmitters
case U => Nil
@@ -334,9 +344,9 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform
val selfDep = Dependency.fromTransform(this)
inputForm match {
- case C => (fullCompilerSet ++ emitters - selfDep).toSeq
- case H => (fullCompilerSet -- Forms.Deduped ++ emitters - selfDep).toSeq
- case M => (fullCompilerSet -- Forms.MidForm ++ emitters - selfDep).toSeq
+ case C => (fullCompilerSet ++ emitters - selfDep).toSeq
+ case H => (fullCompilerSet -- Forms.Deduped ++ emitters - selfDep).toSeq
+ case M => (fullCompilerSet -- Forms.MidForm ++ emitters - selfDep).toSeq
case L => (fullCompilerSet -- Forms.LowFormOptimized ++ emitters - selfDep).toSeq
case U => Nil
}
@@ -347,9 +357,9 @@ trait Transform extends TransformLike[CircuitState] with DependencyAPI[Transform
override def invalidates(a: Transform): Boolean = {
(inputForm, outputForm) match {
- case (U, _) | (_, U) => true // invalidate everything
+ case (U, _) | (_, U) => true // invalidate everything
case (i, o) if i >= o => false // invalidate nothing
- case (_, C) => true // invalidate everything
+ case (_, C) => true // invalidate everything
case (_, H) => highOutputInvalidates(Dependency.fromTransform(a))
case (_, M) => midOutputInvalidates(Dependency.fromTransform(a))
case (_, L) => false // invalidate nothing
@@ -386,7 +396,7 @@ abstract class SeqTransform extends Transform with SeqTransformBased {
/*
require(state.form <= inputForm,
s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
- */
+ */
val ret = runTransforms(state)
CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames)
}
@@ -401,7 +411,7 @@ trait ResolvedAnnotationPaths {
val annotationClasses: Traversable[Class[_]]
override def prepare(state: CircuitState): CircuitState = {
- state.resolvePathsOf(annotationClasses.toSeq:_*)
+ state.resolvePathsOf(annotationClasses.toSeq: _*)
}
}
@@ -419,6 +429,7 @@ trait Emitter extends Transform {
@deprecated("This will be removed in 1.4", "FIRRTL 1.3")
object CompilerUtils extends LazyLogging {
+
/** Generates a sequence of [[Transform]]s to lower a Firrtl circuit
*
* @param inputForm [[CircuitForm]] to lower from
@@ -427,7 +438,8 @@ object CompilerUtils extends LazyLogging {
*/
@deprecated(
"Use a TransformManager requesting which transforms you want to run. This will be removed in 1.4.",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+ )
def getLoweringTransforms(inputForm: CircuitForm, outputForm: CircuitForm): Seq[Transform] = {
// If outputForm is equal-to or higher than inputForm, nothing to lower
if (outputForm >= inputForm) {
@@ -437,10 +449,15 @@ object CompilerUtils extends LazyLogging {
case ChirrtlForm =>
Seq(new ChirrtlToHighFirrtl) ++ getLoweringTransforms(HighForm, outputForm)
case HighForm =>
- Seq(new IRToWorkingIR, new ResolveAndCheck, new firrtl.transforms.DedupModules, new HighFirrtlToMiddleFirrtl) ++
+ Seq(
+ new IRToWorkingIR,
+ new ResolveAndCheck,
+ new firrtl.transforms.DedupModules,
+ new HighFirrtlToMiddleFirrtl
+ ) ++
getLoweringTransforms(MidForm, outputForm)
- case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
- case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above
+ case MidForm => Seq(new MiddleFirrtlToLowFirrtl) ++ getLoweringTransforms(LowForm, outputForm)
+ case LowForm => throwInternalError("getLoweringTransforms - LowForm") // should be caught by if above
case UnknownForm => throwInternalError("getLoweringTransforms - UnknownForm") // should be caught by if above
}
}
@@ -479,28 +496,32 @@ object CompilerUtils extends LazyLogging {
*/
@deprecated(
"Use a TransformManager requesting which transforms you want to run. This will be removed in 1.4.",
- "FIRRTL 1.3")
+ "FIRRTL 1.3"
+ )
def mergeTransforms(lowering: Seq[Transform], custom: Seq[Transform]): Seq[Transform] = {
- custom
- .sortWith{
- case (a, b) => (a, b) match {
+ custom.sortWith {
+ case (a, b) =>
+ (a, b) match {
case (_: Emitter, _: Emitter) => false
- case (_, _: Emitter) => true
- case _ => false }}
- .foldLeft(lowering) { case (transforms, xform) =>
- val index = transforms lastIndexWhere (_.outputForm == xform.inputForm)
- assert(index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front
- s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform")
- val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index
- front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back
+ case (_, _: Emitter) => true
+ case _ => false
+ }
}
+ .foldLeft(lowering) {
+ case (transforms, xform) =>
+ val index = transforms.lastIndexWhere(_.outputForm == xform.inputForm)
+ assert(
+ index >= 0 || xform.inputForm == ChirrtlForm, // If ChirrtlForm just put at front
+ s"No transform in $lowering has outputForm ${xform.inputForm} as required by $xform"
+ )
+ val (front, back) = transforms.splitAt(index + 1) // +1 because we want to be AFTER index
+ front ++ List(xform) ++ getLoweringTransforms(xform.outputForm, xform.inputForm) ++ back
+ }
}
}
-@deprecated(
- "Migrate to firrtl.stage.transforms.Compiler. This will be removed in 1.4.",
- "FIRRTL 1.3")
+@deprecated("Migrate to firrtl.stage.transforms.Compiler. This will be removed in 1.4.", "FIRRTL 1.3")
trait Compiler extends Transform with DependencyAPIMigration {
def emitter: Emitter
@@ -511,15 +532,17 @@ trait Compiler extends Transform with DependencyAPIMigration {
def transforms: Seq[Transform]
final override def execute(state: CircuitState): CircuitState =
- new stage.transforms.Compiler (
+ new stage.transforms.Compiler(
targets = (transforms :+ emitter).map(Dependency.fromTransform),
currentState = prerequisites,
knownObjects = (transforms :+ emitter).toSet
).execute(state)
- require(transforms.size >= 1,
- s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " +
- "Use IdentityTransform if you need an identity/no-op transform.")
+ require(
+ transforms.size >= 1,
+ s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " +
+ "Use IdentityTransform if you need an identity/no-op transform."
+ )
/** Perform compilation
*
@@ -531,10 +554,9 @@ trait Compiler extends Transform with DependencyAPIMigration {
@deprecated(
"Migrate to '(new FirrtlStage).execute(args: Array[String], annotations: AnnotationSeq)'." +
"This will be removed in 1.4.",
- "FIRRTL 1.0")
- def compile(state: CircuitState,
- writer: Writer,
- customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
+ "FIRRTL 1.0"
+ )
+ def compile(state: CircuitState, writer: Writer, customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
val finalState = compileAndEmit(state, customTransforms)
writer.write(finalState.getEmittedCircuit.value)
finalState
@@ -555,9 +577,9 @@ trait Compiler extends Transform with DependencyAPIMigration {
@deprecated(
"Migrate to '(new FirrtlStage).execute(args: Array[String], annotations: AnnotationSeq)'." +
"This will be removed in 1.4.",
- "FIRRTL 1.3.3")
- def compileAndEmit(state: CircuitState,
- customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
+ "FIRRTL 1.3.3"
+ )
+ def compileAndEmit(state: CircuitState, customTransforms: Seq[Transform] = Seq.empty): CircuitState = {
val emitAnno = EmitCircuitAnnotation(emitter.getClass)
compile(state.copy(annotations = emitAnno +: state.annotations), emitter +: customTransforms)
}
@@ -574,9 +596,10 @@ trait Compiler extends Transform with DependencyAPIMigration {
@deprecated(
"Migrate to '(new FirrtlStage).execute(args: Array[String], annotations: AnnotationSeq)'." +
"This will be removed in 1.4.",
- "FIRRTL 1.3.3")
+ "FIRRTL 1.3.3"
+ )
def compile(state: CircuitState, customTransforms: Seq[Transform]): CircuitState = {
- val transformManager = new stage.transforms.Compiler (
+ val transformManager = new stage.transforms.Compiler(
targets = (emitter +: customTransforms ++: transforms).map(Dependency.fromTransform),
currentState = prerequisites,
knownObjects = (transforms :+ emitter).toSet
diff --git a/src/main/scala/firrtl/DependencyAPIMigration.scala b/src/main/scala/firrtl/DependencyAPIMigration.scala
index 6a5ff642..dc5957f2 100644
--- a/src/main/scala/firrtl/DependencyAPIMigration.scala
+++ b/src/main/scala/firrtl/DependencyAPIMigration.scala
@@ -17,14 +17,10 @@ import firrtl.stage.TransformManager.TransformDependency
*/
trait DependencyAPIMigration { this: Transform =>
- @deprecated(
- "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3")
final override def inputForm: CircuitForm = UnknownForm
- @deprecated(
- "Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre",
- "FIRRTL 1.3")
+ @deprecated("Use Dependency API methods for equivalent functionality. See: https://bit.ly/2Voppre", "FIRRTL 1.3")
final override def outputForm: CircuitForm = UnknownForm
override def prerequisites: Seq[TransformDependency] = Seq.empty
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index 2050b235..28eb2d6a 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -13,7 +13,6 @@ import firrtl.stage.phases.DriverCompatibility
import firrtl.options.{Dependency, Phase, PhaseManager, StageUtils, Viewer}
import firrtl.options.phases.DeletedWrapper
-
/**
* The driver provides methods to access the firrtl compiler.
* Invoke the compiler with either a FirrtlExecutionOption
@@ -37,6 +36,7 @@ import firrtl.options.phases.DeletedWrapper
*/
@deprecated("Use firrtl.stage.FirrtlStage", "1.2")
object Driver {
+
/** Print a warning message
*
* @param message error message
@@ -71,7 +71,7 @@ object Driver {
* @return Annotations read from files
*/
def getAnnotations(
- optionsManager: ExecutionOptionsManager with HasFirrtlOptions
+ optionsManager: ExecutionOptionsManager with HasFirrtlOptions
): Seq[Annotation] = {
val firrtlConfig = optionsManager.firrtlOptions
@@ -92,11 +92,11 @@ object Driver {
// Warnings to get people to change to drop old API
if (firrtlConfig.annotationFileNameOverride.nonEmpty) {
val msg = "annotationFileNameOverride has been removed, file will be ignored! " +
- "Use annotationFileNames"
+ "Use annotationFileNames"
dramaticError(msg)
} else if (usingImplicitAnnoFile) {
val msg = "Implicit .anno file from top-name has been removed, file will be ignored!\n" +
- (" "*9) + "Use explicit -faf option or annotationFileNames"
+ (" " * 9) + "Use explicit -faf option or annotationFileNames"
dramaticError(msg)
}
@@ -126,7 +126,7 @@ object Driver {
private def getFileExtension(filename: String): FileExtension =
filename.drop(filename.lastIndexOf('.')) match {
case ".pb" => ProtoBufFile
- case _ => FirrtlFile // Default to FIRRTL File
+ case _ => FirrtlFile // Default to FIRRTL File
}
// Useful for handling erros in the options
@@ -143,7 +143,8 @@ object Driver {
val circuitSources = Map(
"firrtlSource" -> firrtlConfig.firrtlSource.isDefined,
"firrtlCircuit" -> firrtlConfig.firrtlCircuit.isDefined,
- "inputFileNameOverride" -> firrtlConfig.inputFileNameOverride.nonEmpty)
+ "inputFileNameOverride" -> firrtlConfig.inputFileNameOverride.nonEmpty
+ )
if (circuitSources.values.count(x => x) > 1) {
val msg = circuitSources.collect { case (s, true) => s }.mkString(" and ") +
" are set, only 1 can be set at a time!"
@@ -157,8 +158,9 @@ object Driver {
}
if (
optionsManager.topName.isEmpty &&
- firrtlConfig.inputFileNameOverride.nonEmpty &&
- firrtlConfig.outputFileNameOverride.isEmpty) {
+ firrtlConfig.inputFileNameOverride.nonEmpty &&
+ firrtlConfig.outputFileNameOverride.isEmpty
+ ) {
val message = "inputFileName set but neither top-name or output-file-override is set"
throw new OptionsException(message)
}
@@ -167,10 +169,9 @@ object Driver {
// TODO What does InfoMode mean to ProtoBuf?
getFileExtension(inputFileName) match {
case ProtoBufFile => proto.FromProto.fromFile(inputFileName)
- case FirrtlFile => Parser.parseFile(inputFileName, firrtlConfig.infoMode)
+ case FirrtlFile => Parser.parseFile(inputFileName, firrtlConfig.infoMode)
}
- }
- catch {
+ } catch {
case _: FileNotFoundException =>
val message = s"Input file $inputFileName not found"
throw new OptionsException(message)
@@ -195,20 +196,23 @@ object Driver {
val phases: Seq[Phase] = {
import DriverCompatibility._
new PhaseManager(
- List( Dependency[AddImplicitFirrtlFile],
- Dependency[AddImplicitAnnotationFile],
- Dependency[AddImplicitOutputFile],
- Dependency[AddImplicitEmitter],
- Dependency[FirrtlStage] ))
- .transformOrder
+ List(
+ Dependency[AddImplicitFirrtlFile],
+ Dependency[AddImplicitAnnotationFile],
+ Dependency[AddImplicitOutputFile],
+ Dependency[AddImplicitEmitter],
+ Dependency[FirrtlStage]
+ )
+ ).transformOrder
.map(DeletedWrapper(_))
}
- val annosx = try {
- phases.foldLeft(annos)( (a, p) => p.transform(a) )
- } catch {
- case e: firrtl.options.OptionsException => return FirrtlExecutionFailure(e.message)
- }
+ val annosx =
+ try {
+ phases.foldLeft(annos)((a, p) => p.transform(a))
+ } catch {
+ case e: firrtl.options.OptionsException => return FirrtlExecutionFailure(e.message)
+ }
Viewer[FirrtlExecutionResult].view(annosx)
}
@@ -223,7 +227,7 @@ object Driver {
def execute(args: Array[String]): FirrtlExecutionResult = {
val optionsManager = new ExecutionOptionsManager("firrtl") with HasFirrtlOptions
- if(optionsManager.parse(args)) {
+ if (optionsManager.parse(args)) {
execute(optionsManager) match {
case success: FirrtlExecutionSuccess =>
success
@@ -233,8 +237,7 @@ object Driver {
case result =>
throwInternalError(s"Error: Unknown Firrtl Execution result $result")
}
- }
- else {
+ } else {
FirrtlExecutionFailure("Could not parser command line options")
}
}
diff --git a/src/main/scala/firrtl/EmissionOption.scala b/src/main/scala/firrtl/EmissionOption.scala
index 91db1f53..d097e14a 100644
--- a/src/main/scala/firrtl/EmissionOption.scala
+++ b/src/main/scala/firrtl/EmissionOption.scala
@@ -2,8 +2,8 @@
package firrtl
-/**
- * Base type for emission customization options
+/**
+ * Base type for emission customization options
* NOTE: all the following traits must be mixed with SingleTargetAnnotation[T <: Named]
* in order to be taken into account in the Emitter
*/
@@ -24,40 +24,37 @@ case object MemoryEmissionOptionDefault extends MemoryEmissionOption
/** Emission customization options for registers */
trait RegisterEmissionOption extends EmissionOption {
+
/** when true the reset init value will be used to emit a bitstream preset */
- def useInitAsPreset : Boolean = false
-
+ def useInitAsPreset: Boolean = false
+
/** when true the initial randomization is disabled for this register */
- def disableRandomization : Boolean = false
+ def disableRandomization: Boolean = false
}
/** default Emitter behavior for registers */
-case object RegisterEmissionOptionDefault extends RegisterEmissionOption
-
+case object RegisterEmissionOptionDefault extends RegisterEmissionOption
/** Emission customization options for IO ports */
-trait PortEmissionOption extends EmissionOption
+trait PortEmissionOption extends EmissionOption
/** default Emitter behavior for IO ports */
-case object PortEmissionOptionDefault extends PortEmissionOption
-
+case object PortEmissionOptionDefault extends PortEmissionOption
/** Emission customization options for wires */
-trait WireEmissionOption extends EmissionOption
+trait WireEmissionOption extends EmissionOption
/** default Emitter behavior for wires */
-case object WireEmissionOptionDefault extends WireEmissionOption
-
+case object WireEmissionOptionDefault extends WireEmissionOption
/** Emission customization options for nodes */
-trait NodeEmissionOption extends EmissionOption
+trait NodeEmissionOption extends EmissionOption
/** default Emitter behavior for nodes */
-case object NodeEmissionOptionDefault extends NodeEmissionOption
-
+case object NodeEmissionOptionDefault extends NodeEmissionOption
/** Emission customization options for connect */
trait ConnectEmissionOption extends EmissionOption
/** default Emitter behavior for connect */
-case object ConnectEmissionOptionDefault extends ConnectEmissionOption
+case object ConnectEmissionOptionDefault extends ConnectEmissionOption
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index ae9a7dad..843c76a4 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -37,28 +37,38 @@ object EmitCircuitAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[String](
longOption = "emit-circuit",
- toAnnotationSeq = (a: String) => a match {
- case "chirrtl" => Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter),
- EmitCircuitAnnotation(classOf[ChirrtlEmitter]))
- case "high" => Seq(RunFirrtlTransformAnnotation(new HighFirrtlEmitter),
- EmitCircuitAnnotation(classOf[HighFirrtlEmitter]))
- case "middle" => Seq(RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter),
- EmitCircuitAnnotation(classOf[MiddleFirrtlEmitter]))
- case "low" => Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter),
- EmitCircuitAnnotation(classOf[LowFirrtlEmitter]))
- case "verilog" | "mverilog" => Seq(RunFirrtlTransformAnnotation(new VerilogEmitter),
- EmitCircuitAnnotation(classOf[VerilogEmitter]))
- case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter),
- EmitCircuitAnnotation(classOf[SystemVerilogEmitter]))
- case "experimental-btor2" => Seq(RunFirrtlTransformAnnotation(new Btor2Emitter),
- EmitCircuitAnnotation(classOf[Btor2Emitter]))
- case "experimental-smt2" => Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter),
- EmitCircuitAnnotation(classOf[SMTLibEmitter]))
- case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") },
+ toAnnotationSeq = (a: String) =>
+ a match {
+ case "chirrtl" =>
+ Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter), EmitCircuitAnnotation(classOf[ChirrtlEmitter]))
+ case "high" =>
+ Seq(RunFirrtlTransformAnnotation(new HighFirrtlEmitter), EmitCircuitAnnotation(classOf[HighFirrtlEmitter]))
+ case "middle" =>
+ Seq(
+ RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter),
+ EmitCircuitAnnotation(classOf[MiddleFirrtlEmitter])
+ )
+ case "low" =>
+ Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitCircuitAnnotation(classOf[LowFirrtlEmitter]))
+ case "verilog" | "mverilog" =>
+ Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitCircuitAnnotation(classOf[VerilogEmitter]))
+ case "sverilog" =>
+ Seq(
+ RunFirrtlTransformAnnotation(new SystemVerilogEmitter),
+ EmitCircuitAnnotation(classOf[SystemVerilogEmitter])
+ )
+ case "experimental-btor2" =>
+ Seq(RunFirrtlTransformAnnotation(new Btor2Emitter), EmitCircuitAnnotation(classOf[Btor2Emitter]))
+ case "experimental-smt2" =>
+ Seq(RunFirrtlTransformAnnotation(new SMTLibEmitter), EmitCircuitAnnotation(classOf[SMTLibEmitter]))
+ case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)")
+ },
helpText = "Run the specified circuit emitter (all modules in one file)",
shortOption = Some("E"),
// the experimental options are intentionally excluded from the help message
- helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) )
+ helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>")
+ )
+ )
}
@@ -67,30 +77,43 @@ object EmitAllModulesAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[String](
longOption = "emit-modules",
- toAnnotationSeq = (a: String) => a match {
- case "chirrtl" => Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter),
- EmitAllModulesAnnotation(classOf[ChirrtlEmitter]))
- case "high" => Seq(RunFirrtlTransformAnnotation(new HighFirrtlEmitter),
- EmitAllModulesAnnotation(classOf[HighFirrtlEmitter]))
- case "middle" => Seq(RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter),
- EmitAllModulesAnnotation(classOf[MiddleFirrtlEmitter]))
- case "low" => Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter),
- EmitAllModulesAnnotation(classOf[LowFirrtlEmitter]))
- case "verilog" | "mverilog" => Seq(RunFirrtlTransformAnnotation(new VerilogEmitter),
- EmitAllModulesAnnotation(classOf[VerilogEmitter]))
- case "sverilog" => Seq(RunFirrtlTransformAnnotation(new SystemVerilogEmitter),
- EmitAllModulesAnnotation(classOf[SystemVerilogEmitter]))
- case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)") },
+ toAnnotationSeq = (a: String) =>
+ a match {
+ case "chirrtl" =>
+ Seq(RunFirrtlTransformAnnotation(new ChirrtlEmitter), EmitAllModulesAnnotation(classOf[ChirrtlEmitter]))
+ case "high" =>
+ Seq(
+ RunFirrtlTransformAnnotation(new HighFirrtlEmitter),
+ EmitAllModulesAnnotation(classOf[HighFirrtlEmitter])
+ )
+ case "middle" =>
+ Seq(
+ RunFirrtlTransformAnnotation(new MiddleFirrtlEmitter),
+ EmitAllModulesAnnotation(classOf[MiddleFirrtlEmitter])
+ )
+ case "low" =>
+ Seq(RunFirrtlTransformAnnotation(new LowFirrtlEmitter), EmitAllModulesAnnotation(classOf[LowFirrtlEmitter]))
+ case "verilog" | "mverilog" =>
+ Seq(RunFirrtlTransformAnnotation(new VerilogEmitter), EmitAllModulesAnnotation(classOf[VerilogEmitter]))
+ case "sverilog" =>
+ Seq(
+ RunFirrtlTransformAnnotation(new SystemVerilogEmitter),
+ EmitAllModulesAnnotation(classOf[SystemVerilogEmitter])
+ )
+ case _ => throw new PhaseException(s"Unknown emitter '$a'! (Did you misspell it?)")
+ },
helpText = "Run the specified module emitter (one file per module)",
shortOption = Some("e"),
- helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>") ) )
+ helpValueName = Some("<chirrtl|high|middle|low|verilog|mverilog|sverilog>")
+ )
+ )
}
// ***** Annotations for results of emission *****
sealed abstract class EmittedComponent {
- def name: String
- def value: String
+ def name: String
+ def value: String
def outputSuffix: String
}
sealed abstract class EmittedCircuit extends EmittedComponent
@@ -147,7 +170,7 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em
// Use list instead of set to maintain order
val modules = mutable.ArrayBuffer.empty[DefModule]
def onStmt(stmt: Statement): Unit = stmt match {
- case DefInstance(_, _, name, _) => modules += map(name)
+ case DefInstance(_, _, name, _) => modules += map(name)
case WDefInstance(_, _, name, _) => modules += map(name)
case _: WDefInstanceConnector => throwInternalError(s"unrecognized statement: $stmt")
case other => other.foreach(onStmt)
@@ -157,24 +180,28 @@ sealed abstract class FirrtlEmitter(form: CircuitForm) extends Transform with Em
}
val modMap = circuit.modules.map(m => m.name -> m).toMap
// Turn each module into it's own circuit with it as the top and all instantied modules as ExtModules
- circuit.modules collect { case m: Module =>
- val instModules = collectInstantiatedModules(m, modMap)
- val extModules = instModules map {
- case Module(info, name, ports, _) => ExtModule(info, name, ports, name, Seq.empty)
- case ext: ExtModule => ext
- }
- val newCircuit = Circuit(m.info, extModules :+ m, m.name)
- EmittedFirrtlModule(m.name, newCircuit.serialize, outputSuffix)
+ circuit.modules.collect {
+ case m: Module =>
+ val instModules = collectInstantiatedModules(m, modMap)
+ val extModules = instModules.map {
+ case Module(info, name, ports, _) => ExtModule(info, name, ports, name, Seq.empty)
+ case ext: ExtModule => ext
+ }
+ val newCircuit = Circuit(m.info, extModules :+ m, m.name)
+ EmittedFirrtlModule(m.name, newCircuit.serialize, outputSuffix)
}
}
override def execute(state: CircuitState): CircuitState = {
val newAnnos = state.annotations.flatMap {
case EmitCircuitAnnotation(a) if this.getClass == a =>
- Seq(EmittedFirrtlCircuitAnnotation(
- EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, outputSuffix)))
+ Seq(
+ EmittedFirrtlCircuitAnnotation(
+ EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, outputSuffix)
+ )
+ )
case EmitAllModulesAnnotation(a) if this.getClass == a =>
- emitAllModules(state.circuit) map (EmittedFirrtlModuleAnnotation(_))
+ emitAllModules(state.circuit).map(EmittedFirrtlModuleAnnotation(_))
case _ => Seq()
}
state.copy(annotations = newAnnos ++ state.annotations)
@@ -195,12 +222,12 @@ case class VRandom(width: BigInt) extends Expression {
def nWords = (width + 31) / 32
def realWidth = nWords * 32
override def serialize: String = "RANDOM"
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
}
class VerilogEmitter extends SeqTransform with Emitter {
@@ -221,14 +248,16 @@ class VerilogEmitter extends SeqTransform with Emitter {
else if (e2 == we(one)) e1.e1
else DoPrim(And, Seq(e1.e1, e2.e1), Nil, UIntType(IntWidth(1)))
}
- def wref(n: String, t: Type) = WRef(n, t, ExpKind, UnknownFlow)
+ def wref(n: String, t: Type) = WRef(n, t, ExpKind, UnknownFlow)
def remove_root(ex: Expression): Expression = ex match {
- case ex: WSubField => ex.expr match {
- case (e: WSubField) => remove_root(e)
- case (_: WRef) => WRef(ex.name, ex.tpe, InstanceKind, UnknownFlow)
- }
+ case ex: WSubField =>
+ ex.expr match {
+ case (e: WSubField) => remove_root(e)
+ case (_: WRef) => WRef(ex.name, ex.tpe, InstanceKind, UnknownFlow)
+ }
case _ => throwInternalError(s"shouldn't be here: remove_root($ex)")
}
+
/** Turn Params into Verilog Strings */
def stringify(param: Param): String = param match {
case IntParam(name, value) =>
@@ -237,11 +266,11 @@ class VerilogEmitter extends SeqTransform with Emitter {
s"$value"
} else {
val blen = value.bitLength
- if (value > 0) s"$blen'd$value" else s"-${blen+1}'sd${value.abs}"
+ if (value > 0) s"$blen'd$value" else s"-${blen + 1}'sd${value.abs}"
}
s".$name($lit)"
- case DoubleParam(name, value) => s".$name($value)"
- case StringParam(name, value) => s".${name}(${value.verilogEscape})"
+ case DoubleParam(name, value) => s".$name($value)"
+ case StringParam(name, value) => s".${name}(${value.verilogEscape})"
case RawStringParam(name, value) => s".$name($value)"
}
def stringify(tpe: GroundType): String = tpe match {
@@ -249,16 +278,16 @@ class VerilogEmitter extends SeqTransform with Emitter {
val wx = bitWidth(tpe) - 1
if (wx > 0) s"[$wx:0]" else ""
case ClockType | AsyncResetType => ""
- case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe")
+ case _ => throwInternalError(s"trying to write unsupported type in the Verilog Emitter: $tpe")
}
def emit(x: Any)(implicit w: Writer): Unit = { emit(x, 0) }
def emit(x: Any, top: Int)(implicit w: Writer): Unit = {
def cast(e: Expression): Any = e.tpe match {
case (t: UIntType) => e
- case (t: SIntType) => Seq("$signed(",e,")")
- case ClockType => e
+ case (t: SIntType) => Seq("$signed(", e, ")")
+ case ClockType => e
case AnalogType(_) => e
- case _ => throwInternalError(s"unrecognized cast: $e")
+ case _ => throwInternalError(s"unrecognized cast: $e")
}
x match {
case (e: DoPrim) => emit(op_stream(e), top + 1)
@@ -269,186 +298,190 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (e.tpe == AsyncResetType) {
throw EmitterException("Cannot emit async reset muxes directly")
}
- emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1)
+ emit(Seq(e.cond, " ? ", cast(e.tval), " : ", cast(e.fval)), top + 1)
}
- case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1)
- case (e: WRef) => w write e.serialize
- case (e: WSubField) => w write LowerTypes.loweredName(e)
- case (e: WSubAccess) => w write s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]"
- case (e: WSubIndex) => w write e.serialize
- case (e: Literal) => v_print(e)
- case (e: VRandom) => w write s"{${e.nWords}{`RANDOM}}"
- case (t: GroundType) => w write stringify(t)
+ case (e: ValidIf) => emit(Seq(cast(e.value)), top + 1)
+ case (e: WRef) => w.write(e.serialize)
+ case (e: WSubField) => w.write(LowerTypes.loweredName(e))
+ case (e: WSubAccess) => w.write(s"${LowerTypes.loweredName(e.expr)}[${LowerTypes.loweredName(e.index)}]")
+ case (e: WSubIndex) => w.write(e.serialize)
+ case (e: Literal) => v_print(e)
+ case (e: VRandom) => w.write(s"{${e.nWords}{`RANDOM}}")
+ case (t: GroundType) => w.write(stringify(t))
case (t: VectorType) =>
emit(t.tpe, top + 1)
- w write s"[${t.size - 1}:0]"
- case (s: String) => w write s
- case (i: Int) => w write i.toString
- case (i: Long) => w write i.toString
- case (i: BigInt) => w write i.toString
- case (i: Info) => i match {
- case NoInfo => // Do nothing
- case f: FileInfo =>
- val escaped = FileInfo.escapedToVerilog(f.escaped)
- w.write(s" // @[$escaped]")
- case m: MultiInfo =>
- val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" "))
- w.write(s" // @[$escaped]")
- }
+ w.write(s"[${t.size - 1}:0]")
+ case (s: String) => w.write(s)
+ case (i: Int) => w.write(i.toString)
+ case (i: Long) => w.write(i.toString)
+ case (i: BigInt) => w.write(i.toString)
+ case (i: Info) =>
+ i match {
+ case NoInfo => // Do nothing
+ case f: FileInfo =>
+ val escaped = FileInfo.escapedToVerilog(f.escaped)
+ w.write(s" // @[$escaped]")
+ case m: MultiInfo =>
+ val escaped = FileInfo.escapedToVerilog(m.flatten.map(_.escaped).mkString(" "))
+ w.write(s" // @[$escaped]")
+ }
case (s: Seq[Any]) =>
- s foreach (emit(_, top + 1))
- if (top == 0) w write "\n"
+ s.foreach(emit(_, top + 1))
+ if (top == 0) w.write("\n")
case x => throwInternalError(s"trying to emit unsupported operator: $x")
}
}
- //;------------- PASS -----------------
- def v_print(e: Expression)(implicit w: Writer) = e match {
- case UIntLiteral(value, IntWidth(width)) =>
- w write s"$width'h${value.toString(16)}"
- case SIntLiteral(value, IntWidth(width)) =>
- val stringLiteral = value.toString(16)
- w write (stringLiteral.head match {
- case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}"
- case '-' => s"-$width'sh${stringLiteral.tail}"
- case _ => s"$width'sh${stringLiteral}"
- })
- case _ => throwInternalError(s"attempt to print unrecognized expression: $e")
- }
-
- // NOTE: We emit SInts as regular Verilog unsigned wires/regs so the real type of any SInt
- // reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary
- // to ensure Verilog operations are signed.
- def op_stream(doprim: DoPrim): Seq[Any] = {
- // Cast to SInt, don't cast multiple times
- def doCast(e: Expression): Any = e match {
- case DoPrim(AsSInt, Seq(arg), _,_) => doCast(arg)
- case slit: SIntLiteral => slit
- case other => Seq("$signed(", other, ")")
- }
- def castIf(e: Expression): Any = {
- if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) {
- e.tpe match {
- case _: SIntType => doCast(e)
- case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim")
- }
- } else {
- e
- }
- }
- def cast(e: Expression): Any = doprim.tpe match {
- case _: UIntType => e
- case _: SIntType => doCast(e)
- case _ => throwInternalError(s"Unexpected type for $e in $doprim")
- }
- def castAs(e: Expression): Any = e.tpe match {
- case _: UIntType => e
- case _: SIntType => doCast(e)
- case _ => throwInternalError(s"Unexpected type for $e in $doprim")
- }
- def a0: Expression = doprim.args.head
- def a1: Expression = doprim.args(1)
- def c0: Int = doprim.consts.head.toInt
- def c1: Int = doprim.consts(1).toInt
-
- def checkArgumentLegality(e: Expression): Unit = e match {
- case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
- case DoPrim(Not, args, _,_) => args.foreach(checkArgumentLegality)
- case DoPrim(op, args, _,_) if isCast(op) => args.foreach(checkArgumentLegality)
- case DoPrim(op, args, _,_) if isBitExtract(op) => args.foreach(checkArgumentLegality)
- case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
- }
-
- def checkCatArgumentLegality(e: Expression): Unit = e match {
- case DoPrim(Cat, args, _, _) => args foreach(checkCatArgumentLegality)
- case _ => checkArgumentLegality(e)
- }
-
- def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = {
- val a0Seq = a0 match {
- case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
- case _ => Seq(cast(a0))
- }
- val a1Seq = a1 match {
- case cat@DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
- case _ => Seq(cast(a1))
- }
- a0Seq ++ Seq(",") ++ a1Seq
- }
-
- doprim.op match {
- case Cat => doprim.args foreach(checkCatArgumentLegality)
- case cast if isCast(cast) => // Casts are allowed to wrap any Expression
- case other => doprim.args foreach checkArgumentLegality
- }
- doprim.op match {
- case Add => Seq(castIf(a0), " + ", castIf(a1))
- case Addw => Seq(castIf(a0), " + ", castIf(a1))
- case Sub => Seq(castIf(a0), " - ", castIf(a1))
- case Subw => Seq(castIf(a0), " - ", castIf(a1))
- case Mul => Seq(castIf(a0), " * ", castIf(a1))
- case Div => Seq(castIf(a0), " / ", castIf(a1))
- case Rem => Seq(castIf(a0), " % ", castIf(a1))
- case Lt => Seq(castIf(a0), " < ", castIf(a1))
- case Leq => Seq(castIf(a0), " <= ", castIf(a1))
- case Gt => Seq(castIf(a0), " > ", castIf(a1))
- case Geq => Seq(castIf(a0), " >= ", castIf(a1))
- case Eq => Seq(castIf(a0), " == ", castIf(a1))
- case Neq => Seq(castIf(a0), " != ", castIf(a1))
- case Pad =>
- val w = bitWidth(a0.tpe)
- val diff = c0 - w
- if (w == BigInt(0) || diff <= 0) Seq(a0)
- else doprim.tpe match {
- // Either sign extend or zero extend.
- // If width == BigInt(1), don't extract bit
- case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}")
- case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}")
- case (_) => Seq("{{", diff, "'d0}, ", a0, "}")
- }
- // Because we don't support complex Expressions, all casts are ignored
- // This simplifies handling of assignment of a signed expression to an unsigned LHS value
- // which does not require a cast in Verilog
- case AsUInt | AsSInt | AsClock | AsAsyncReset => Seq(a0)
- case Dshlw => Seq(cast(a0), " << ", a1)
- case Dshl => Seq(cast(a0), " << ", a1)
- case Dshr => doprim.tpe match {
- case (_: SIntType) => Seq(cast(a0)," >>> ", a1)
- case (_) => Seq(cast(a0), " >> ", a1)
- }
- case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0))
- case Shr if c0 >= bitWidth(a0.tpe) =>
- error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
- case Shr if c0 == (bitWidth(a0.tpe)-1) => Seq(a0,"[", bitWidth(a0.tpe) - 1, "]")
- case Shr => Seq(a0,"[", bitWidth(a0.tpe) - 1, ":", c0, "]")
- case Neg => Seq("-", cast(a0))
- case Cvt => a0.tpe match {
- case (_: UIntType) => Seq("{1'b0,", cast(a0), "}")
- case (_: SIntType) => Seq(cast(a0))
- }
- case Not => Seq("~", a0)
- case And => Seq(castAs(a0), " & ", castAs(a1))
- case Or => Seq(castAs(a0), " | ", castAs(a1))
- case Xor => Seq(castAs(a0), " ^ ", castAs(a1))
- case Andr => Seq("&", cast(a0))
- case Orr => Seq("|", cast(a0))
- case Xorr => Seq("^", cast(a0))
- case Cat => "{" +: (castCatArgs(a0, a1) :+ "}")
- // If selecting zeroth bit and single-bit wire, just emit the wire
- case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
- case Bits if c0 == c1 => Seq(a0, "[", c0, "]")
- case Bits => Seq(a0, "[", c0, ":", c1, "]")
- // If selecting zeroth bit and single-bit wire, just emit the wire
- case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
- case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe)-1, "]")
- case Head =>
- val msb = bitWidth(a0.tpe) - 1
- val lsb = bitWidth(a0.tpe) - c0
- Seq(a0, "[", msb, ":", lsb, "]")
- case Tail if c0 == (bitWidth(a0.tpe)-1) => Seq(a0, "[0]")
- case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
- }
- }
+ //;------------- PASS -----------------
+ def v_print(e: Expression)(implicit w: Writer) = e match {
+ case UIntLiteral(value, IntWidth(width)) =>
+ w.write(s"$width'h${value.toString(16)}")
+ case SIntLiteral(value, IntWidth(width)) =>
+ val stringLiteral = value.toString(16)
+ w.write(stringLiteral.head match {
+ case '-' if value == FixAddingNegativeLiterals.minNegValue(width) => s"$width'sh${stringLiteral.tail}"
+ case '-' => s"-$width'sh${stringLiteral.tail}"
+ case _ => s"$width'sh${stringLiteral}"
+ })
+ case _ => throwInternalError(s"attempt to print unrecognized expression: $e")
+ }
+
+ // NOTE: We emit SInts as regular Verilog unsigned wires/regs so the real type of any SInt
+ // reference is actually unsigned in the emitted Verilog. Thus we must cast refs as necessary
+ // to ensure Verilog operations are signed.
+ def op_stream(doprim: DoPrim): Seq[Any] = {
+ // Cast to SInt, don't cast multiple times
+ def doCast(e: Expression): Any = e match {
+ case DoPrim(AsSInt, Seq(arg), _, _) => doCast(arg)
+ case slit: SIntLiteral => slit
+ case other => Seq("$signed(", other, ")")
+ }
+ def castIf(e: Expression): Any = {
+ if (doprim.args.exists(_.tpe.isInstanceOf[SIntType])) {
+ e.tpe match {
+ case _: SIntType => doCast(e)
+ case _ => throwInternalError(s"Unexpected non-SInt type for $e in $doprim")
+ }
+ } else {
+ e
+ }
+ }
+ def cast(e: Expression): Any = doprim.tpe match {
+ case _: UIntType => e
+ case _: SIntType => doCast(e)
+ case _ => throwInternalError(s"Unexpected type for $e in $doprim")
+ }
+ def castAs(e: Expression): Any = e.tpe match {
+ case _: UIntType => e
+ case _: SIntType => doCast(e)
+ case _ => throwInternalError(s"Unexpected type for $e in $doprim")
+ }
+ def a0: Expression = doprim.args.head
+ def a1: Expression = doprim.args(1)
+ def c0: Int = doprim.consts.head.toInt
+ def c1: Int = doprim.consts(1).toInt
+
+ def checkArgumentLegality(e: Expression): Unit = e match {
+ case _: UIntLiteral | _: SIntLiteral | _: WRef | _: WSubField =>
+ case DoPrim(Not, args, _, _) => args.foreach(checkArgumentLegality)
+ case DoPrim(op, args, _, _) if isCast(op) => args.foreach(checkArgumentLegality)
+ case DoPrim(op, args, _, _) if isBitExtract(op) => args.foreach(checkArgumentLegality)
+ case _ => throw EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument")
+ }
+
+ def checkCatArgumentLegality(e: Expression): Unit = e match {
+ case DoPrim(Cat, args, _, _) => args.foreach(checkCatArgumentLegality)
+ case _ => checkArgumentLegality(e)
+ }
+
+ def castCatArgs(a0: Expression, a1: Expression): Seq[Any] = {
+ val a0Seq = a0 match {
+ case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
+ case _ => Seq(cast(a0))
+ }
+ val a1Seq = a1 match {
+ case cat @ DoPrim(PrimOps.Cat, args, _, _) => castCatArgs(args.head, args(1))
+ case _ => Seq(cast(a1))
+ }
+ a0Seq ++ Seq(",") ++ a1Seq
+ }
+
+ doprim.op match {
+ case Cat => doprim.args.foreach(checkCatArgumentLegality)
+ case cast if isCast(cast) => // Casts are allowed to wrap any Expression
+ case other => doprim.args.foreach(checkArgumentLegality)
+ }
+ doprim.op match {
+ case Add => Seq(castIf(a0), " + ", castIf(a1))
+ case Addw => Seq(castIf(a0), " + ", castIf(a1))
+ case Sub => Seq(castIf(a0), " - ", castIf(a1))
+ case Subw => Seq(castIf(a0), " - ", castIf(a1))
+ case Mul => Seq(castIf(a0), " * ", castIf(a1))
+ case Div => Seq(castIf(a0), " / ", castIf(a1))
+ case Rem => Seq(castIf(a0), " % ", castIf(a1))
+ case Lt => Seq(castIf(a0), " < ", castIf(a1))
+ case Leq => Seq(castIf(a0), " <= ", castIf(a1))
+ case Gt => Seq(castIf(a0), " > ", castIf(a1))
+ case Geq => Seq(castIf(a0), " >= ", castIf(a1))
+ case Eq => Seq(castIf(a0), " == ", castIf(a1))
+ case Neq => Seq(castIf(a0), " != ", castIf(a1))
+ case Pad =>
+ val w = bitWidth(a0.tpe)
+ val diff = c0 - w
+ if (w == BigInt(0) || diff <= 0) Seq(a0)
+ else
+ doprim.tpe match {
+ // Either sign extend or zero extend.
+ // If width == BigInt(1), don't extract bit
+ case (_: SIntType) if w == BigInt(1) => Seq("{", c0, "{", a0, "}}")
+ case (_: SIntType) => Seq("{{", diff, "{", a0, "[", w - 1, "]}},", a0, "}")
+ case (_) => Seq("{{", diff, "'d0}, ", a0, "}")
+ }
+ // Because we don't support complex Expressions, all casts are ignored
+ // This simplifies handling of assignment of a signed expression to an unsigned LHS value
+ // which does not require a cast in Verilog
+ case AsUInt | AsSInt | AsClock | AsAsyncReset => Seq(a0)
+ case Dshlw => Seq(cast(a0), " << ", a1)
+ case Dshl => Seq(cast(a0), " << ", a1)
+ case Dshr =>
+ doprim.tpe match {
+ case (_: SIntType) => Seq(cast(a0), " >>> ", a1)
+ case (_) => Seq(cast(a0), " >> ", a1)
+ }
+ case Shl => if (c0 > 0) Seq("{", cast(a0), s", $c0'h0}") else Seq(cast(a0))
+ case Shr if c0 >= bitWidth(a0.tpe) =>
+ error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
+ case Shr if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]")
+ case Shr => Seq(a0, "[", bitWidth(a0.tpe) - 1, ":", c0, "]")
+ case Neg => Seq("-", cast(a0))
+ case Cvt =>
+ a0.tpe match {
+ case (_: UIntType) => Seq("{1'b0,", cast(a0), "}")
+ case (_: SIntType) => Seq(cast(a0))
+ }
+ case Not => Seq("~", a0)
+ case And => Seq(castAs(a0), " & ", castAs(a1))
+ case Or => Seq(castAs(a0), " | ", castAs(a1))
+ case Xor => Seq(castAs(a0), " ^ ", castAs(a1))
+ case Andr => Seq("&", cast(a0))
+ case Orr => Seq("|", cast(a0))
+ case Xorr => Seq("^", cast(a0))
+ case Cat => "{" +: (castCatArgs(a0, a1) :+ "}")
+ // If selecting zeroth bit and single-bit wire, just emit the wire
+ case Bits if c0 == 0 && c1 == 0 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
+ case Bits if c0 == c1 => Seq(a0, "[", c0, "]")
+ case Bits => Seq(a0, "[", c0, ":", c1, "]")
+ // If selecting zeroth bit and single-bit wire, just emit the wire
+ case Head if c0 == 1 && bitWidth(a0.tpe) == BigInt(1) => Seq(a0)
+ case Head if c0 == 1 => Seq(a0, "[", bitWidth(a0.tpe) - 1, "]")
+ case Head =>
+ val msb = bitWidth(a0.tpe) - 1
+ val lsb = bitWidth(a0.tpe) - c0
+ Seq(a0, "[", msb, ":", lsb, "]")
+ case Tail if c0 == (bitWidth(a0.tpe) - 1) => Seq(a0, "[0]")
+ case Tail => Seq(a0, "[", bitWidth(a0.tpe) - c0 - 1, ":0]")
+ }
+ }
/**
* Gets a reference to a verilog renderer. This is used by the current standard verilog emission process
@@ -475,31 +508,43 @@ class VerilogEmitter extends SeqTransform with Emitter {
* @param writer where rendering will be placed
* @return the render reference
*/
- def getRenderer(descriptions: Seq[DescriptionAnnotation],
- m: Module,
- moduleMap: Map[String, DefModule])(implicit writer: Writer): VerilogRender = {
+ def getRenderer(
+ descriptions: Seq[DescriptionAnnotation],
+ m: Module,
+ moduleMap: Map[String, DefModule]
+ )(
+ implicit writer: Writer
+ ): VerilogRender = {
val newMod = new AddDescriptionNodes().executeModule(m, descriptions)
newMod match {
- case DescribedMod(d, pds, m: Module) => new VerilogRender(d, pds, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer)
+ case DescribedMod(d, pds, m: Module) =>
+ new VerilogRender(d, pds, m, moduleMap, "", new EmissionOptions(Seq.empty))(writer)
case m: Module => new VerilogRender(m, moduleMap)(writer)
}
}
- def addFormalStatement(formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]],
- clk: Expression, en: Expression,
- stmt: Seq[Any], info: Info, msg: StringLit): Unit = {
- throw EmitterException("Cannot emit verification statements in Verilog" +
- "(2001). Use the SystemVerilog emitter instead.")
+ def addFormalStatement(
+ formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]],
+ clk: Expression,
+ en: Expression,
+ stmt: Seq[Any],
+ info: Info,
+ msg: StringLit
+ ): Unit = {
+ throw EmitterException(
+ "Cannot emit verification statements in Verilog" +
+ "(2001). Use the SystemVerilog emitter instead."
+ )
}
/**
* Store Emission option per Target
* Guarantee only one emission option per Target
*/
- private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df : V) {
+ private[firrtl] class EmissionOptionMap[V <: EmissionOption](val df: V) {
private val m = collection.mutable.HashMap[ReferenceTarget, V]().withDefaultValue(df)
- def +=(elem : (ReferenceTarget, V)) : EmissionOptionMap.this.type = {
+ def +=(elem: (ReferenceTarget, V)): EmissionOptionMap.this.type = {
if (m.contains(elem._1))
throw EmitterException(s"Multiple EmissionOption for the target ${elem._1} (${m(elem._1)} ; ${elem._2})")
m += (elem)
@@ -511,7 +556,6 @@ class VerilogEmitter extends SeqTransform with Emitter {
/** Provide API to retrieve EmissionOptions based on the provided [[AnnotationSeq]]
*
* @param annotations : AnnotationSeq to be searched for EmissionOptions
- *
*/
private[firrtl] class EmissionOptions(annotations: AnnotationSeq) {
// Private so that we can present an immutable API
@@ -540,16 +584,34 @@ class VerilogEmitter extends SeqTransform with Emitter {
def getConnectEmissionOption(target: ReferenceTarget): ConnectEmissionOption =
connectEmissionOption(target)
- private val emissionAnnos = annotations.collect{
- case m : SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m
+ private val emissionAnnos = annotations.collect {
+ case m: SingleTargetAnnotation[ReferenceTarget] @unchecked with EmissionOption => m
}
// using multiple foreach instead of a single partial function as an Annotation can gather multiple EmissionOptions for simplicity
- emissionAnnos.foreach { case a :MemoryEmissionOption => memoryEmissionOption += ((a.target,a)) case _ => }
- emissionAnnos.foreach { case a :RegisterEmissionOption => registerEmissionOption += ((a.target,a)) case _ => }
- emissionAnnos.foreach { case a :WireEmissionOption => wireEmissionOption += ((a.target,a)) case _ => }
- emissionAnnos.foreach { case a :PortEmissionOption => portEmissionOption += ((a.target,a)) case _ => }
- emissionAnnos.foreach { case a :NodeEmissionOption => nodeEmissionOption += ((a.target,a)) case _ => }
- emissionAnnos.foreach { case a :ConnectEmissionOption => connectEmissionOption += ((a.target,a)) case _ => }
+ emissionAnnos.foreach {
+ case a: MemoryEmissionOption => memoryEmissionOption += ((a.target, a))
+ case _ =>
+ }
+ emissionAnnos.foreach {
+ case a: RegisterEmissionOption => registerEmissionOption += ((a.target, a))
+ case _ =>
+ }
+ emissionAnnos.foreach {
+ case a: WireEmissionOption => wireEmissionOption += ((a.target, a))
+ case _ =>
+ }
+ emissionAnnos.foreach {
+ case a: PortEmissionOption => portEmissionOption += ((a.target, a))
+ case _ =>
+ }
+ emissionAnnos.foreach {
+ case a: NodeEmissionOption => nodeEmissionOption += ((a.target, a))
+ case _ =>
+ }
+ emissionAnnos.foreach {
+ case a: ConnectEmissionOption => connectEmissionOption += ((a.target, a))
+ case _ =>
+ }
}
/**
@@ -562,14 +624,24 @@ class VerilogEmitter extends SeqTransform with Emitter {
* @param moduleMap a map of modules so submodules can be discovered
* @param writer where rendered information is placed.
*/
- class VerilogRender(description: Seq[Description],
- portDescriptions: Map[String, Seq[Description]],
- m: Module,
- moduleMap: Map[String, DefModule],
- circuitName: String,
- emissionOptions: EmissionOptions)(implicit writer: Writer) {
-
- def this(m: Module, moduleMap: Map[String, DefModule], circuitName: String, emissionOptions: EmissionOptions)(implicit writer: Writer) = {
+ class VerilogRender(
+ description: Seq[Description],
+ portDescriptions: Map[String, Seq[Description]],
+ m: Module,
+ moduleMap: Map[String, DefModule],
+ circuitName: String,
+ emissionOptions: EmissionOptions
+ )(
+ implicit writer: Writer) {
+
+ def this(
+ m: Module,
+ moduleMap: Map[String, DefModule],
+ circuitName: String,
+ emissionOptions: EmissionOptions
+ )(
+ implicit writer: Writer
+ ) = {
this(Seq(), Map.empty, m, moduleMap, circuitName, emissionOptions)(writer)
}
def this(m: Module, moduleMap: Map[String, DefModule])(implicit writer: Writer) = {
@@ -582,7 +654,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
def build_netlist(s: Statement): Unit = {
s.foreach(build_netlist)
s match {
- case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr)
+ case sx: Connect => netlist(sx.loc) = InfoExpr(sx.info, sx.expr)
case sx: IsInvalid => error("Should have removed these!")
// TODO Since only register update and memories use the netlist anymore, I think nodes are
// unnecessary
@@ -642,7 +714,14 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (bi.isValidInt) bi.toString else s"${bi.bitLength}'d$bi"
// declare vector type with no preset and optionally with an ifdef guard
- private def declareVectorType(b: String, n: String, tpe: Type, size: BigInt, info: Info, ifdefOpt: Option[String]): Unit = {
+ private def declareVectorType(
+ b: String,
+ n: String,
+ tpe: Type,
+ size: BigInt,
+ info: Info,
+ ifdefOpt: Option[String]
+ ): Unit = {
val decl = Seq(b, " ", tpe, " ", n, " [0:", bigIntToVLit(size - 1), "];", info)
if (ifdefOpt.isDefined) {
ifdefDeclares(ifdefOpt.get) += decl
@@ -675,7 +754,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case tx: VectorType =>
declareVectorType(b, n, tx.tpe, tx.size, info, ifdefOpt)
case tx =>
- val decl = Seq(b, " ", tx, " ", n,";",info)
+ val decl = Seq(b, " ", tx, " ", n, ";", info)
if (ifdefOpt.isDefined) {
ifdefDeclares(ifdefOpt.get) += decl
} else {
@@ -703,8 +782,18 @@ class VerilogEmitter extends SeqTransform with Emitter {
assigns += Seq("`ifndef RANDOMIZE_GARBAGE_ASSIGN")
assigns += Seq("assign ", e, " = ", syn, ";", info)
assigns += Seq("`else")
- assigns += Seq("assign ", e, " = ", garbageCond, " ? ", rand_string(syn.tpe, "RANDOMIZE_GARBAGE_ASSIGN"), " : ", syn,
- ";", info)
+ assigns += Seq(
+ "assign ",
+ e,
+ " = ",
+ garbageCond,
+ " ? ",
+ rand_string(syn.tpe, "RANDOMIZE_GARBAGE_ASSIGN"),
+ " : ",
+ syn,
+ ";",
+ info
+ )
assigns += Seq("`endif // RANDOMIZE_GARBAGE_ASSIGN")
}
@@ -721,12 +810,12 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (m.tpe == AsyncResetType) throw EmitterException("Cannot emit async reset muxes directly")
val (eninfo, tinfo, finfo) = MultiInfo.demux(info)
- lazy val _if = Seq(tabs, "if (", m.cond, ") begin", eninfo)
- lazy val _else = Seq(tabs, "end else begin")
- lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo)
- lazy val _end = Seq(tabs, "end")
- lazy val _true = addUpdate(tinfo, m.tval, tabs + tab)
- lazy val _false = addUpdate(finfo, m.fval, tabs + tab)
+ lazy val _if = Seq(tabs, "if (", m.cond, ") begin", eninfo)
+ lazy val _else = Seq(tabs, "end else begin")
+ lazy val _ifNot = Seq(tabs, "if (!(", m.cond, ")) begin", eninfo)
+ lazy val _end = Seq(tabs, "end")
+ lazy val _true = addUpdate(tinfo, m.tval, tabs + tab)
+ lazy val _false = addUpdate(finfo, m.fval, tabs + tab)
lazy val _elseIfFalse = {
val _falsex = addUpdate(finfo, m.fval, tabs) // _false, but without an additional tab
Seq(tabs, "end else ", _falsex.head.tail) +: _falsex.tail
@@ -743,13 +832,14 @@ class VerilogEmitter extends SeqTransform with Emitter {
*/
(m.tval, m.fval) match {
case (t, f) if weq(t, r) && weq(f, r) => Nil
- case (t, _) if weq(t, r) => _ifNot +: _false :+ _end
- case (_, f) if weq(f, r) => m.cond.tpe match {
- case AsyncResetType => (_if +: _true :+ _else) ++ _true :+ _end
- case _ => _if +: _true :+ _end
- }
- case (_, _: Mux) => (_if +: _true) ++ _elseIfFalse
- case _ => (_if +: _true :+ _else) ++ _false :+ _end
+ case (t, _) if weq(t, r) => _ifNot +: _false :+ _end
+ case (_, f) if weq(f, r) =>
+ m.cond.tpe match {
+ case AsyncResetType => (_if +: _true :+ _else) ++ _true :+ _end
+ case _ => _if +: _true :+ _end
+ }
+ case (_, _: Mux) => (_if +: _true) ++ _elseIfFalse
+ case _ => (_if +: _true :+ _else) ++ _false :+ _end
}
case e => Seq(Seq(tabs, r, " <= ", e, ";", info))
}
@@ -816,35 +906,52 @@ class VerilogEmitter extends SeqTransform with Emitter {
val maxDataValue = (BigInt(1) << dataWidth.toInt) - 1
def checkValueRange(value: BigInt, at: String): Unit = {
- if(value < 0) throw EmitterException(s"Memory ${at} cannot be initialized with negative value: $value")
- if(value > maxDataValue) throw EmitterException(s"Memory ${at} cannot be initialized with value: $value. Too large (> $maxDataValue)!")
+ if (value < 0) throw EmitterException(s"Memory ${at} cannot be initialized with negative value: $value")
+ if (value > maxDataValue)
+ throw EmitterException(s"Memory ${at} cannot be initialized with value: $value. Too large (> $maxDataValue)!")
}
opt.initValue match {
case MemoryArrayInit(values) =>
- if(values.length != s.depth) throw EmitterException(
- s"Memory ${s.name} of depth ${s.depth} cannot be initialized with an array of length ${values.length}!"
- )
+ if (values.length != s.depth)
+ throw EmitterException(
+ s"Memory ${s.name} of depth ${s.depth} cannot be initialized with an array of length ${values.length}!"
+ )
val memName = LowerTypes.loweredName(wref(s.name, s.dataType))
- values.zipWithIndex.foreach { case (value, addr) =>
- checkValueRange(value, s"${s.name}[$addr]")
- val access = s"$memName[${bigIntToVLit(addr)}]"
- memoryInitials += Seq(access, " = ", bigIntToVLit(value), ";")
+ values.zipWithIndex.foreach {
+ case (value, addr) =>
+ checkValueRange(value, s"${s.name}[$addr]")
+ val access = s"$memName[${bigIntToVLit(addr)}]"
+ memoryInitials += Seq(access, " = ", bigIntToVLit(value), ";")
}
case MemoryScalarInit(value) =>
checkValueRange(value, s.name)
// note: s.dataType is the incorrect type for initvar, but it is ignored in the serialization
val index = wref("initvar", s.dataType)
memoryInitials += Seq("for (initvar = 0; initvar < ", bigIntToVLit(s.depth), "; initvar = initvar+1)")
- memoryInitials += Seq(tab, WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow),
- " = ", bigIntToVLit(value), ";")
+ memoryInitials += Seq(
+ tab,
+ WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow),
+ " = ",
+ bigIntToVLit(value),
+ ";"
+ )
case MemoryRandomInit =>
// note: s.dataType is the incorrect type for initvar, but it is ignored in the serialization
val index = wref("initvar", s.dataType)
val rstring = rand_string(s.dataType, "RANDOMIZE_MEM_INIT")
- ifdefInitials("RANDOMIZE_MEM_INIT") += Seq("for (initvar = 0; initvar < ", bigIntToVLit(s.depth), "; initvar = initvar+1)")
- ifdefInitials("RANDOMIZE_MEM_INIT") += Seq(tab, WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow),
- " = ", rstring, ";")
+ ifdefInitials("RANDOMIZE_MEM_INIT") += Seq(
+ "for (initvar = 0; initvar < ",
+ bigIntToVLit(s.depth),
+ "; initvar = initvar+1)"
+ )
+ ifdefInitials("RANDOMIZE_MEM_INIT") += Seq(
+ tab,
+ WSubAccess(wref(s.name, s.dataType), index, s.dataType, SinkFlow),
+ " = ",
+ rstring,
+ ";"
+ )
}
}
@@ -888,7 +995,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
if (lines.size > 1) {
val lineSeqs = lines.tail.map {
- case "" => Seq(" *")
+ case "" => Seq(" *")
case nonEmpty => Seq(" * ", nonEmpty)
}
Seq("/* ", lines.head) +: lineSeqs :+ Seq(" */")
@@ -905,19 +1012,20 @@ class VerilogEmitter extends SeqTransform with Emitter {
def build_ports(): Unit = {
def padToMax(strs: Seq[String]): Seq[String] = {
val len = if (strs.nonEmpty) strs.map(_.length).max else 0
- strs map (_.padTo(len, ' '))
+ strs.map(_.padTo(len, ' '))
}
// Turn directions into strings (and AnalogType into inout)
- val dirs = m.ports map { case Port(_, name, dir, tpe) =>
- (dir, tpe) match {
- case (_, AnalogType(_)) => "inout " // padded to length of output
- case (Input, _) => "input "
- case (Output, _) => "output"
- }
+ val dirs = m.ports.map {
+ case Port(_, name, dir, tpe) =>
+ (dir, tpe) match {
+ case (_, AnalogType(_)) => "inout " // padded to length of output
+ case (Input, _) => "input "
+ case (Output, _) => "output"
+ }
}
// Turn types into strings, all ports must be GroundTypes
- val tpes = m.ports map {
+ val tpes = m.ports.map {
case Port(_, _, _, tpe: GroundType) => stringify(tpe)
case port: Port => error(s"Trying to emit non-GroundType Port $port")
}
@@ -925,9 +1033,10 @@ class VerilogEmitter extends SeqTransform with Emitter {
// dirs are already padded
(dirs, padToMax(tpes), m.ports).zipped.toSeq.zipWithIndex.foreach {
case ((dir, tpe, Port(info, name, _, _)), i) =>
- portDescriptions.get(name).map { case d =>
- portdefs += Seq("")
- portdefs ++= build_description(d)
+ portDescriptions.get(name).map {
+ case d =>
+ portdefs += Seq("")
+ portdefs ++= build_description(d)
}
if (i != m.ports.size - 1) {
@@ -956,14 +1065,14 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
withoutDescription.foreach(build_streams)
withoutDescription match {
- case sx@Connect(info, loc@WRef(_, _, PortKind | WireKind | InstanceKind, _), expr) =>
+ case sx @ Connect(info, loc @ WRef(_, _, PortKind | WireKind | InstanceKind, _), expr) =>
assign(loc, expr, info)
case sx: DefWire =>
declare("wire", sx.name, sx.tpe, sx.info)
case sx: DefRegister =>
val options = emissionOptions.getRegisterEmissionOption(moduleTarget.ref(sx.name))
val e = wref(sx.name, sx.tpe)
- if (options.useInitAsPreset){
+ if (options.useInitAsPreset) {
declare("reg", sx.name, sx.tpe, sx.info, sx.init)
regUpdate(e, sx.clock, sx.reset, e)
} else {
@@ -997,11 +1106,11 @@ class VerilogEmitter extends SeqTransform with Emitter {
case sx: WDefInstanceConnector =>
val (module, params) = moduleMap(sx.module) match {
case DescribedMod(_, _, ExtModule(_, _, _, extname, params)) => (extname, params)
- case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty)
- case ExtModule(_, _, _, extname, params) => (extname, params)
- case Module(_, name, _, _) => (name, Seq.empty)
+ case DescribedMod(_, _, Module(_, name, _, _)) => (name, Seq.empty)
+ case ExtModule(_, _, _, extname, params) => (extname, params)
+ case Module(_, name, _, _) => (name, Seq.empty)
}
- val ps = if (params.nonEmpty) params map stringify mkString("#(", ", ", ") ") else ""
+ val ps = if (params.nonEmpty) params.map(stringify).mkString("#(", ", ", ") ") else ""
instdeclares += Seq(module, " ", ps, sx.name, " (", sx.info)
for (((port, ref), i) <- sx.portCons.zipWithIndex) {
val line = Seq(tab, ".", remove_root(port), "(", ref, ")")
@@ -1012,14 +1121,16 @@ class VerilogEmitter extends SeqTransform with Emitter {
case sx: DefMemory =>
val options = emissionOptions.getMemoryEmissionOption(moduleTarget.ref(sx.name))
val fullSize = sx.depth * (sx.dataType match {
- case GroundType(IntWidth(width)) => width
- })
+ case GroundType(IntWidth(width)) => width
+ })
val decl = if (fullSize > (1 << 29)) "reg /* sparse */" else "reg"
declareVectorType(decl, sx.name, sx.dataType, sx.depth, sx.info)
initialize_mem(sx, options)
if (sx.readLatency != 0 || sx.writeLatency != 1)
- throw EmitterException("All memories should be transformed into " +
- "blackboxes or combinational by previous passses")
+ throw EmitterException(
+ "All memories should be transformed into " +
+ "blackboxes or combinational by previous passses"
+ )
for (r <- sx.readers) {
val data = memPortField(sx, r, "data")
val addr = memPortField(sx, r, "addr")
@@ -1031,7 +1142,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
//; Read port
assign(addr, netlist(addr))
- // assign(en, netlist(en)) //;Connects value to m.r.en
+ // assign(en, netlist(en)) //;Connects value to m.r.en
val mem = WRef(sx.name, memType(sx), MemKind, UnknownFlow)
val memPort = WSubAccess(mem, addr, sx.dataType, UnknownFlow)
val depthValue = UIntLiteral(sx.depth, IntWidth(sx.depth.bitLength))
@@ -1069,8 +1180,10 @@ class VerilogEmitter extends SeqTransform with Emitter {
}
if (sx.readwriters.nonEmpty)
- throw EmitterException("All readwrite ports should be transformed into " +
- "read & write ports by previous passes")
+ throw EmitterException(
+ "All readwrite ports should be transformed into " +
+ "read & write ports by previous passes"
+ )
case _ =>
}
}
@@ -1081,10 +1194,11 @@ class VerilogEmitter extends SeqTransform with Emitter {
for (x <- portdefs) emit(Seq(tab, x))
emit(Seq(");"))
- ifdefDeclares.toSeq.sortWith(_._1 < _._1).foreach { case (ifdef, declares) =>
- emit(Seq("`ifdef " + ifdef))
- for (x <- declares) emit(Seq(tab, x))
- emit(Seq("`endif // " + ifdef))
+ ifdefDeclares.toSeq.sortWith(_._1 < _._1).foreach {
+ case (ifdef, declares) =>
+ emit(Seq("`ifdef " + ifdef))
+ for (x <- declares) emit(Seq(tab, x))
+ emit(Seq("`endif // " + ifdef))
}
for (x <- declares) emit(Seq(tab, x))
for (x <- instdeclares) emit(Seq(tab, x))
@@ -1093,7 +1207,12 @@ class VerilogEmitter extends SeqTransform with Emitter {
emit(Seq("`ifdef SYNTHESIS"))
for (x <- attachSynAssigns) emit(Seq(tab, x))
emit(Seq("`elsif verilator"))
- emit(Seq(tab, "`error \"Verilator does not support alias and thus cannot arbirarily connect bidirectional wires and ports\""))
+ emit(
+ Seq(
+ tab,
+ "`error \"Verilator does not support alias and thus cannot arbirarily connect bidirectional wires and ports\""
+ )
+ )
emit(Seq("`else"))
for (x <- attachAliases) emit(Seq(tab, x))
emit(Seq("`endif"))
@@ -1129,7 +1248,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
emit(Seq("`define RANDOM $random"))
emit(Seq("`endif"))
// the initvar is also used to initialize memories to constants
- if(memoryInitials.isEmpty) emit(Seq("`ifdef RANDOMIZE_MEM_INIT"))
+ if (memoryInitials.isEmpty) emit(Seq("`ifdef RANDOMIZE_MEM_INIT"))
// Since simulators don't actually support memories larger than 2^31 - 1, there is no reason
// to change Verilog emission in the common case. Instead, we only emit a larger initvar
// where necessary
@@ -1140,7 +1259,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
val width = maxMemSize.bitLength - 1 // minus one because [width-1:0] has a width of "width"
emit(Seq(s" reg [$width:0] initvar;"))
}
- if(memoryInitials.isEmpty) emit(Seq("`endif"))
+ if (memoryInitials.isEmpty) emit(Seq("`endif"))
emit(Seq("`ifndef SYNTHESIS"))
// User-defined macro of code to run before an initial block
emit(Seq("`ifdef FIRRTL_BEFORE_INITIAL"))
@@ -1162,15 +1281,16 @@ class VerilogEmitter extends SeqTransform with Emitter {
emit(Seq(" #0.002 begin end"))
emit(Seq(" `endif"))
emit(Seq(" `endif"))
- ifdefInitials.toSeq.sortWith(_._1 < _._1).foreach { case (ifdef, initials) =>
- emit(Seq("`ifdef " + ifdef))
- for (x <- initials) emit(Seq(tab, x))
- emit(Seq("`endif // " + ifdef))
+ ifdefInitials.toSeq.sortWith(_._1 < _._1).foreach {
+ case (ifdef, initials) =>
+ emit(Seq("`ifdef " + ifdef))
+ for (x <- initials) emit(Seq(tab, x))
+ emit(Seq("`endif // " + ifdef))
}
for (x <- initials) emit(Seq(tab, x))
for (x <- asyncInitials) emit(Seq(tab, x))
emit(Seq(" `endif // RANDOMIZE"))
- for(x <- memoryInitials) emit(Seq(tab, x))
+ for (x <- memoryInitials) emit(Seq(tab, x))
emit(Seq("end // initial"))
// User-defined macro of code to run after an initial block
emit(Seq("`ifdef FIRRTL_AFTER_INITIAL"))
@@ -1258,7 +1378,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
val emissionOptions = new EmissionOptions(cs.annotations)
val moduleMap = cs.circuit.modules.map(m => m.name -> m).toMap
- cs.circuit.modules flatMap {
+ cs.circuit.modules.flatMap {
case dm @ DescribedMod(d, pds, module: Module) =>
val writer = new java.io.StringWriter
val renderer = new VerilogRender(d, pds, module, moduleMap, cs.circuit.main, emissionOptions)(writer)
@@ -1282,8 +1402,8 @@ class MinimumVerilogEmitter extends VerilogEmitter with Emitter {
override def prerequisites = firrtl.stage.Forms.AssertsRemoved ++
firrtl.stage.Forms.LowFormMinimumOptimized
- override def transforms = new TransformManager(firrtl.stage.Forms.VerilogMinimumOptimized, prerequisites)
- .flattenedTransformOrder
+ override def transforms =
+ new TransformManager(firrtl.stage.Forms.VerilogMinimumOptimized, prerequisites).flattenedTransformOrder
}
@@ -1292,9 +1412,14 @@ class SystemVerilogEmitter extends VerilogEmitter {
override def prerequisites = firrtl.stage.Forms.LowFormOptimized
- override def addFormalStatement(formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]],
- clk: Expression, en: Expression,
- stmt: Seq[Any], info: Info, msg: StringLit): Unit = {
+ override def addFormalStatement(
+ formals: mutable.Map[Expression, ArrayBuffer[Seq[Any]]],
+ clk: Expression,
+ en: Expression,
+ stmt: Seq[Any],
+ info: Info,
+ msg: StringLit
+ ): Unit = {
val lines = formals.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]())
lines += Seq("// ", msg.serialize)
lines += Seq("if (", en, ") begin")
diff --git a/src/main/scala/firrtl/ExecutionOptionsManager.scala b/src/main/scala/firrtl/ExecutionOptionsManager.scala
index d21ccade..50fb30a6 100644
--- a/src/main/scala/firrtl/ExecutionOptionsManager.scala
+++ b/src/main/scala/firrtl/ExecutionOptionsManager.scala
@@ -5,15 +5,22 @@ package firrtl
import logger.LogLevel
import logger.{ClassLogLevelAnnotation, LogClassNamesAnnotation, LogFileAnnotation, LogLevelAnnotation}
import firrtl.annotations._
-import firrtl.Parser.{InfoMode, UseInfo, IgnoreInfo, GenInfo, AppendInfo}
+import firrtl.Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo}
import firrtl.ir.Circuit
import firrtl.passes.memlib.{InferReadWriteAnnotation, ReplSeqMemAnnotation}
import firrtl.passes.clocklist.ClockListAnnotation
import firrtl.transforms.NoCircuitDedupAnnotation
import scopt.OptionParser
-import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation, FirrtlFileAnnotation, FirrtlSourceAnnotation,
- InfoModeAnnotation, OutputFileAnnotation, RunFirrtlTransformAnnotation}
-import firrtl.stage.phases.DriverCompatibility.{TopNameAnnotation, EmitOneFilePerModuleAnnotation}
+import firrtl.stage.{
+ CompilerAnnotation,
+ FirrtlCircuitAnnotation,
+ FirrtlFileAnnotation,
+ FirrtlSourceAnnotation,
+ InfoModeAnnotation,
+ OutputFileAnnotation,
+ RunFirrtlTransformAnnotation
+}
+import firrtl.stage.phases.DriverCompatibility.{EmitOneFilePerModuleAnnotation, TopNameAnnotation}
import firrtl.options.{InputAnnotationFileAnnotation, OutputAnnotationFileAnnotation, ProgramArgsAnnotation, StageUtils}
import firrtl.transforms.{DontCheckCombLoopsAnnotation, NoDCEAnnotation}
@@ -33,7 +40,7 @@ abstract class HasParser(applicationName: String) {
final val parser = new OptionParser[Unit](applicationName) {
var terminateOnExit = true
override def terminate(exitState: Either[String, Unit]): Unit = {
- if(terminateOnExit) sys.exit(0)
+ if (terminateOnExit) sys.exit(0)
}
}
@@ -43,12 +50,14 @@ abstract class HasParser(applicationName: String) {
def doNotExitOnHelp(): Unit = {
parser.terminateOnExit = false
}
+
/**
* By default scopt calls sys.exit when --help is in options, this un-defeats doNotExitOnHelp
*/
def exitOnHelp(): Unit = {
parser.terminateOnExit = true
- }}
+ }
+}
/**
* Most of the chisel toolchain components require a topName which defines a circuit or a device under test.
@@ -59,20 +68,19 @@ abstract class HasParser(applicationName: String) {
*/
@deprecated("Use a FirrtlOptionsView, LoggerOptionsView, or construct your own view of an AnnotationSeq", "1.2")
case class CommonOptions(
- topName: String = "",
- targetDirName: String = ".",
- globalLogLevel: LogLevel.Value = LogLevel.None,
- logToFile: Boolean = false,
- logClassNames: Boolean = false,
- classLogLevels: Map[String, LogLevel.Value] = Map.empty,
- programArgs: Seq[String] = Seq.empty
-) extends ComposableOptions {
+ topName: String = "",
+ targetDirName: String = ".",
+ globalLogLevel: LogLevel.Value = LogLevel.None,
+ logToFile: Boolean = false,
+ logClassNames: Boolean = false,
+ classLogLevels: Map[String, LogLevel.Value] = Map.empty,
+ programArgs: Seq[String] = Seq.empty)
+ extends ComposableOptions {
def getLogFileName(optionsManager: ExecutionOptionsManager): String = {
- if(topName.isEmpty) {
+ if (topName.isEmpty) {
optionsManager.getBuildFileName("log", "firrtl")
- }
- else {
+ } else {
optionsManager.getBuildFileName("log")
}
}
@@ -80,10 +88,12 @@ case class CommonOptions(
def toAnnotations: AnnotationSeq = List() ++ (if (topName.nonEmpty) Seq(TopNameAnnotation(topName)) else Seq()) ++
(if (targetDirName != ".") Some(TargetDirAnnotation(targetDirName)) else None) ++
Some(LogLevelAnnotation(globalLogLevel)) ++
- (if (logToFile) { Some(LogFileAnnotation(None)) } else { None }) ++
- (if (logClassNames) { Some(LogClassNamesAnnotation) } else { None }) ++
- classLogLevels.map{ case (c, v) => ClassLogLevelAnnotation(c, v) } ++
- programArgs.map( a => ProgramArgsAnnotation(a) )
+ (if (logToFile) { Some(LogFileAnnotation(None)) }
+ else { None }) ++
+ (if (logClassNames) { Some(LogClassNamesAnnotation) }
+ else { None }) ++
+ classLogLevels.map { case (c, v) => ClassLogLevelAnnotation(c, v) } ++
+ programArgs.map(a => ProgramArgsAnnotation(a))
}
@deprecated("Specify command line arguments in an Annotation mixing in HasScoptOptions", "1.2")
@@ -93,7 +103,8 @@ trait HasCommonOptions {
parser.note("common options")
- parser.opt[String]("top-name")
+ parser
+ .opt[String]("top-name")
.abbr("tn")
.valueName("<top-level-circuit-name>")
.foreach { x =>
@@ -101,15 +112,19 @@ trait HasCommonOptions {
}
.text("This options defines the top level circuit, defaults to dut when possible")
- parser.opt[String]("target-dir")
- .abbr("td").valueName("<target-directory>")
+ parser
+ .opt[String]("target-dir")
+ .abbr("td")
+ .valueName("<target-directory>")
.foreach { x =>
commonOptions = commonOptions.copy(targetDirName = x)
}
.text(s"This options defines a work directory for intermediate files, default is ${commonOptions.targetDirName}")
- parser.opt[String]("log-level")
- .abbr("ll").valueName("<error|warn|info|debug|trace>")
+ parser
+ .opt[String]("log-level")
+ .abbr("ll")
+ .valueName("<error|warn|info|debug|trace>")
.foreach { x =>
val level = x.toLowerCase match {
case "error" => LogLevel.Error
@@ -126,16 +141,18 @@ trait HasCommonOptions {
}
.text(s"This options defines global log level, default is ${commonOptions.globalLogLevel}")
- parser.opt[Seq[String]]("class-log-level")
- .abbr("cll").valueName("<FullClassName:[error|warn|info|debug|trace]>[,...]")
+ parser
+ .opt[Seq[String]]("class-log-level")
+ .abbr("cll")
+ .valueName("<FullClassName:[error|warn|info|debug|trace]>[,...]")
.foreach { x =>
val logAssignments = x.map { y =>
val className :: levelName :: _ = y.split(":").toList
val level = levelName.toLowerCase match {
case "error" => LogLevel.Error
- case "warn" => LogLevel.Warn
- case "info" => LogLevel.Info
+ case "warn" => LogLevel.Warn
+ case "info" => LogLevel.Info
case "debug" => LogLevel.Debug
case "trace" => LogLevel.Trace
case _ =>
@@ -149,14 +166,16 @@ trait HasCommonOptions {
}
.text(s"This options defines class log level, default is ${commonOptions.classLogLevels}")
- parser.opt[Unit]("log-to-file")
+ parser
+ .opt[Unit]("log-to-file")
.abbr("ltf")
.foreach { _ =>
commonOptions = commonOptions.copy(logToFile = true)
}
.text(s"default logs to stdout, this flags writes to topName.log or firrtl.log if no topName")
- parser.opt[Unit]("log-class-names")
+ parser
+ .opt[Unit]("log-class-names")
.abbr("lcn")
.foreach { _ =>
commonOptions = commonOptions.copy(logClassNames = true)
@@ -165,8 +184,12 @@ trait HasCommonOptions {
parser.help("help").text("prints this usage text")
- parser.arg[String]("<arg>...").unbounded().optional().action( (x, c) =>
- commonOptions = commonOptions.copy(programArgs = commonOptions.programArgs :+ x) ).text("optional unbounded args")
+ parser
+ .arg[String]("<arg>...")
+ .unbounded()
+ .optional()
+ .action((x, c) => commonOptions = commonOptions.copy(programArgs = commonOptions.programArgs :+ x))
+ .text("optional unbounded args")
}
@@ -189,46 +212,47 @@ final case class OneFilePerModule(targetDir: String) extends OutputConfig
*/
@deprecated("Use a FirrtlOptionsView or construct your own view of an AnnotationSeq", "1.2")
case class FirrtlExecutionOptions(
- inputFileNameOverride: String = "",
- outputFileNameOverride: String = "",
- compilerName: String = "verilog",
- infoModeName: String = "append",
- inferRW: Seq[String] = Seq.empty,
- firrtlSource: Option[String] = None,
- customTransforms: Seq[Transform] = List.empty,
- annotations: List[Annotation] = List.empty,
- annotationFileNameOverride: String = "",
- outputAnnotationFileName: String = "",
- emitOneFilePerModule: Boolean = false,
- dontCheckCombLoops: Boolean = false,
- noDCE: Boolean = false,
- annotationFileNames: List[String] = List.empty,
- firrtlCircuit: Option[Circuit] = None
-)
-extends ComposableOptions {
-
- require(!(emitOneFilePerModule && outputFileNameOverride.nonEmpty),
- "Cannot both specify the output filename and emit one file per module!!!")
+ inputFileNameOverride: String = "",
+ outputFileNameOverride: String = "",
+ compilerName: String = "verilog",
+ infoModeName: String = "append",
+ inferRW: Seq[String] = Seq.empty,
+ firrtlSource: Option[String] = None,
+ customTransforms: Seq[Transform] = List.empty,
+ annotations: List[Annotation] = List.empty,
+ annotationFileNameOverride: String = "",
+ outputAnnotationFileName: String = "",
+ emitOneFilePerModule: Boolean = false,
+ dontCheckCombLoops: Boolean = false,
+ noDCE: Boolean = false,
+ annotationFileNames: List[String] = List.empty,
+ firrtlCircuit: Option[Circuit] = None)
+ extends ComposableOptions {
+
+ require(
+ !(emitOneFilePerModule && outputFileNameOverride.nonEmpty),
+ "Cannot both specify the output filename and emit one file per module!!!"
+ )
def infoMode: InfoMode = {
infoModeName match {
- case "use" => UseInfo
+ case "use" => UseInfo
case "ignore" => IgnoreInfo
- case "gen" => GenInfo(inputFileNameOverride)
+ case "gen" => GenInfo(inputFileNameOverride)
case "append" => AppendInfo(inputFileNameOverride)
- case other => UseInfo
+ case other => UseInfo
}
}
def compiler: Compiler = {
compilerName match {
- case "none" => new NoneCompiler()
- case "high" => new HighFirrtlCompiler()
- case "low" => new LowFirrtlCompiler()
- case "middle" => new MiddleFirrtlCompiler()
- case "verilog" => new VerilogCompiler()
- case "mverilog" => new MinimumVerilogCompiler()
- case "sverilog" => new SystemVerilogCompiler()
+ case "none" => new NoneCompiler()
+ case "high" => new HighFirrtlCompiler()
+ case "low" => new LowFirrtlCompiler()
+ case "middle" => new MiddleFirrtlCompiler()
+ case "verilog" => new VerilogCompiler()
+ case "mverilog" => new MinimumVerilogCompiler()
+ case "sverilog" => new SystemVerilogCompiler()
}
}
@@ -255,6 +279,7 @@ extends ComposableOptions {
if (inputFileNameOverride.nonEmpty) inputFileNameOverride
else optionsManager.getBuildFileName("fir", inputFileNameOverride)
}
+
/** Get the user-specified [[OutputConfig]]
*
* @param optionsManager this is needed to access build function and its common options
@@ -264,6 +289,7 @@ extends ComposableOptions {
if (emitOneFilePerModule) OneFilePerModule(optionsManager.targetDirName)
else SingleFile(optionsManager.getBuildFileName(outputSuffix, outputFileNameOverride))
}
+
/** Get the user-specified targetFile assuming [[OutputConfig]] is [[SingleFile]]
*
* @param optionsManager this is needed to access build function and its common options
@@ -272,9 +298,10 @@ extends ComposableOptions {
def getTargetFile(optionsManager: ExecutionOptionsManager): String = {
getOutputConfig(optionsManager) match {
case SingleFile(targetFile) => targetFile
- case other => throw new Exception("OutputConfig is not SingleFile!")
+ case other => throw new Exception("OutputConfig is not SingleFile!")
}
}
+
/** Gives annotations based on the output configuration
*
* @param optionsManager this is needed to access build function and its common options
@@ -283,19 +310,20 @@ extends ComposableOptions {
def getEmitterAnnos(optionsManager: ExecutionOptionsManager): Seq[Annotation] = {
// TODO should this be a public function?
val emitter = compilerName match {
- case "none" => classOf[ChirrtlEmitter]
- case "high" => classOf[HighFirrtlEmitter]
- case "middle" => classOf[MiddleFirrtlEmitter]
- case "low" => classOf[LowFirrtlEmitter]
- case "verilog" => classOf[VerilogEmitter]
+ case "none" => classOf[ChirrtlEmitter]
+ case "high" => classOf[HighFirrtlEmitter]
+ case "middle" => classOf[MiddleFirrtlEmitter]
+ case "low" => classOf[LowFirrtlEmitter]
+ case "verilog" => classOf[VerilogEmitter]
case "mverilog" => classOf[MinimumVerilogEmitter]
case "sverilog" => classOf[VerilogEmitter]
}
getOutputConfig(optionsManager) match {
- case SingleFile(_) => Seq(EmitCircuitAnnotation(emitter))
+ case SingleFile(_) => Seq(EmitCircuitAnnotation(emitter))
case OneFilePerModule(_) => Seq(EmitAllModulesAnnotation(emitter))
}
}
+
/**
* build the annotation file name, taking overriding parameters
*
@@ -313,23 +341,28 @@ extends ComposableOptions {
}
List() ++ (if (inputFileNameOverride.nonEmpty) Seq(FirrtlFileAnnotation(inputFileNameOverride)) else Seq()) ++
- (if (outputFileNameOverride.nonEmpty) { Some(OutputFileAnnotation(outputFileNameOverride)) } else { None }) ++
+ (if (outputFileNameOverride.nonEmpty) { Some(OutputFileAnnotation(outputFileNameOverride)) }
+ else { None }) ++
Some(CompilerAnnotation(compilerName)) ++
Some(InfoModeAnnotation(infoModeName)) ++
firrtlSource.map(FirrtlSourceAnnotation(_)) ++
customTransforms.map(t => RunFirrtlTransformAnnotation(t)) ++
annotations ++
- (if (annotationFileNameOverride.nonEmpty) { Some(InputAnnotationFileAnnotation(annotationFileNameOverride)) } else { None }) ++
- (if (outputAnnotationFileName.nonEmpty) { Some(OutputAnnotationFileAnnotation(outputAnnotationFileName)) } else { None }) ++
- (if (emitOneFilePerModule) { Some(EmitOneFilePerModuleAnnotation) } else { None }) ++
- (if (dontCheckCombLoops) { Some(DontCheckCombLoopsAnnotation) } else { None }) ++
- (if (noDCE) { Some(NoDCEAnnotation) } else { None }) ++
+ (if (annotationFileNameOverride.nonEmpty) { Some(InputAnnotationFileAnnotation(annotationFileNameOverride)) }
+ else { None }) ++
+ (if (outputAnnotationFileName.nonEmpty) { Some(OutputAnnotationFileAnnotation(outputAnnotationFileName)) }
+ else { None }) ++
+ (if (emitOneFilePerModule) { Some(EmitOneFilePerModuleAnnotation) }
+ else { None }) ++
+ (if (dontCheckCombLoops) { Some(DontCheckCombLoopsAnnotation) }
+ else { None }) ++
+ (if (noDCE) { Some(NoDCEAnnotation) }
+ else { None }) ++
annotationFileNames.map(InputAnnotationFileAnnotation(_)) ++
firrtlCircuit.map(FirrtlCircuitAnnotation(_))
}
}
-
@deprecated("Specify command line arguments in an Annotation mixing in HasScoptOptions", "1.2")
trait HasFirrtlOptions {
self: ExecutionOptionsManager =>
@@ -337,16 +370,19 @@ trait HasFirrtlOptions {
parser.note("firrtl options")
- parser.opt[String]("input-file")
+ parser
+ .opt[String]("input-file")
.abbr("i")
- .valueName ("<firrtl-source>")
+ .valueName("<firrtl-source>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(inputFileNameOverride = x)
- }.text {
+ }
+ .text {
"use this to override the default input file name , default is empty"
}
- parser.opt[String]("output-file")
+ parser
+ .opt[String]("output-file")
.abbr("o")
.valueName("<output>")
.validate { x =>
@@ -356,40 +392,47 @@ trait HasFirrtlOptions {
}
.foreach { x =>
firrtlOptions = firrtlOptions.copy(outputFileNameOverride = x)
- }.text {
- "use this to override the default output file name, default is empty"
- }
+ }
+ .text {
+ "use this to override the default output file name, default is empty"
+ }
- parser.opt[String]("annotation-file")
+ parser
+ .opt[String]("annotation-file")
.abbr("faf")
.unbounded()
.valueName("<input-anno-file>")
.foreach { x =>
val annoFiles = x +: firrtlOptions.annotationFileNames
firrtlOptions = firrtlOptions.copy(annotationFileNames = annoFiles)
- }.text("Used to specify annotation files (can appear multiple times)")
+ }
+ .text("Used to specify annotation files (can appear multiple times)")
- parser.opt[Unit]("force-append-anno-file")
+ parser
+ .opt[Unit]("force-append-anno-file")
.abbr("ffaaf")
.hidden()
.foreach { _ =>
val msg = "force-append-anno-file is deprecated and will soon be removed\n" +
- (" "*9) + "(It does not do anything anymore)"
+ (" " * 9) + "(It does not do anything anymore)"
StageUtils.dramaticWarning(msg)
}
- parser.opt[String]("output-annotation-file")
+ parser
+ .opt[String]("output-annotation-file")
.abbr("foaf")
- .valueName ("<output-anno-file>")
+ .valueName("<output-anno-file>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(outputAnnotationFileName = x)
- }.text {
- "use this to set the annotation output file"
- }
+ }
+ .text {
+ "use this to set the annotation output file"
+ }
- parser.opt[String]("compiler")
+ parser
+ .opt[String]("compiler")
.abbr("X")
- .valueName ("<high|middle|low|verilog|mverilog|sverilog|none>")
+ .valueName("<high|middle|low|verilog|mverilog|sverilog|none>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(compilerName = x)
}
@@ -399,12 +442,14 @@ trait HasFirrtlOptions {
} else {
parser.failure(s"$x not a legal compiler")
}
- }.text {
+ }
+ .text {
s"compiler to use, default is ${firrtlOptions.compilerName}"
}
- parser.opt[String]("info-mode")
- .valueName ("<ignore|use|gen|append>")
+ parser
+ .opt[String]("info-mode")
+ .valueName("<ignore|use|gen|append>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(infoModeName = x.toLowerCase)
}
@@ -416,13 +461,14 @@ trait HasFirrtlOptions {
s"specifies the source info handling, default is ${firrtlOptions.infoModeName}"
}
- parser.opt[Seq[String]]("custom-transforms")
+ parser
+ .opt[Seq[String]]("custom-transforms")
.abbr("fct")
- .valueName ("<package>.<class>")
+ .valueName("<package>.<class>")
.foreach { customTransforms: Seq[String] =>
firrtlOptions = firrtlOptions.copy(
customTransforms = firrtlOptions.customTransforms ++
- (customTransforms map { x: String =>
+ (customTransforms.map { x: String =>
Class.forName(x).asInstanceOf[Class[_ <: Transform]].newInstance()
})
)
@@ -431,10 +477,10 @@ trait HasFirrtlOptions {
"""runs these custom transforms during compilation."""
}
-
- parser.opt[Seq[String]]("inline")
+ parser
+ .opt[Seq[String]]("inline")
.abbr("fil")
- .valueName ("<circuit>[.<module>[.<instance>]][,..],")
+ .valueName("<circuit>[.<module>[.<instance>]][,..],")
.foreach { x =>
val newAnnotations = x.map { value =>
value.split('.') match {
@@ -455,20 +501,23 @@ trait HasFirrtlOptions {
"""Inline one or more module (comma separated, no spaces) module looks like "MyModule" or "MyModule.myinstance"""
}
- parser.opt[Unit]("infer-rw")
+ parser
+ .opt[Unit]("infer-rw")
.abbr("firw")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(
annotations = firrtlOptions.annotations :+ InferReadWriteAnnotation,
customTransforms = firrtlOptions.customTransforms :+ new passes.memlib.InferReadWrite
)
- }.text {
+ }
+ .text {
"Enable readwrite port inference for the target circuit"
}
- parser.opt[String]("repl-seq-mem")
+ parser
+ .opt[String]("repl-seq-mem")
.abbr("frsq")
- .valueName ("-c:<circuit>:-i:<filename>:-o:<filename>")
+ .valueName("-c:<circuit>:-i:<filename>:-o:<filename>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(
annotations = firrtlOptions.annotations :+ ReplSeqMemAnnotation.parse(x),
@@ -479,9 +528,10 @@ trait HasFirrtlOptions {
"Replace sequential memories with blackboxes + configuration file"
}
- parser.opt[String]("list-clocks")
+ parser
+ .opt[String]("list-clocks")
.abbr("clks")
- .valueName ("-c:<circuit>:-m:<module>:-o:<filename>")
+ .valueName("-c:<circuit>:-m:<module>:-o:<filename>")
.foreach { x =>
firrtlOptions = firrtlOptions.copy(
annotations = firrtlOptions.annotations :+ ClockListAnnotation.parse(x),
@@ -492,7 +542,8 @@ trait HasFirrtlOptions {
"List which signal drives each clock of every descendent of specified module"
}
- parser.opt[Unit]("split-modules")
+ parser
+ .opt[Unit]("split-modules")
.abbr("fsm")
.validate { x =>
if (firrtlOptions.outputFileNameOverride.nonEmpty)
@@ -501,32 +552,39 @@ trait HasFirrtlOptions {
}
.foreach { _ =>
firrtlOptions = firrtlOptions.copy(emitOneFilePerModule = true)
- }.text {
+ }
+ .text {
"Emit each module to its own file in the target directory."
}
- parser.opt[Unit]("no-check-comb-loops")
+ parser
+ .opt[Unit]("no-check-comb-loops")
.foreach { _ =>
firrtlOptions = firrtlOptions.copy(dontCheckCombLoops = true)
- }.text {
+ }
+ .text {
"Do NOT check for combinational loops (not recommended)"
}
- parser.opt[Unit]("no-dce")
+ parser
+ .opt[Unit]("no-dce")
.foreach { _ =>
firrtlOptions = firrtlOptions.copy(noDCE = true)
- }.text {
+ }
+ .text {
"Do NOT run dead code elimination"
}
- parser.opt[Unit]("no-dedup")
+ parser
+ .opt[Unit]("no-dedup")
.foreach { _ =>
firrtlOptions = firrtlOptions.copy(
annotations = firrtlOptions.annotations :+ NoCircuitDedupAnnotation
)
- }.text {
- "Do NOT dedup modules"
- }
+ }
+ .text {
+ "Do NOT dedup modules"
+ }
parser.note("")
}
@@ -537,16 +595,16 @@ sealed trait FirrtlExecutionResult
@deprecated("Use FirrtlStage and examine the output AnnotationSeq directly", "1.2")
object FirrtlExecutionSuccess {
def apply(
- emitType : String,
- emitted : String,
+ emitType: String,
+ emitted: String,
circuitState: CircuitState
): FirrtlExecutionSuccess = new FirrtlExecutionSuccess(emitType, emitted, circuitState)
-
def unapply(arg: FirrtlExecutionSuccess): Option[(String, String)] = {
Some((arg.emitType, arg.emitted))
}
}
+
/**
* Indicates a successful execution of the firrtl compiler, returning the compiled result and
* the type of compile
@@ -557,10 +615,10 @@ object FirrtlExecutionSuccess {
*/
@deprecated("Use FirrtlStage and examine the output AnnotationSeq directly", "1.2")
class FirrtlExecutionSuccess(
- val emitType: String,
- val emitted : String,
- val circuitState: CircuitState
-) extends FirrtlExecutionResult
+ val emitType: String,
+ val emitted: String,
+ val circuitState: CircuitState)
+ extends FirrtlExecutionResult
/**
* The firrtl compilation failed.
@@ -571,7 +629,6 @@ class FirrtlExecutionSuccess(
case class FirrtlExecutionFailure(message: String) extends FirrtlExecutionResult
/**
- *
* @param applicationName The name shown in the usage
*/
@deprecated("Use new FirrtlStage infrastructure", "1.2")
@@ -607,7 +664,7 @@ class ExecutionOptionsManager(val applicationName: String) extends HasParser(app
commonOptions = commonOptions.copy(topName = newTopName)
}
def setTopNameIfNotSet(newTopName: String): Unit = {
- if(commonOptions.topName.isEmpty) {
+ if (commonOptions.topName.isEmpty) {
setTopName(newTopName)
}
}
@@ -627,21 +684,19 @@ class ExecutionOptionsManager(val applicationName: String) extends HasParser(app
def getBuildFileName(suffix: String, fileNameOverride: String = ""): String = {
makeTargetDir()
- val baseName = if(fileNameOverride.nonEmpty) fileNameOverride else topName
+ val baseName = if (fileNameOverride.nonEmpty) fileNameOverride else topName
val directoryName = {
- if(fileNameOverride.nonEmpty) {
+ if (fileNameOverride.nonEmpty) {
""
- }
- else if(baseName.startsWith("./") || baseName.startsWith("/")) {
+ } else if (baseName.startsWith("./") || baseName.startsWith("/")) {
""
- }
- else {
- if(targetDirName.endsWith("/")) targetDirName else targetDirName + "/"
+ } else {
+ if (targetDirName.endsWith("/")) targetDirName else targetDirName + "/"
}
}
val normalizedSuffix = {
- val dottedSuffix = if(suffix.startsWith(".")) suffix else s".$suffix"
- if(baseName.endsWith(dottedSuffix)) "" else dottedSuffix
+ val dottedSuffix = if (suffix.startsWith(".")) suffix else s".$suffix"
+ if (baseName.endsWith(dottedSuffix)) "" else dottedSuffix
}
val path = directoryName + baseName.split("/").dropRight(1).mkString("/")
FileUtils.makeDirectory(path)
diff --git a/src/main/scala/firrtl/FileUtils.scala b/src/main/scala/firrtl/FileUtils.scala
index 8e73b4f9..3db86b7c 100644
--- a/src/main/scala/firrtl/FileUtils.scala
+++ b/src/main/scala/firrtl/FileUtils.scala
@@ -7,7 +7,7 @@ import java.io.File
import firrtl.options.StageUtils
import scala.collection.Seq
-import scala.sys.process.{BasicIO, ProcessLogger, stringSeqToProcess}
+import scala.sys.process.{stringSeqToProcess, BasicIO, ProcessLogger}
object FileUtils {
@@ -17,7 +17,7 @@ object FileUtils {
*/
def makeDirectory(directoryName: String): Boolean = {
val dirFile = new File(directoryName)
- if(dirFile.exists()) {
+ if (dirFile.exists()) {
dirFile.isDirectory
} else {
dirFile.mkdirs()
@@ -33,6 +33,7 @@ object FileUtils {
def deleteDirectoryHierarchy(directoryPathName: String): Boolean = {
deleteDirectoryHierarchy(new File(directoryPathName))
}
+
/**
* recursively delete all directories in a relative path
* DO NOT DELETE absolute paths
@@ -40,18 +41,18 @@ object FileUtils {
* @param file: a directory hierarchy to delete
*/
def deleteDirectoryHierarchy(file: File, atTop: Boolean = true): Boolean = {
- if(file.getPath.split("/").last.isEmpty ||
+ if (
+ file.getPath.split("/").last.isEmpty ||
file.getAbsolutePath == "/" ||
- file.getPath.startsWith("/")) {
+ file.getPath.startsWith("/")
+ ) {
StageUtils.dramaticError(s"delete directory ${file.getPath} will not delete absolute paths")
false
- }
- else {
+ } else {
val result = {
- if(file.isDirectory) {
- file.listFiles().forall( f => deleteDirectoryHierarchy(f)) && file.delete()
- }
- else {
+ if (file.isDirectory) {
+ file.listFiles().forall(f => deleteDirectoryHierarchy(f)) && file.delete()
+ } else {
file.delete()
}
}
@@ -81,7 +82,7 @@ object FileUtils {
* @param cmd the command/executable (without any arguments).
* @return true if ```cmd``` returns a 0 exit status.
*/
- def isCommandAvailable(cmd:String): Boolean = {
+ def isCommandAvailable(cmd: String): Boolean = {
isCommandAvailable(Seq(cmd))
}
@@ -90,7 +91,7 @@ object FileUtils {
* Instead we try to run the executable itself (with innocuous arguments) and interpret any errors/exceptions
* as an indication that the executable is unavailable.
*/
- lazy val isVCSAvailable: Boolean = isCommandAvailable(Seq("vcs", "-platform"))
+ lazy val isVCSAvailable: Boolean = isCommandAvailable(Seq("vcs", "-platform"))
/** Read a text file and return it as a Seq of strings
* Closes the file after read to avoid dangling file handles
diff --git a/src/main/scala/firrtl/FirrtlException.scala b/src/main/scala/firrtl/FirrtlException.scala
index 20d984a1..6f98fda3 100644
--- a/src/main/scala/firrtl/FirrtlException.scala
+++ b/src/main/scala/firrtl/FirrtlException.scala
@@ -18,7 +18,7 @@ object FIRRTLException {
}
@deprecated("External users should use either FirrtlUserException or their own hierarchy", "1.2")
class FIRRTLException(val str: String, cause: Throwable = null)
- extends RuntimeException(FIRRTLException.defaultMessage(str, cause), cause)
+ extends RuntimeException(FIRRTLException.defaultMessage(str, cause), cause)
/** Exception indicating user error
*
@@ -26,7 +26,8 @@ class FIRRTLException(val str: String, cause: Throwable = null)
* This can be extended by custom transform writers.
*/
class FirrtlUserException(message: String, cause: Throwable = null)
- extends RuntimeException(message, cause) with NoStackTrace
+ extends RuntimeException(message, cause)
+ with NoStackTrace
/** Wraps exceptions from CustomTransforms so they can be reported appropriately */
case class CustomTransformException(cause: Throwable) extends Exception("", cause)
@@ -40,4 +41,4 @@ case class CustomTransformException(cause: Throwable) extends Exception("", caus
* transforms are treated differently and should thus have their own structure
*/
private[firrtl] class FirrtlInternalException(message: String, cause: Throwable = null)
- extends Exception(message, cause)
+ extends Exception(message, cause)
diff --git a/src/main/scala/firrtl/Implicits.scala b/src/main/scala/firrtl/Implicits.scala
index ec1cf3d6..fd732917 100644
--- a/src/main/scala/firrtl/Implicits.scala
+++ b/src/main/scala/firrtl/Implicits.scala
@@ -7,19 +7,19 @@ import Utils.trim
import firrtl.constraint.Constraint
object Implicits {
- implicit def int2WInt(i: Int): WrappedInt = WrappedInt(BigInt(i))
- implicit def bigint2WInt(i: BigInt): WrappedInt = WrappedInt(i)
+ implicit def int2WInt(i: Int): WrappedInt = WrappedInt(BigInt(i))
+ implicit def bigint2WInt(i: BigInt): WrappedInt = WrappedInt(i)
implicit def constraint2bound(c: Constraint): Bound = c match {
case x: Bound => x
case x => CalcBound(x)
}
implicit def constraint2width(c: Constraint): Width = c match {
case Closed(x) if trim(x).isWhole => IntWidth(x.toBigInt)
- case x => CalcWidth(x)
+ case x => CalcWidth(x)
}
implicit def width2constraint(w: Width): Constraint = w match {
case CalcWidth(x: Constraint) => x
- case IntWidth(x) => Closed(BigDecimal(x))
+ case IntWidth(x) => Closed(BigDecimal(x))
case UnknownWidth => UnknownBound
case v: Constraint => v
}
diff --git a/src/main/scala/firrtl/LexerHelper.scala b/src/main/scala/firrtl/LexerHelper.scala
index cc17ac46..3ddfc5b9 100644
--- a/src/main/scala/firrtl/LexerHelper.scala
+++ b/src/main/scala/firrtl/LexerHelper.scala
@@ -15,7 +15,7 @@ import firrtl.antlr.FIRRTLParser
abstract class LexerHelper {
- import FIRRTLParser.{NEWLINE, INDENT, DEDENT}
+ import FIRRTLParser.{DEDENT, INDENT, NEWLINE}
private val tokenBuffer = mutable.Queue.empty[Token]
private val indentations = mutable.Stack[Int]()
@@ -58,9 +58,9 @@ abstract class LexerHelper {
def handleNewlineToken(token: Token): Token = {
@tailrec
- def nonNewline(token: Token) : (Token, Token) = {
+ def nonNewline(token: Token): (Token, Token) = {
val nextNext = pullToken()
- if(nextNext.getType == NEWLINE)
+ if (nextNext.getType == NEWLINE)
nonNewline(nextNext)
else
(token, nextNext)
@@ -94,10 +94,11 @@ abstract class LexerHelper {
}
}
- val t = if (tokenBuffer.isEmpty)
- pullToken()
- else
- tokenBuffer.dequeue
+ val t =
+ if (tokenBuffer.isEmpty)
+ pullToken()
+ else
+ tokenBuffer.dequeue
if (reachedEof)
t
@@ -117,8 +118,8 @@ abstract class LexerHelper {
setType(tokenType)
tokenType match {
case `NEWLINE` => setText("<NEWLINE>")
- case `INDENT` => setText("<INDENT>")
- case `DEDENT` => setText("<DEDENT>")
+ case `INDENT` => setText("<INDENT>")
+ case `DEDENT` => setText("<DEDENT>")
}
}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 19e7d8c6..90881a57 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -91,10 +91,10 @@ class LowFirrtlOptimization extends CoreTransform {
}
/** Runs runs only the optimization passes needed for Verilog emission */
- @deprecated(
- "Use 'new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm)'. This will be removed in 1.4.",
- "FIRRTL 1.3"
- )
+@deprecated(
+ "Use 'new TransformManager(Forms.LowFormMinimumOptimized, Forms.LowForm)'. This will be removed in 1.4.",
+ "FIRRTL 1.3"
+)
class MinimumLowFirrtlOptimization extends CoreTransform {
def inputForm = LowForm
def outputForm = LowForm
diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala
index 3bf89885..e9a698ae 100644
--- a/src/main/scala/firrtl/Mappers.scala
+++ b/src/main/scala/firrtl/Mappers.scala
@@ -12,36 +12,35 @@ object Mappers {
}
private object PortMagnet {
implicit def forType(f: Type => Type): PortMagnet = new PortMagnet {
- override def map(port: Port): Port = port mapType f
+ override def map(port: Port): Port = port.mapType(f)
}
implicit def forString(f: String => String): PortMagnet = new PortMagnet {
- override def map(port: Port): Port = port mapString f
+ override def map(port: Port): Port = port.mapString(f)
}
}
implicit class PortMap(val _port: Port) extends AnyVal {
def map[T](f: T => T)(implicit magnet: (T => T) => PortMagnet): Port = magnet(f).map(_port)
}
-
// ********** Stmt Mappers **********
private trait StmtMagnet {
def map(stmt: Statement): Statement
}
private object StmtMagnet {
implicit def forStmt(f: Statement => Statement): StmtMagnet = new StmtMagnet {
- override def map(stmt: Statement): Statement = stmt mapStmt f
+ override def map(stmt: Statement): Statement = stmt.mapStmt(f)
}
implicit def forExp(f: Expression => Expression): StmtMagnet = new StmtMagnet {
- override def map(stmt: Statement): Statement = stmt mapExpr f
+ override def map(stmt: Statement): Statement = stmt.mapExpr(f)
}
implicit def forType(f: Type => Type): StmtMagnet = new StmtMagnet {
- override def map(stmt: Statement) : Statement = stmt mapType f
+ override def map(stmt: Statement): Statement = stmt.mapType(f)
}
implicit def forString(f: String => String): StmtMagnet = new StmtMagnet {
- override def map(stmt: Statement): Statement = stmt mapString f
+ override def map(stmt: Statement): Statement = stmt.mapString(f)
}
implicit def forInfo(f: Info => Info): StmtMagnet = new StmtMagnet {
- override def map(stmt: Statement): Statement = stmt mapInfo f
+ override def map(stmt: Statement): Statement = stmt.mapInfo(f)
}
}
implicit class StmtMap(val _stmt: Statement) extends AnyVal {
@@ -55,13 +54,13 @@ object Mappers {
}
private object ExprMagnet {
implicit def forExpr(f: Expression => Expression): ExprMagnet = new ExprMagnet {
- override def map(expr: Expression): Expression = expr mapExpr f
+ override def map(expr: Expression): Expression = expr.mapExpr(f)
}
implicit def forType(f: Type => Type): ExprMagnet = new ExprMagnet {
- override def map(expr: Expression): Expression = expr mapType f
+ override def map(expr: Expression): Expression = expr.mapType(f)
}
implicit def forWidth(f: Width => Width): ExprMagnet = new ExprMagnet {
- override def map(expr: Expression): Expression = expr mapWidth f
+ override def map(expr: Expression): Expression = expr.mapWidth(f)
}
}
implicit class ExprMap(val _expr: Expression) extends AnyVal {
@@ -74,10 +73,10 @@ object Mappers {
}
private object TypeMagnet {
implicit def forType(f: Type => Type): TypeMagnet = new TypeMagnet {
- override def map(tpe: Type): Type = tpe mapType f
+ override def map(tpe: Type): Type = tpe.mapType(f)
}
implicit def forWidth(f: Width => Width): TypeMagnet = new TypeMagnet {
- override def map(tpe: Type): Type = tpe mapWidth f
+ override def map(tpe: Type): Type = tpe.mapWidth(f)
}
}
implicit class TypeMap(val _tpe: Type) extends AnyVal {
@@ -91,7 +90,7 @@ object Mappers {
private object WidthMagnet {
implicit def forWidth(f: Width => Width): WidthMagnet = new WidthMagnet {
override def map(width: Width): Width = width match {
- case mapable: HasMapWidth => mapable mapWidth f // WIR
+ case mapable: HasMapWidth => mapable.mapWidth(f) // WIR
case other => other // Standard IR nodes
}
}
@@ -106,21 +105,21 @@ object Mappers {
}
private object ModuleMagnet {
implicit def forStmt(f: Statement => Statement): ModuleMagnet = new ModuleMagnet {
- override def map(module: DefModule): DefModule = module mapStmt f
+ override def map(module: DefModule): DefModule = module.mapStmt(f)
}
implicit def forPorts(f: Port => Port): ModuleMagnet = new ModuleMagnet {
- override def map(module: DefModule): DefModule = module mapPort f
+ override def map(module: DefModule): DefModule = module.mapPort(f)
}
implicit def forString(f: String => String): ModuleMagnet = new ModuleMagnet {
- override def map(module: DefModule): DefModule = module mapString f
+ override def map(module: DefModule): DefModule = module.mapString(f)
}
implicit def forInfo(f: Info => Info): ModuleMagnet = new ModuleMagnet {
- override def map(module: DefModule): DefModule = module mapInfo f
+ override def map(module: DefModule): DefModule = module.mapInfo(f)
}
}
implicit class ModuleMap(val _module: DefModule) extends AnyVal {
def map[T](f: T => T)(implicit magnet: (T => T) => ModuleMagnet): DefModule = magnet(f).map(_module)
- }
+ }
// ********** Circuit Mappers **********
private trait CircuitMagnet {
@@ -128,16 +127,16 @@ object Mappers {
}
private object CircuitMagnet {
implicit def forModules(f: DefModule => DefModule): CircuitMagnet = new CircuitMagnet {
- override def map(circuit: Circuit): Circuit = circuit mapModule f
+ override def map(circuit: Circuit): Circuit = circuit.mapModule(f)
}
implicit def forString(f: String => String): CircuitMagnet = new CircuitMagnet {
- override def map(circuit: Circuit): Circuit = circuit mapString f
+ override def map(circuit: Circuit): Circuit = circuit.mapString(f)
}
implicit def forInfo(f: Info => Info): CircuitMagnet = new CircuitMagnet {
- override def map(circuit: Circuit): Circuit = circuit mapInfo f
+ override def map(circuit: Circuit): Circuit = circuit.mapInfo(f)
}
}
implicit class CircuitMap(val _circuit: Circuit) extends AnyVal {
def map[T](f: T => T)(implicit magnet: (T => T) => CircuitMagnet): Circuit = magnet(f).map(_circuit)
- }
+ }
}
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala
index bb358be6..196539c8 100644
--- a/src/main/scala/firrtl/Namespace.scala
+++ b/src/main/scala/firrtl/Namespace.scala
@@ -29,8 +29,7 @@ class Namespace private {
do {
str = s"${value}_$idx"
idx += 1
- }
- while (!(tryName(str)))
+ } while (!(tryName(str)))
indices(value) = idx
str
}
@@ -55,10 +54,10 @@ object Namespace {
def buildNamespaceStmt(s: Statement): Seq[String] = s match {
case s: IsDeclaration => Seq(s.name)
case s: Conditionally => buildNamespaceStmt(s.conseq) ++ buildNamespaceStmt(s.alt)
- case s: Block => s.stmts flatMap buildNamespaceStmt
+ case s: Block => s.stmts.flatMap(buildNamespaceStmt)
case _ => Nil
}
- namespace.namespace ++= m.ports map (_.name)
+ namespace.namespace ++= m.ports.map(_.name)
m match {
case in: Module =>
namespace.namespace ++= buildNamespaceStmt(in.body)
@@ -71,11 +70,11 @@ object Namespace {
/** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */
def apply(c: Circuit): Namespace = {
val namespace = new Namespace
- namespace.namespace ++= c.modules map (_.name)
+ namespace.namespace ++= c.modules.map(_.name)
namespace
}
- /** Initializes a [[Namespace]] from arbitrary strings **/
+ /** Initializes a [[Namespace]] from arbitrary strings * */
def apply(names: Seq[String] = Nil): Namespace = {
val namespace = new Namespace
namespace.namespace ++= names
diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala
index d3075cbb..40eaa88f 100644
--- a/src/main/scala/firrtl/Parser.scala
+++ b/src/main/scala/firrtl/Parser.scala
@@ -17,7 +17,6 @@ case class InvalidStringLitException(message: String) extends ParserException(me
case class InvalidEscapeCharException(message: String) extends ParserException(message)
case class SyntaxErrorsException(message: String) extends ParserException(message)
-
object Parser extends LazyLogging {
/** Parses a file in a given filename and returns a parsed [[firrtl.ir.Circuit Circuit]] */
@@ -57,13 +56,13 @@ object Parser extends LazyLogging {
ast
}
+
/** Takes Iterator over lines of FIRRTL, returns FirrtlNode (root node is Circuit) */
def parse(lines: Iterator[String], infoMode: InfoMode = UseInfo): Circuit =
parseString(lines.mkString("\n"), infoMode)
def parse(lines: Seq[String]): Circuit = parseString(lines.mkString("\n"), UseInfo)
-
/** Parse the concrete syntax of a FIRRTL [[firrtl.ir.Circuit]], e.g.
* {{{
* """circuit Top:
@@ -106,7 +105,7 @@ object Parser extends LazyLogging {
def parse(lines: Seq[String], infoMode: InfoMode): Circuit = parse(lines.iterator, infoMode)
- def parse(text: String, infoMode: InfoMode): Circuit = parse(text split "\n", infoMode)
+ def parse(text: String, infoMode: InfoMode): Circuit = parse(text.split("\n"), infoMode)
/** Parse the concrete syntax of a FIRRTL [[firrtl.ir.Expression]], e.g.
* "add(x, y)" becomes:
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index 883692c8..baa8638a 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -15,14 +15,14 @@ object PrimOps extends LazyLogging {
def w1(e: DoPrim): Width = getWidth(t1(e))
def w2(e: DoPrim): Width = getWidth(t2(e))
def p1(e: DoPrim): Width = t1(e) match {
- case FixedType(w, p) => p
+ case FixedType(w, p) => p
case IntervalType(min, max, p) => p
- case _ => sys.error(s"Cannot get binary point from ${t1(e)}")
+ case _ => sys.error(s"Cannot get binary point from ${t1(e)}")
}
def p2(e: DoPrim): Width = t2(e) match {
- case FixedType(w, p) => p
+ case FixedType(w, p) => p
case IntervalType(min, max, p) => p
- case _ => sys.error(s"Cannot get binary point from ${t1(e)}")
+ case _ => sys.error(s"Cannot get binary point from ${t1(e)}")
}
def c1(e: DoPrim) = IntWidth(e.consts.head)
def c2(e: DoPrim) = IntWidth(e.consts(1))
@@ -37,8 +37,16 @@ object PrimOps extends LazyLogging {
(t1(e), t2(e)) match {
case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
- case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), IntWidth(1)), IsMax(p1(e), p2(e)))
- case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, l2), IsAdd(u1, u2), IsMax(p1, p2))
+ case (_: FixedType, _: FixedType) =>
+ FixedType(
+ IsAdd(
+ IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))),
+ IntWidth(1)
+ ),
+ IsMax(p1(e), p2(e))
+ )
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ IntervalType(IsAdd(l1, l2), IsAdd(u1, u2), IsMax(p1, p2))
case _ => UnknownType
}
}
@@ -49,8 +57,13 @@ object PrimOps extends LazyLogging {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
case (_: UIntType, _: UIntType) => UIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
case (_: SIntType, _: SIntType) => SIntType(IsAdd(IsMax(w1(e), w2(e)), IntWidth(1)))
- case (_: FixedType, _: FixedType) => FixedType(IsAdd(IsAdd(IsMax(p1(e), p2(e)),IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))),IntWidth(1)), IsMax(p1(e), p2(e)))
- case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) => IntervalType(IsAdd(l1, IsNeg(u2)), IsAdd(u1, IsNeg(l2)), IsMax(p1, p2))
+ case (_: FixedType, _: FixedType) =>
+ FixedType(
+ IsAdd(IsAdd(IsMax(p1(e), p2(e)), IsMax(IsAdd(w1(e), IsNeg(p1(e))), IsAdd(w2(e), IsNeg(p2(e))))), IntWidth(1)),
+ IsMax(p1(e), p2(e))
+ )
+ case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
+ IntervalType(IsAdd(l1, IsNeg(u2)), IsAdd(u1, IsNeg(l2)), IsMax(p1, p2))
case _ => UnknownType
}
override def toString = "sub"
@@ -70,7 +83,8 @@ object PrimOps extends LazyLogging {
)
case _ => UnknownType
}
- override def toString = "mul" }
+ override def toString = "mul"
+ }
/** Division */
case object Div extends PrimOp {
@@ -79,7 +93,8 @@ object PrimOps extends LazyLogging {
case (_: SIntType, _: SIntType) => SIntType(IsAdd(w1(e), IntWidth(1)))
case _ => UnknownType
}
- override def toString = "div" }
+ override def toString = "div"
+ }
/** Remainder */
case object Rem extends PrimOp {
@@ -88,7 +103,9 @@ object PrimOps extends LazyLogging {
case (_: SIntType, _: SIntType) => SIntType(MIN(w1(e), w2(e)))
case _ => UnknownType
}
- override def toString = "rem" }
+ override def toString = "rem"
+ }
+
/** Less Than */
case object Lt extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -98,7 +115,9 @@ object PrimOps extends LazyLogging {
case (_: IntervalType, _: IntervalType) => Utils.BoolType
case _ => UnknownType
}
- override def toString = "lt" }
+ override def toString = "lt"
+ }
+
/** Less Than Or Equal To */
case object Leq extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -108,7 +127,9 @@ object PrimOps extends LazyLogging {
case (_: IntervalType, _: IntervalType) => Utils.BoolType
case _ => UnknownType
}
- override def toString = "leq" }
+ override def toString = "leq"
+ }
+
/** Greater Than */
case object Gt extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -118,7 +139,9 @@ object PrimOps extends LazyLogging {
case (_: IntervalType, _: IntervalType) => Utils.BoolType
case _ => UnknownType
}
- override def toString = "gt" }
+ override def toString = "gt"
+ }
+
/** Greater Than Or Equal To */
case object Geq extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -128,7 +151,9 @@ object PrimOps extends LazyLogging {
case (_: IntervalType, _: IntervalType) => Utils.BoolType
case _ => UnknownType
}
- override def toString = "geq" }
+ override def toString = "geq"
+ }
+
/** Equal To */
case object Eq extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -138,7 +163,9 @@ object PrimOps extends LazyLogging {
case (_: IntervalType, _: IntervalType) => Utils.BoolType
case _ => UnknownType
}
- override def toString = "eq" }
+ override def toString = "eq"
+ }
+
/** Not Equal To */
case object Neq extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -148,31 +175,42 @@ object PrimOps extends LazyLogging {
case (_: IntervalType, _: IntervalType) => Utils.BoolType
case _ => UnknownType
}
- override def toString = "neq" }
+ override def toString = "neq"
+ }
+
/** Padding */
case object Pad extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => UIntType(IsMax(w1(e), c1(e)))
- case _: SIntType => SIntType(IsMax(w1(e), c1(e)))
+ case _: UIntType => UIntType(IsMax(w1(e), c1(e)))
+ case _: SIntType => SIntType(IsMax(w1(e), c1(e)))
case _: FixedType => FixedType(IsMax(w1(e), c1(e)), p1(e))
case _ => UnknownType
}
- override def toString = "pad" }
+ override def toString = "pad"
+ }
+
/** Static Shift Left */
case object Shl extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => UIntType(IsAdd(w1(e), c1(e)))
- case _: SIntType => SIntType(IsAdd(w1(e), c1(e)))
- case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), p1(e))
- case IntervalType(l, u, p) => IntervalType(IsMul(l, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), IsMul(u, Closed(BigDecimal(BigInt(1) << o1(e).toInt))), p)
+ case _: UIntType => UIntType(IsAdd(w1(e), c1(e)))
+ case _: SIntType => SIntType(IsAdd(w1(e), c1(e)))
+ case _: FixedType => FixedType(IsAdd(w1(e), c1(e)), p1(e))
+ case IntervalType(l, u, p) =>
+ IntervalType(
+ IsMul(l, Closed(BigDecimal(BigInt(1) << o1(e).toInt))),
+ IsMul(u, Closed(BigDecimal(BigInt(1) << o1(e).toInt))),
+ p
+ )
case _ => UnknownType
}
- override def toString = "shl" }
+ override def toString = "shl"
+ }
+
/** Static Shift Right */
case object Shr extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => UIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)))
- case _: SIntType => SIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)))
+ case _: UIntType => UIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)))
+ case _: SIntType => SIntType(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)))
case _: FixedType => FixedType(IsMax(IsMax(IsAdd(w1(e), IsNeg(c1(e))), IntWidth(1)), p1(e)), p1(e))
case IntervalType(l, u, IntWidth(p)) =>
val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt))
@@ -187,11 +225,12 @@ object PrimOps extends LazyLogging {
}
override def toString = "shr"
}
+
/** Dynamic Shift Left */
case object Dshl extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => UIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))))
- case _: SIntType => SIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))))
+ case _: UIntType => UIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))))
+ case _: SIntType => SIntType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))))
case _: FixedType => FixedType(IsAdd(w1(e), IsAdd(IsPow(w2(e)), Closed(-1))), p1(e))
case IntervalType(l, u, p) =>
val maxShiftAmt = IsAdd(IsPow(w2(e)), Closed(-1))
@@ -206,18 +245,20 @@ object PrimOps extends LazyLogging {
}
override def toString = "dshl"
}
+
/** Dynamic Shift Right */
case object Dshr extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => UIntType(w1(e))
- case _: SIntType => SIntType(w1(e))
+ case _: UIntType => UIntType(w1(e))
+ case _: SIntType => SIntType(w1(e))
case _: FixedType => FixedType(w1(e), p1(e))
// Decreasing magnitude -- don't need more bits
case IntervalType(l, u, p) => IntervalType(l, u, p)
- case _ => UnknownType
+ case _ => UnknownType
}
override def toString = "dshr"
}
+
/** Arithmetic Convert to Signed */
case object Cvt extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -227,6 +268,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "cvt"
}
+
/** Negate */
case object Neg extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -236,6 +278,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "neg"
}
+
/** Bitwise Complement */
case object Not extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -245,6 +288,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "not"
}
+
/** Bitwise And */
case object And extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -253,6 +297,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "and"
}
+
/** Bitwise Or */
case object Or extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -261,6 +306,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "or"
}
+
/** Bitwise Exclusive Or */
case object Xor extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
@@ -269,6 +315,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "xor"
}
+
/** Bitwise And Reduce */
case object Andr extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -277,6 +324,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "andr"
}
+
/** Bitwise Or Reduce */
case object Orr extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -285,6 +333,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "orr"
}
+
/** Bitwise Exclusive Or Reduce */
case object Xorr extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -293,22 +342,30 @@ object PrimOps extends LazyLogging {
}
override def toString = "xorr"
}
+
/** Concatenate */
case object Cat extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
- case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType, _: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(w1(e), w2(e)))
+ case (
+ _: UIntType | _: SIntType | _: FixedType | _: IntervalType,
+ _: UIntType | _: SIntType | _: FixedType | _: IntervalType
+ ) =>
+ UIntType(IsAdd(w1(e), w2(e)))
case (t1, t2) => UnknownType
}
override def toString = "cat"
}
+
/** Bit Extraction */
case object Bits extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) => UIntType(IsAdd(IsAdd(c1(e), IsNeg(c2(e))), IntWidth(1)))
+ case (_: UIntType | _: SIntType | _: FixedType | _: IntervalType) =>
+ UIntType(IsAdd(IsAdd(c1(e), IsNeg(c2(e))), IntWidth(1)))
case _ => UnknownType
}
override def toString = "bits"
}
+
/** Head */
case object Head extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -317,6 +374,7 @@ object PrimOps extends LazyLogging {
}
override def toString = "head"
}
+
/** Tail */
case object Tail extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
@@ -325,20 +383,22 @@ object PrimOps extends LazyLogging {
}
override def toString = "tail"
}
- /** Increase Precision **/
+
+ /** Increase Precision * */
case object IncP extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: FixedType => FixedType(IsAdd(w1(e),c1(e)), IsAdd(p1(e), c1(e)))
+ case _: FixedType => FixedType(IsAdd(w1(e), c1(e)), IsAdd(p1(e), c1(e)))
// Keeps the same exact value, but adds more precision for the future i.e. aaa.bbb -> aaa.bbb00
case IntervalType(l, u, p) => IntervalType(l, u, IsAdd(p, c1(e)))
- case _ => UnknownType
+ case _ => UnknownType
}
override def toString = "incp"
}
- /** Decrease Precision **/
+
+ /** Decrease Precision * */
case object DecP extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: FixedType => FixedType(IsAdd(w1(e),IsNeg(c1(e))), IsAdd(p1(e), IsNeg(c1(e))))
+ case _: FixedType => FixedType(IsAdd(w1(e), IsNeg(c1(e))), IsAdd(p1(e), IsNeg(c1(e))))
case IntervalType(l, u, IntWidth(p)) =>
val shiftMul = Closed(BigDecimal(1) / BigDecimal(BigInt(1) << o1(e).toInt))
// BP is inferred at this point
@@ -355,7 +415,8 @@ object PrimOps extends LazyLogging {
}
override def toString = "decp"
}
- /** Set Precision **/
+
+ /** Set Precision * */
case object SetP extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
case _: FixedType => FixedType(IsAdd(c1(e), IsAdd(w1(e), IsNeg(p1(e)))), c1(e))
@@ -369,84 +430,98 @@ object PrimOps extends LazyLogging {
}
override def toString = "setp"
}
+
/** Interpret As UInt */
case object AsUInt extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => UIntType(w1(e))
- case _: SIntType => UIntType(w1(e))
+ case _: UIntType => UIntType(w1(e))
+ case _: SIntType => UIntType(w1(e))
case _: FixedType => UIntType(w1(e))
- case ClockType => UIntType(IntWidth(1))
+ case ClockType => UIntType(IntWidth(1))
case AsyncResetType => UIntType(IntWidth(1))
- case ResetType => UIntType(IntWidth(1))
- case AnalogType(w) => UIntType(w1(e))
+ case ResetType => UIntType(IntWidth(1))
+ case AnalogType(w) => UIntType(w1(e))
case _: IntervalType => UIntType(w1(e))
case _ => UnknownType
}
override def toString = "asUInt"
}
+
/** Interpret As SInt */
case object AsSInt extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => SIntType(w1(e))
- case _: SIntType => SIntType(w1(e))
+ case _: UIntType => SIntType(w1(e))
+ case _: SIntType => SIntType(w1(e))
case _: FixedType => SIntType(w1(e))
- case ClockType => SIntType(IntWidth(1))
+ case ClockType => SIntType(IntWidth(1))
case AsyncResetType => SIntType(IntWidth(1))
- case ResetType => SIntType(IntWidth(1))
- case _: AnalogType => SIntType(w1(e))
+ case ResetType => SIntType(IntWidth(1))
+ case _: AnalogType => SIntType(w1(e))
case _: IntervalType => SIntType(w1(e))
case _ => UnknownType
}
override def toString = "asSInt"
}
+
/** Interpret As Clock */
case object AsClock extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
case _: UIntType => ClockType
case _: SIntType => ClockType
- case ClockType => ClockType
+ case ClockType => ClockType
case AsyncResetType => ClockType
- case ResetType => ClockType
- case _: AnalogType => ClockType
+ case ResetType => ClockType
+ case _: AnalogType => ClockType
case _: IntervalType => ClockType
case _ => UnknownType
}
override def toString = "asClock"
}
+
/** Interpret As AsyncReset */
case object AsAsyncReset extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType | _: IntervalType | _: FixedType => AsyncResetType
+ case _: UIntType | _: SIntType | _: AnalogType | ClockType | AsyncResetType | ResetType | _: IntervalType |
+ _: FixedType =>
+ AsyncResetType
case _ => UnknownType
}
override def toString = "asAsyncReset"
}
- /** Interpret as Fixed Point **/
+
+ /** Interpret as Fixed Point * */
case object AsFixedPoint extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
- case _: UIntType => FixedType(w1(e), c1(e))
- case _: SIntType => FixedType(w1(e), c1(e))
+ case _: UIntType => FixedType(w1(e), c1(e))
+ case _: SIntType => FixedType(w1(e), c1(e))
case _: FixedType => FixedType(w1(e), c1(e))
case ClockType => FixedType(IntWidth(1), c1(e))
case _: AnalogType => FixedType(w1(e), c1(e))
case AsyncResetType => FixedType(IntWidth(1), c1(e))
- case ResetType => FixedType(IntWidth(1), c1(e))
+ case ResetType => FixedType(IntWidth(1), c1(e))
case _: IntervalType => FixedType(w1(e), c1(e))
case _ => UnknownType
}
override def toString = "asFixedPoint"
}
- /** Interpret as Interval (closed lower bound, closed upper bound, binary point) **/
+
+ /** Interpret as Interval (closed lower bound, closed upper bound, binary point) * */
case object AsInterval extends PrimOp {
override def propagateType(e: DoPrim): Type = t1(e) match {
// Chisel shifts up and rounds first.
- case _: UIntType | _: SIntType | _: FixedType | ClockType | AsyncResetType | ResetType | _: AnalogType | _: IntervalType =>
- IntervalType(Closed(BigDecimal(o1(e))/BigDecimal(BigInt(1) << o3(e).toInt)), Closed(BigDecimal(o2(e))/BigDecimal(BigInt(1) << o3(e).toInt)), IntWidth(o3(e)))
+ case _: UIntType | _: SIntType | _: FixedType | ClockType | AsyncResetType | ResetType | _: AnalogType |
+ _: IntervalType =>
+ IntervalType(
+ Closed(BigDecimal(o1(e)) / BigDecimal(BigInt(1) << o3(e).toInt)),
+ Closed(BigDecimal(o2(e)) / BigDecimal(BigInt(1) << o3(e).toInt)),
+ IntWidth(o3(e))
+ )
case _ => UnknownType
}
override def toString = "asInterval"
}
- /** Try to fit the first argument into the type of the smaller argument **/
+
+ /** Try to fit the first argument into the type of the smaller argument * */
case object Squeeze extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) =>
@@ -457,15 +532,17 @@ object PrimOps extends LazyLogging {
}
override def toString = "squz"
}
- /** Wrap First Operand Around Range/Width of Second Operand **/
+
+ /** Wrap First Operand Around Range/Width of Second Operand * */
case object Wrap extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) => IntervalType(l2, u2, p1)
- case _ => UnknownType
+ case _ => UnknownType
}
override def toString = "wrap"
}
- /** Clip First Operand At Range/Width of Second Operand **/
+
+ /** Clip First Operand At Range/Width of Second Operand * */
case object Clip extends PrimOp {
override def propagateType(e: DoPrim): Type = (t1(e), t2(e)) match {
case (IntervalType(l1, u1, p1), IntervalType(l2, u2, _)) =>
@@ -485,19 +562,20 @@ object PrimOps extends LazyLogging {
)
// format: on
private lazy val strToPrimOp: Map[String, PrimOp] = {
- builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap
+ builtinPrimOps.map { case op: PrimOp => op.toString -> op }.toMap
}
/** Seq of String representations of [[ir.PrimOp]]s */
- lazy val listing: Seq[String] = builtinPrimOps map (_.toString)
+ lazy val listing: Seq[String] = builtinPrimOps.map(_.toString)
+
/** Gets the corresponding [[ir.PrimOp]] from its String representation */
def fromString(op: String): PrimOp = strToPrimOp(op)
// Width Constraint Functions
- def PLUS(w1: Width, w2: Width): Constraint = IsAdd(w1, w2)
- def MAX(w1: Width, w2: Width): Constraint = IsMax(w1, w2)
+ def PLUS(w1: Width, w2: Width): Constraint = IsAdd(w1, w2)
+ def MAX(w1: Width, w2: Width): Constraint = IsMax(w1, w2)
def MINUS(w1: Width, w2: Width): Constraint = IsAdd(w1, IsNeg(w2))
- def MIN(w1: Width, w2: Width): Constraint = IsMin(w1, w2)
+ def MIN(w1: Width, w2: Width): Constraint = IsMin(w1, w2)
def set_primop_type(e: DoPrim): DoPrim = DoPrim(e.op, e.args, e.consts, e.op.propagateType(e))
}
diff --git a/src/main/scala/firrtl/RenameMap.scala b/src/main/scala/firrtl/RenameMap.scala
index 9c848bca..d85998b5 100644
--- a/src/main/scala/firrtl/RenameMap.scala
+++ b/src/main/scala/firrtl/RenameMap.scala
@@ -38,9 +38,9 @@ object RenameMap {
*/
// TODO This should probably be refactored into immutable and mutable versions
final class RenameMap private (
- val underlying: mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]](),
- val chained: Option[RenameMap] = None
-) {
+ val underlying: mutable.HashMap[CompleteTarget, Seq[CompleteTarget]] =
+ mutable.HashMap[CompleteTarget, Seq[CompleteTarget]](),
+ val chained: Option[RenameMap] = None) {
/** Chain a [[RenameMap]] with this [[RenameMap]]
* @param next the map to chain with this map
@@ -100,7 +100,7 @@ final class RenameMap private (
* $noteDistinct
*/
def recordAll(map: collection.Map[CompleteTarget, Seq[CompleteTarget]]): Unit =
- map.foreach{
+ map.foreach {
case (from: IsComponent, tos: Seq[_]) => completeRename(from, tos)
case (from: IsModule, tos: Seq[_]) => completeRename(from, tos)
case (from: CircuitTarget, tos: Seq[_]) => completeRename(from, tos)
@@ -128,7 +128,7 @@ final class RenameMap private (
* @param key Target referencing the original circuit
* @return Optionally return sequence of targets that key remaps to
*/
- def get(key: CircuitTarget): Option[Seq[CircuitTarget]] = completeGet(key).map( _.map { case x: CircuitTarget => x } )
+ def get(key: CircuitTarget): Option[Seq[CircuitTarget]] = completeGet(key).map(_.map { case x: CircuitTarget => x })
/** Get renames of a [[firrtl.annotations.IsMember IsMember]]
* @param key Target referencing the original member of the circuit
@@ -136,12 +136,11 @@ final class RenameMap private (
*/
def get(key: IsMember): Option[Seq[IsMember]] = completeGet(key).map { _.map { case x: IsMember => x } }
-
/** Create new [[RenameMap]] that merges this and renameMap
* @param renameMap
* @return
*/
- def ++ (renameMap: RenameMap): RenameMap = {
+ def ++(renameMap: RenameMap): RenameMap = {
val newChained = if (chained.nonEmpty && renameMap.chained.nonEmpty) {
Some(chained.get ++ renameMap.chained.get)
} else {
@@ -168,7 +167,7 @@ final class RenameMap private (
def getReverseRenameMap: RenameMap = {
val reverseMap = mutable.HashMap[CompleteTarget, Seq[CompleteTarget]]()
- underlying.keysIterator.foreach{ key =>
+ underlying.keysIterator.foreach { key =>
apply(key).foreach { v =>
reverseMap(v) = key +: reverseMap.getOrElse(v, Nil)
}
@@ -181,8 +180,9 @@ final class RenameMap private (
/** Serialize the underlying remapping of keys to new targets
* @return
*/
- def serialize: String = underlying.map { case (k, v) =>
- k.serialize + "=>" + v.map(_.serialize).mkString(", ")
+ def serialize: String = underlying.map {
+ case (k, v) =>
+ k.serialize + "=>" + v.map(_.serialize).mkString(", ")
}.mkString("\n")
/** Records which local InstanceTargets will require modification.
@@ -229,7 +229,8 @@ final class RenameMap private (
val hereRet = (chainedRet.flatMap { target =>
hereCompleteGet(target).getOrElse(Seq(target))
}).distinct
- if (hereRet.size == 1 && hereRet.head == key) { None } else { Some(hereRet) }
+ if (hereRet.size == 1 && hereRet.head == key) { None }
+ else { Some(hereRet) }
}
} else {
hereCompleteGet(key)
@@ -238,10 +239,11 @@ final class RenameMap private (
private def hereCompleteGet(key: CompleteTarget): Option[Seq[CompleteTarget]] = {
val errors = mutable.ArrayBuffer[String]()
- val ret = if(hasChanges) {
+ val ret = if (hasChanges) {
val ret = recursiveGet(errors)(key)
- if(errors.nonEmpty) { throw IllegalRenameException(errors.mkString("\n")) }
- if(ret.size == 1 && ret.head == key) { None } else { Some(ret) }
+ if (errors.nonEmpty) { throw IllegalRenameException(errors.mkString("\n")) }
+ if (ret.size == 1 && ret.head == key) { None }
+ else { Some(ret) }
} else { None }
ret
}
@@ -266,50 +268,54 @@ final class RenameMap private (
* @return Renamed targets if a match is found, otherwise None
*/
private def referenceGet(errors: mutable.ArrayBuffer[String])(key: ReferenceTarget): Option[Seq[IsComponent]] = {
- def traverseTokens(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseTokensCache.getOrElseUpdate(key, {
- if (underlying.contains(key)) {
- Some(underlying(key).flatMap {
- case comp: IsComponent => Some(comp)
- case other =>
- errors += s"reference ${key.targetParent} cannot be renamed to a non-component ${other}"
- None
- })
- } else {
- key match {
- case t: ReferenceTarget if t.component.nonEmpty =>
- val last = t.component.last
- val parent = t.copy(component = t.component.dropRight(1))
- traverseTokens(parent).map(_.flatMap { x =>
- (x, last) match {
- case (t2: InstanceTarget, Field(f)) => Some(t2.ref(f))
- case (t2: ReferenceTarget, Field(f)) => Some(t2.field(f))
- case (t2: ReferenceTarget, Index(i)) => Some(t2.index(i))
- case other =>
- errors += s"Illegal rename: ${key.targetParent} cannot be renamed to ${other._1} - must rename $key directly"
- None
- }
- })
- case t: ReferenceTarget => None
+ def traverseTokens(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseTokensCache.getOrElseUpdate(
+ key, {
+ if (underlying.contains(key)) {
+ Some(underlying(key).flatMap {
+ case comp: IsComponent => Some(comp)
+ case other =>
+ errors += s"reference ${key.targetParent} cannot be renamed to a non-component ${other}"
+ None
+ })
+ } else {
+ key match {
+ case t: ReferenceTarget if t.component.nonEmpty =>
+ val last = t.component.last
+ val parent = t.copy(component = t.component.dropRight(1))
+ traverseTokens(parent).map(_.flatMap { x =>
+ (x, last) match {
+ case (t2: InstanceTarget, Field(f)) => Some(t2.ref(f))
+ case (t2: ReferenceTarget, Field(f)) => Some(t2.field(f))
+ case (t2: ReferenceTarget, Index(i)) => Some(t2.index(i))
+ case other =>
+ errors += s"Illegal rename: ${key.targetParent} cannot be renamed to ${other._1} - must rename $key directly"
+ None
+ }
+ })
+ case t: ReferenceTarget => None
+ }
}
}
- })
-
- def traverseHierarchy(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseHierarchyCache.getOrElseUpdate(key, {
- val tokenRenamed = traverseTokens(key)
- if (tokenRenamed.nonEmpty) {
- tokenRenamed
- } else {
- key match {
- case t: ReferenceTarget if t.isLocal => None
- case t: ReferenceTarget =>
- val encapsulatingInstance = t.path.head._1.value
- val stripped = t.stripHierarchy(1)
- traverseHierarchy(stripped).map(_.map {
- _.addHierarchy(t.module, encapsulatingInstance)
- })
+ )
+
+ def traverseHierarchy(key: ReferenceTarget): Option[Seq[IsComponent]] = traverseHierarchyCache.getOrElseUpdate(
+ key, {
+ val tokenRenamed = traverseTokens(key)
+ if (tokenRenamed.nonEmpty) {
+ tokenRenamed
+ } else {
+ key match {
+ case t: ReferenceTarget if t.isLocal => None
+ case t: ReferenceTarget =>
+ val encapsulatingInstance = t.path.head._1.value
+ val stripped = t.stripHierarchy(1)
+ traverseHierarchy(stripped).map(_.map {
+ _.addHierarchy(t.module, encapsulatingInstance)
+ })
+ }
}
}
- })
+ )
traverseHierarchy(key)
}
@@ -335,64 +341,73 @@ final class RenameMap private (
* @return Renamed targets if a match is found, otherwise None
*/
private def instanceGet(errors: mutable.ArrayBuffer[String])(key: InstanceTarget): Option[Seq[IsModule]] = {
- def traverseLeft(key: InstanceTarget): Option[Seq[IsModule]] = traverseLeftCache.getOrElseUpdate(key, {
- val getOpt = underlying.get(key)
-
- if (getOpt.nonEmpty) {
- getOpt.map(_.flatMap {
- case isMod: IsModule => Some(isMod)
- case other =>
- errors += s"IsModule: $key cannot be renamed to non-IsModule $other"
- None
- })
- } else {
- key match {
- case t: InstanceTarget if t.isLocal => None
- case t: InstanceTarget =>
- val (Instance(outerInst), OfModule(outerMod)) = t.path.head
- val stripped = t.copy(path = t.path.tail, module = outerMod)
- traverseLeft(stripped).map(_.map {
- case absolute if absolute.circuit == absolute.module => absolute
- case relative => relative.addHierarchy(t.module, outerInst)
- })
+ def traverseLeft(key: InstanceTarget): Option[Seq[IsModule]] = traverseLeftCache.getOrElseUpdate(
+ key, {
+ val getOpt = underlying.get(key)
+
+ if (getOpt.nonEmpty) {
+ getOpt.map(_.flatMap {
+ case isMod: IsModule => Some(isMod)
+ case other =>
+ errors += s"IsModule: $key cannot be renamed to non-IsModule $other"
+ None
+ })
+ } else {
+ key match {
+ case t: InstanceTarget if t.isLocal => None
+ case t: InstanceTarget =>
+ val (Instance(outerInst), OfModule(outerMod)) = t.path.head
+ val stripped = t.copy(path = t.path.tail, module = outerMod)
+ traverseLeft(stripped).map(_.map {
+ case absolute if absolute.circuit == absolute.module => absolute
+ case relative => relative.addHierarchy(t.module, outerInst)
+ })
+ }
}
}
- })
-
- def traverseRight(key: InstanceTarget): Option[Seq[IsModule]] = traverseRightCache.getOrElseUpdate(key, {
- val findLeft = traverseLeft(key)
- if (findLeft.isDefined) {
- findLeft
- } else {
- key match {
- case t: InstanceTarget if t.isLocal => None
- case t: InstanceTarget =>
- val (Instance(i), OfModule(m)) = t.path.last
- val parent = t.copy(path = t.path.dropRight(1), instance = i, ofModule = m)
- traverseRight(parent).map(_.map(_.instOf(t.instance, t.ofModule)))
+ )
+
+ def traverseRight(key: InstanceTarget): Option[Seq[IsModule]] = traverseRightCache.getOrElseUpdate(
+ key, {
+ val findLeft = traverseLeft(key)
+ if (findLeft.isDefined) {
+ findLeft
+ } else {
+ key match {
+ case t: InstanceTarget if t.isLocal => None
+ case t: InstanceTarget =>
+ val (Instance(i), OfModule(m)) = t.path.last
+ val parent = t.copy(path = t.path.dropRight(1), instance = i, ofModule = m)
+ traverseRight(parent).map(_.map(_.instOf(t.instance, t.ofModule)))
+ }
}
}
- })
+ )
traverseRight(key)
}
private def circuitGet(errors: mutable.ArrayBuffer[String])(key: CircuitTarget): Seq[CircuitTarget] = {
- underlying.get(key).map(_.flatMap {
- case c: CircuitTarget => Some(c)
- case other =>
- errors += s"Illegal rename: $key cannot be renamed to non-circuit target: $other"
- None
- }).getOrElse(Seq(key))
+ underlying
+ .get(key)
+ .map(_.flatMap {
+ case c: CircuitTarget => Some(c)
+ case other =>
+ errors += s"Illegal rename: $key cannot be renamed to non-circuit target: $other"
+ None
+ })
+ .getOrElse(Seq(key))
}
private def moduleGet(errors: mutable.ArrayBuffer[String])(key: ModuleTarget): Option[Seq[IsModule]] = {
- underlying.get(key).map(_.flatMap {
- case mod: IsModule => Some(mod)
- case other =>
- errors += s"Illegal rename: $key cannot be renamed to non-module target: $other"
- None
- })
+ underlying
+ .get(key)
+ .map(_.flatMap {
+ case mod: IsModule => Some(mod)
+ case other =>
+ errors += s"Illegal rename: $key cannot be renamed to non-module target: $other"
+ None
+ })
}
// the possible results returned by ofModuleGet
@@ -438,10 +453,11 @@ final class RenameMap private (
private def ofModuleGet(errors: mutable.ArrayBuffer[String])(key: IsComponent): OfModuleRenameResult = {
val circuit = key.circuit
def renameOfModules(
- path: Seq[(Instance, OfModule)],
- foundRename: Boolean,
+ path: Seq[(Instance, OfModule)],
+ foundRename: Boolean,
newCircuitOpt: Option[String],
- children: Seq[(Instance, OfModule)]): OfModuleRenameResult = {
+ children: Seq[(Instance, OfModule)]
+ ): OfModuleRenameResult = {
if (path.isEmpty && foundRename) {
RenamedOfModules(children)
} else if (path.isEmpty) {
@@ -489,15 +505,15 @@ final class RenameMap private (
* @return Renamed targets
*/
private def recursiveGet(errors: mutable.ArrayBuffer[String])(key: CompleteTarget): Seq[CompleteTarget] = {
- if(getCache.contains(key)) {
+ if (getCache.contains(key)) {
getCache(key)
} else {
// rename just the component portion; path/ref/component for ReferenceTargets or path/instance for InstanceTargets
val componentRename = key match {
- case t: CircuitTarget => None
- case t: ModuleTarget => None
- case t: InstanceTarget => instanceGet(errors)(t)
+ case t: CircuitTarget => None
+ case t: ModuleTarget => None
+ case t: InstanceTarget => instanceGet(errors)(t)
case ref: ReferenceTarget if ref.isLocal => referenceGet(errors)(ref)
case ref @ ReferenceTarget(c, m, p, r, t) =>
val (Instance(inst), OfModule(ofMod)) = p.last
@@ -510,7 +526,6 @@ final class RenameMap private (
}
}
-
// if no component rename was found, look for Module renames; root module/OfModules in path
val moduleRename = if (componentRename.isDefined) {
componentRename
@@ -522,7 +537,8 @@ final class RenameMap private (
ofModuleGet(errors)(t) match {
case AbsoluteOfModule(absolute) =>
t match {
- case ref: ReferenceTarget => Some(Seq(ref.copy(circuit = absolute.circuit, module = absolute.module, path = absolute.asPath)))
+ case ref: ReferenceTarget =>
+ Some(Seq(ref.copy(circuit = absolute.circuit, module = absolute.module, path = absolute.asPath)))
case inst: InstanceTarget => Some(Seq(absolute))
}
case RenamedOfModules(children) =>
@@ -532,14 +548,16 @@ final class RenameMap private (
val newPath = mod.asPath ++ children
t match {
- case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath)
+ case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath)
case inst: InstanceTarget =>
val (Instance(newInst), OfModule(newOfMod)) = newPath.last
- inst.copy(circuit = mod.circuit,
+ inst.copy(
+ circuit = mod.circuit,
module = mod.module,
path = newPath.dropRight(1),
instance = newInst,
- ofModule = newOfMod)
+ ofModule = newOfMod
+ )
}
}
Some(result)
@@ -551,14 +569,16 @@ final class RenameMap private (
val newPath = mod.asPath ++ children
t match {
- case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath)
+ case ref: ReferenceTarget => ref.copy(circuit = mod.circuit, module = mod.module, path = newPath)
case inst: InstanceTarget =>
val (Instance(newInst), OfModule(newOfMod)) = newPath.last
- inst.copy(circuit = mod.circuit,
+ inst.copy(
+ circuit = mod.circuit,
module = mod.module,
path = newPath.dropRight(1),
instance = newInst,
- ofModule = newOfMod)
+ ofModule = newOfMod
+ )
}
})
}
@@ -579,8 +599,8 @@ final class RenameMap private (
circuitGet(errors)(CircuitTarget(t.circuit)).map {
case CircuitTarget(c) =>
t match {
- case ref: ReferenceTarget => ref.copy(circuit = c)
- case inst: InstanceTarget => inst.copy(circuit = c)
+ case ref: ReferenceTarget => ref.copy(circuit = c)
+ case inst: InstanceTarget => inst.copy(circuit = c)
}
}
}
@@ -597,7 +617,7 @@ final class RenameMap private (
* @param tos
*/
private def completeRename(from: CompleteTarget, tos: Seq[CompleteTarget]): Unit = {
- tos.foreach{recordSensitivity(from, _)}
+ tos.foreach { recordSensitivity(from, _) }
val existing = underlying.getOrElse(from, Vector.empty)
val updated = (existing ++ tos).distinct
underlying(from) = updated
@@ -625,29 +645,30 @@ final class RenameMap private (
def delete(name: ComponentName): Unit = underlying(name) = Seq.empty
def addMap(map: collection.Map[Named, Seq[Named]]): Unit =
- recordAll(map.map { case (key, values) => (Target.convertNamed2Target(key), values.map(Target.convertNamed2Target)) })
+ recordAll(map.map {
+ case (key, values) => (Target.convertNamed2Target(key), values.map(Target.convertNamed2Target))
+ })
def get(key: CircuitName): Option[Seq[CircuitName]] = {
- get(Target.convertCircuitName2CircuitTarget(key)).map(_.collect{ case c: CircuitTarget => c.toNamed })
+ get(Target.convertCircuitName2CircuitTarget(key)).map(_.collect { case c: CircuitTarget => c.toNamed })
}
def get(key: ModuleName): Option[Seq[ModuleName]] = {
- get(Target.convertModuleName2ModuleTarget(key)).map(_.collect{ case m: ModuleTarget => m.toNamed })
+ get(Target.convertModuleName2ModuleTarget(key)).map(_.collect { case m: ModuleTarget => m.toNamed })
}
def get(key: ComponentName): Option[Seq[ComponentName]] = {
- get(Target.convertComponentName2ReferenceTarget(key)).map(_.collect{ case c: IsComponent => c.toNamed })
+ get(Target.convertComponentName2ReferenceTarget(key)).map(_.collect { case c: IsComponent => c.toNamed })
}
def get(key: Named): Option[Seq[Named]] = key match {
case t: CompleteTarget => get(t)
- case other => get(key.toTarget).map(_.collect{ case c: IsComponent => c.toNamed })
+ case other => get(key.toTarget).map(_.collect { case c: IsComponent => c.toNamed })
}
-
// Mutable helpers - APIs that set these are deprecated!
private var circuitName: String = ""
- private var moduleName: String = ""
+ private var moduleName: String = ""
/** Sets mutable state to record current module we are visiting
* @param module
@@ -673,7 +694,7 @@ final class RenameMap private (
def rename(from: String, tos: Seq[String]): Unit = {
val mn = ModuleName(moduleName, CircuitName(circuitName))
val fromName = ComponentName(from, mn).toTarget
- val tosName = tos map { to => ComponentName(to, mn).toTarget }
+ val tosName = tos.map { to => ComponentName(to, mn).toTarget }
record(fromName, tosName)
}
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index e9af3365..bc285ef3 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -21,24 +21,22 @@ object seqCat {
case 1 => args.head
case 2 => DoPrim(PrimOps.Cat, args, Nil, UIntType(UnknownWidth))
case _ =>
- val (high, low) = args splitAt (args.length / 2)
+ val (high, low) = args.splitAt(args.length / 2)
DoPrim(PrimOps.Cat, Seq(seqCat(high), seqCat(low)), Nil, UIntType(UnknownWidth))
}
}
/** Given an expression, return an expression consisting of all sub-expressions
- * concatenated (or flattened).
- */
+ * concatenated (or flattened).
+ */
object toBits {
def apply(e: Expression): Expression = e match {
case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex)
case t => Utils.error(s"Invalid operand expression for toBits: $e")
}
private def hiercat(e: Expression): Expression = e.tpe match {
- case t: VectorType => seqCat((0 until t.size).reverse map (i =>
- hiercat(WSubIndex(e, i, t.tpe, UnknownFlow))))
- case t: BundleType => seqCat(t.fields map (f =>
- hiercat(WSubField(e, f.name, f.tpe, UnknownFlow))))
+ case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UnknownFlow))))
+ case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UnknownFlow))))
case t: GroundType => DoPrim(AsUInt, Seq(e), Seq.empty, UnknownType)
case t => Utils.error(s"Unknown type encountered in toBits: $e")
}
@@ -53,12 +51,12 @@ object getWidth {
}
object bitWidth {
- def apply(dt: Type): BigInt = widthOf(dt)
+ def apply(dt: Type): BigInt = widthOf(dt)
private def widthOf(dt: Type): BigInt = dt match {
case t: VectorType => t.size * bitWidth(t.tpe)
- case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_)
+ case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_ + _)
case GroundType(IntWidth(width)) => width
- case t => Utils.error(s"Unknown type encountered in bitWidth: $dt")
+ case t => Utils.error(s"Unknown type encountered in bitWidth: $dt")
}
}
@@ -88,32 +86,28 @@ object fromBits {
}
Block(fbits._2)
}
- private def getPartGround(lhs: Expression,
- lhst: Type,
- rhs: Expression,
- offset: BigInt): (BigInt, Seq[Statement]) = {
+ private def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = {
val intWidth = bitWidth(lhst)
val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset + intWidth - 1, offset), UnknownType)
val rhsConnect = castRhs(lhst, sel)
(offset + intWidth, Seq(Connect(NoInfo, lhs, rhsConnect)))
}
- private def getPart(lhs: Expression,
- lhst: Type,
- rhs: Expression,
- offset: BigInt): (BigInt, Seq[Statement]) =
+ private def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) =
lhst match {
- case t: VectorType => (0 until t.size foldLeft( (offset, Seq[Statement]()) )) {
- case ((curOffset, stmts), i) =>
- val subidx = WSubIndex(lhs, i, t.tpe, UnknownFlow)
- val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset)
- (tmpOffset, stmts ++ substmts)
- }
- case t: BundleType => (t.fields foldRight( (offset, Seq[Statement]()) )) {
- case (f, (curOffset, stmts)) =>
- val subfield = WSubField(lhs, f.name, f.tpe, UnknownFlow)
- val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset)
- (tmpOffset, stmts ++ substmts)
- }
+ case t: VectorType =>
+ ((0 until t.size).foldLeft((offset, Seq[Statement]()))) {
+ case ((curOffset, stmts), i) =>
+ val subidx = WSubIndex(lhs, i, t.tpe, UnknownFlow)
+ val (tmpOffset, substmts) = getPart(subidx, t.tpe, rhs, curOffset)
+ (tmpOffset, stmts ++ substmts)
+ }
+ case t: BundleType =>
+ (t.fields.foldRight((offset, Seq[Statement]()))) {
+ case (f, (curOffset, stmts)) =>
+ val subfield = WSubField(lhs, f.name, f.tpe, UnknownFlow)
+ val (tmpOffset, substmts) = getPart(subfield, f.tpe, rhs, curOffset)
+ (tmpOffset, stmts ++ substmts)
+ }
case t: GroundType => getPartGround(lhs, t, rhs, offset)
case t => Utils.error(s"Unknown type encountered in fromBits: $lhst")
}
@@ -129,6 +123,7 @@ object flattenType {
}
object Utils extends LazyLogging {
+
/** Unwind the causal chain until we hit the initial exception (which may be the first).
*
* @param maybeException - possible exception triggering the error,
@@ -157,13 +152,16 @@ object Utils extends LazyLogging {
*
* @param message - possible string to emit,
* @param exception - possible exception triggering the error.
- */
+ */
def throwInternalError(message: String = "", exception: Option[Exception] = None) = {
// We'll get the first exception in the chain, keeping it intact.
val first = true
val throwable = getThrowable(exception, true)
val string = if (message.nonEmpty) message + "\n" else message
- error("Internal Error! %sPlease file an issue at https://github.com/ucb-bar/firrtl/issues".format(string), throwable)
+ error(
+ "Internal Error! %sPlease file an issue at https://github.com/ucb-bar/firrtl/issues".format(string),
+ throwable
+ )
}
def time[R](block: => R): (Double, R) = {
@@ -177,9 +175,9 @@ object Utils extends LazyLogging {
/** Removes all [[firrtl.ir.EmptyStmt]] statements and condenses
* [[firrtl.ir.Block]] statements.
*/
- def squashEmpty(s: Statement): Statement = s map squashEmpty match {
+ def squashEmpty(s: Statement): Statement = s.map(squashEmpty) match {
case Block(stmts) =>
- val newStmts = stmts filter (_ != EmptyStmt)
+ val newStmts = stmts.filter(_ != EmptyStmt)
newStmts.size match {
case 0 => EmptyStmt
case 1 => newStmts.head
@@ -191,43 +189,46 @@ object Utils extends LazyLogging {
/** Returns true if PrimOp is a cast, false otherwise */
def isCast(op: PrimOp): Boolean = op match {
case AsUInt | AsSInt | AsClock | AsAsyncReset | AsFixedPoint => true
- case _ => false
+ case _ => false
}
+
/** Returns true if Expression is a casting PrimOp, false otherwise */
def isCast(expr: Expression): Boolean = expr match {
- case DoPrim(op, _,_,_) if isCast(op) => true
- case _ => false
+ case DoPrim(op, _, _, _) if isCast(op) => true
+ case _ => false
}
/** Returns true if PrimOp is a BitExtraction, false otherwise */
def isBitExtract(op: PrimOp): Boolean = op match {
case Bits | Head | Tail | Shr => true
- case _ => false
+ case _ => false
}
+
/** Returns true if Expression is a Bits PrimOp, false otherwise */
def isBitExtract(expr: Expression): Boolean = expr match {
- case DoPrim(op, _,_, UIntType(_)) if isBitExtract(op) => true
- case _ => false
+ case DoPrim(op, _, _, UIntType(_)) if isBitExtract(op) => true
+ case _ => false
}
- /** Provide a nice name to create a temporary **/
+ /** Provide a nice name to create a temporary * */
def niceName(e: Expression): String = niceName(1)(e)
def niceName(depth: Int)(e: Expression): String = {
e match {
case Reference(name, _, _, _) if name(0) == '_' => name
- case Reference(name, _, _, _) => "_" + name
+ case Reference(name, _, _, _) => "_" + name
case SubAccess(expr, index, _, _) if depth <= 0 => niceName(depth)(expr)
- case SubAccess(expr, index, _, _) => niceName(depth)(expr) + niceName(depth - 1)(index)
- case SubField(expr, field, _, _) => niceName(depth)(expr) + "_" + field
- case SubIndex(expr, index, _, _) => niceName(depth)(expr) + "_" + index
- case DoPrim(op, args, consts, _) if depth <= 0 => "_" + op
- case DoPrim(op, args, consts, _) => "_" + op + (args.map(niceName(depth - 1)) ++ consts.map("_" + _)).mkString("")
- case Mux(cond, tval, fval, _) if depth <= 0 => "_mux"
- case Mux(cond, tval, fval, _) => "_mux" + Seq(cond, tval, fval).map(niceName(depth - 1)).mkString("")
- case UIntLiteral(value, _) => "_" + value
- case SIntLiteral(value, _) => "_" + value
+ case SubAccess(expr, index, _, _) => niceName(depth)(expr) + niceName(depth - 1)(index)
+ case SubField(expr, field, _, _) => niceName(depth)(expr) + "_" + field
+ case SubIndex(expr, index, _, _) => niceName(depth)(expr) + "_" + index
+ case DoPrim(op, args, consts, _) if depth <= 0 => "_" + op
+ case DoPrim(op, args, consts, _) => "_" + op + (args.map(niceName(depth - 1)) ++ consts.map("_" + _)).mkString("")
+ case Mux(cond, tval, fval, _) if depth <= 0 => "_mux"
+ case Mux(cond, tval, fval, _) => "_mux" + Seq(cond, tval, fval).map(niceName(depth - 1)).mkString("")
+ case UIntLiteral(value, _) => "_" + value
+ case SIntLiteral(value, _) => "_" + value
}
}
+
/** Maps node name to value */
type NodeMap = mutable.HashMap[String, Expression]
@@ -235,18 +236,18 @@ object Utils extends LazyLogging {
/** Indent the results of [[ir.FirrtlNode.serialize]] */
@deprecated("Use ther new firrt.ir.Serializer instead.", "FIRRTL 1.4")
- def indent(str: String) = str replaceAllLiterally ("\n", "\n ")
-
- implicit def toWrappedExpression (x:Expression): WrappedExpression = new WrappedExpression(x)
- def getSIntWidth(s: BigInt): Int = s.bitLength + 1
- def getUIntWidth(u: BigInt): Int = u.bitLength
- def dec2string(v: BigDecimal): String = v.underlying().stripTrailingZeros().toPlainString
- def trim(v: BigDecimal): BigDecimal = BigDecimal(dec2string(v))
- def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b
- def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
- def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
+ def indent(str: String) = str.replaceAllLiterally("\n", "\n ")
+
+ implicit def toWrappedExpression(x: Expression): WrappedExpression = new WrappedExpression(x)
+ def getSIntWidth(s: BigInt): Int = s.bitLength + 1
+ def getUIntWidth(u: BigInt): Int = u.bitLength
+ def dec2string(v: BigDecimal): String = v.underlying().stripTrailingZeros().toPlainString
+ def trim(v: BigDecimal): BigDecimal = BigDecimal(dec2string(v))
+ def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b
+ def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
+ def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
val BoolType = UIntType(IntWidth(1))
- val one = UIntLiteral(1)
+ val one = UIntLiteral(1)
val zero = UIntLiteral(0)
def create_exps(n: String, t: Type): Seq[Expression] =
@@ -255,16 +256,18 @@ object Utils extends LazyLogging {
case ex: Mux =>
val e1s = create_exps(ex.tval)
val e2s = create_exps(ex.fval)
- e1s zip e2s map {case (e1, e2) =>
- Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2))
+ e1s.zip(e2s).map {
+ case (e1, e2) =>
+ Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2))
+ }
+ case ex: ValidIf => create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe))
+ case ex =>
+ ex.tpe match {
+ case (_: GroundType) => Seq(ex)
+ case t: BundleType =>
+ t.fields.flatMap(f => create_exps(WSubField(ex, f.name, f.tpe, times(flow(ex), f.flip))))
+ case t: VectorType => (0 until t.size).flatMap(i => create_exps(WSubIndex(ex, i, t.tpe, flow(ex))))
}
- case ex: ValidIf => create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe))
- case ex => ex.tpe match {
- case (_: GroundType) => Seq(ex)
- case t: BundleType =>
- t.fields.flatMap(f => create_exps(WSubField(ex, f.name, f.tpe,times(flow(ex), f.flip))))
- case t: VectorType => (0 until t.size).flatMap(i => create_exps(WSubIndex(ex, i, t.tpe,flow(ex))))
- }
}
/** Like create_exps, but returns intermediate Expressions as well
@@ -275,26 +278,28 @@ object Utils extends LazyLogging {
case ex: Mux =>
val e1s = expandRef(ex.tval)
val e2s = expandRef(ex.fval)
- e1s zip e2s map {case (e1, e2) =>
- Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2))
+ e1s.zip(e2s).map {
+ case (e1, e2) =>
+ Mux(ex.cond, e1, e2, mux_type_and_widths(e1, e2))
+ }
+ case ex: ValidIf => expandRef(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe))
+ case ex =>
+ ex.tpe match {
+ case (_: GroundType) => Seq(ex)
+ case (t: BundleType) =>
+ ex +: t.fields.flatMap(f => expandRef(WSubField(ex, f.name, f.tpe, times(flow(ex), f.flip))))
+ case (t: VectorType) =>
+ ex +: (0 until t.size).flatMap(i => expandRef(WSubIndex(ex, i, t.tpe, flow(ex))))
}
- case ex: ValidIf => expandRef(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe))
- case ex => ex.tpe match {
- case (_: GroundType) => Seq(ex)
- case (t: BundleType) =>
- ex +: t.fields.flatMap(f => expandRef(WSubField(ex, f.name, f.tpe, times(flow(ex), f.flip))))
- case (t: VectorType) =>
- ex +: (0 until t.size).flatMap(i => expandRef(WSubIndex(ex, i, t.tpe, flow(ex))))
- }
}
def toTarget(main: String, module: String)(expression: Expression): ReferenceTarget = {
val tokens = mutable.ArrayBuffer[TargetToken]()
var ref = "???"
def onExp(expr: Expression): Expression = {
- expr map onExp match {
+ expr.map(onExp) match {
case e: Reference => ref = e.name
- case e: SubField => tokens += TargetToken.Field(e.name)
- case e: SubIndex => tokens += TargetToken.Index(e.value)
+ case e: SubField => tokens += TargetToken.Field(e.name)
+ case e: SubIndex => tokens += TargetToken.Index(e.value)
case other => throwInternalError("Cannot call Utils.toTarget on non-referencing expression")
}
expr
@@ -302,39 +307,42 @@ object Utils extends LazyLogging {
onExp(expression)
ReferenceTarget(main, module, Nil, ref, tokens.toSeq)
}
- @deprecated("get_flip is fundamentally slow, use to_flip(flow(expr))", "1.2")
- def get_flip(t: Type, i: Int, f: Orientation): Orientation = {
- if (i >= get_size(t)) throwInternalError(s"get_flip: shouldn't be here - $i >= get_size($t)")
- t match {
- case (_: GroundType) => f
- case (tx: BundleType) =>
- val (_, flip) = tx.fields.foldLeft( (i, None: Option[Orientation]) ) {
- case ((n, ret), x) if n < get_size(x.tpe) => ret match {
- case None => (n, Some(get_flip(x.tpe, n, times(x.flip, f))))
- case Some(_) => (n, ret)
- }
- case ((n, ret), x) => (n - get_size(x.tpe), ret)
- }
- flip.get
- case (tx: VectorType) =>
- val (_, flip) = (0 until tx.size).foldLeft( (i, None: Option[Orientation]) ) {
- case ((n, ret), x) if n < get_size(tx.tpe) => ret match {
- case None => (n, Some(get_flip(tx.tpe, n, f)))
- case Some(_) => (n, ret)
- }
- case ((n, ret), x) => (n - get_size(tx.tpe), ret)
- }
- flip.get
- }
- }
-
- def get_point (e:Expression) : Int = e match {
- case (e: WRef) => 0
- case (e: WSubField) => e.expr.tpe match {case b: BundleType =>
- (b.fields takeWhile (_.name != e.name) foldLeft 0)(
- (point, f) => point + get_size(f.tpe))
+ @deprecated("get_flip is fundamentally slow, use to_flip(flow(expr))", "1.2")
+ def get_flip(t: Type, i: Int, f: Orientation): Orientation = {
+ if (i >= get_size(t)) throwInternalError(s"get_flip: shouldn't be here - $i >= get_size($t)")
+ t match {
+ case (_: GroundType) => f
+ case (tx: BundleType) =>
+ val (_, flip) = tx.fields.foldLeft((i, None: Option[Orientation])) {
+ case ((n, ret), x) if n < get_size(x.tpe) =>
+ ret match {
+ case None => (n, Some(get_flip(x.tpe, n, times(x.flip, f))))
+ case Some(_) => (n, ret)
+ }
+ case ((n, ret), x) => (n - get_size(x.tpe), ret)
+ }
+ flip.get
+ case (tx: VectorType) =>
+ val (_, flip) = (0 until tx.size).foldLeft((i, None: Option[Orientation])) {
+ case ((n, ret), x) if n < get_size(tx.tpe) =>
+ ret match {
+ case None => (n, Some(get_flip(tx.tpe, n, f)))
+ case Some(_) => (n, ret)
+ }
+ case ((n, ret), x) => (n - get_size(tx.tpe), ret)
+ }
+ flip.get
}
- case (e: WSubIndex) => e.value * get_size(e.tpe)
+ }
+
+ def get_point(e: Expression): Int = e match {
+ case (e: WRef) => 0
+ case (e: WSubField) =>
+ e.expr.tpe match {
+ case b: BundleType =>
+ (b.fields.takeWhile(_.name != e.name).foldLeft(0))((point, f) => point + get_size(f.tpe))
+ }
+ case (e: WSubIndex) => e.value * get_size(e.tpe)
case (e: WSubAccess) => get_point(e.expr)
}
@@ -345,8 +353,8 @@ object Utils extends LazyLogging {
*/
def hasFlip(t: Type): Boolean = t match {
case t: BundleType =>
- (t.fields exists (_.flip == Flip)) ||
- (t.fields exists (f => hasFlip(f.tpe)))
+ (t.fields.exists(_.flip == Flip)) ||
+ (t.fields.exists(f => hasFlip(f.tpe)))
case t: VectorType => hasFlip(t.tpe)
case _ => false
}
@@ -358,17 +366,17 @@ object Utils extends LazyLogging {
kids += e
e
}
- e map addKids
+ e.map(addKids)
kids.toSeq
}
/** Walks two expression trees and returns a sequence of tuples of where they differ */
def diff(e1: Expression, e2: Expression): Seq[(Expression, Expression)] = {
- if(weq(e1, e2)) Nil
+ if (weq(e1, e2)) Nil
else {
val (e1Kids, e2Kids) = (getKids(e1), getKids(e2))
- if(e1Kids == Nil || e2Kids == Nil || e1Kids.size != e2Kids.size) Seq((e1, e2))
+ if (e1Kids == Nil || e2Kids == Nil || e1Kids.size != e2Kids.size) Seq((e1, e2))
else {
e1Kids.zip(e2Kids).flatMap { case (e1k, e2k) => diff(e1k, e2k) }
}
@@ -378,65 +386,67 @@ object Utils extends LazyLogging {
/** Returns an inlined expression (replacing node references with values),
* stopping on a stopping condition or until the reference is not a node
*/
- def inline(nodeMap: NodeMap, stop: String => Boolean = {x: String => false})(e: Expression): Expression = {
- def onExp(e: Expression): Expression = e map onExp match {
+ def inline(nodeMap: NodeMap, stop: String => Boolean = { x: String => false })(e: Expression): Expression = {
+ def onExp(e: Expression): Expression = e.map(onExp) match {
case Reference(name, _, _, _) if nodeMap.contains(name) && !stop(name) => onExp(nodeMap(name))
- case other => other
+ case other => other
}
onExp(e)
}
def mux_type(e1: Expression, e2: Expression): Type = mux_type(e1.tpe, e2.tpe)
- def mux_type(t1: Type, t2: Type): Type = (t1, t2) match {
- case (ClockType, ClockType) => ClockType
+ def mux_type(t1: Type, t2: Type): Type = (t1, t2) match {
+ case (ClockType, ClockType) => ClockType
case (AsyncResetType, AsyncResetType) => AsyncResetType
case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth)
case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth)
case (t1: FixedType, t2: FixedType) => FixedType(UnknownWidth, UnknownWidth)
case (t1: IntervalType, t2: IntervalType) => IntervalType(UnknownBound, UnknownBound, UnknownWidth)
case (t1: VectorType, t2: VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size)
- case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map {
- case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe))
- })
+ case (t1: BundleType, t2: BundleType) =>
+ BundleType(t1.fields.zip(t2.fields).map {
+ case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe))
+ })
case _ => UnknownType
}
- def mux_type_and_widths(e1: Expression,e2: Expression): Type =
+ def mux_type_and_widths(e1: Expression, e2: Expression): Type =
mux_type_and_widths(e1.tpe, e2.tpe)
def mux_type_and_widths(t1: Type, t2: Type): Type = {
def wmax(w1: Width, w2: Width): Width = (w1, w2) match {
- case (w1x: IntWidth, w2x: IntWidth) => IntWidth(w1x.width max w2x.width)
+ case (w1x: IntWidth, w2x: IntWidth) => IntWidth(w1x.width.max(w2x.width))
case (w1x, w2x) => IsMax(w1x, w2x)
}
(t1, t2) match {
- case (ClockType, ClockType) => ClockType
+ case (ClockType, ClockType) => ClockType
case (AsyncResetType, AsyncResetType) => AsyncResetType
case (t1x: UIntType, t2x: UIntType) => UIntType(IsMax(t1x.width, t2x.width))
case (t1x: SIntType, t2x: SIntType) => SIntType(IsMax(t1x.width, t2x.width))
case (FixedType(w1, p1), FixedType(w2, p2)) =>
- FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2))
+ FixedType(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2))
case (IntervalType(l1, u1, p1), IntervalType(l2, u2, p2)) =>
IntervalType(IsMin(l1, l2), constraint.IsMax(u1, u2), MAX(p1, p2))
- case (t1x: VectorType, t2x: VectorType) => VectorType(
- mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size)
- case (t1x: BundleType, t2x: BundleType) => BundleType(t1x.fields zip t2x.fields map {
- case (f1, f2) => Field(f1.name, f1.flip, mux_type_and_widths(f1.tpe, f2.tpe))
- })
+ case (t1x: VectorType, t2x: VectorType) => VectorType(mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size)
+ case (t1x: BundleType, t2x: BundleType) =>
+ BundleType(t1x.fields.zip(t2x.fields).map {
+ case (f1, f2) => Field(f1.name, f1.flip, mux_type_and_widths(f1.tpe, f2.tpe))
+ })
case _ => UnknownType
}
}
- def module_type(m: DefModule): BundleType = BundleType(m.ports map {
+ def module_type(m: DefModule): BundleType = BundleType(m.ports.map {
case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
})
def sub_type(v: Type): Type = v match {
case vx: VectorType => vx.tpe
case vx => UnknownType
}
- def field_type(v: Type, s: String) : Type = v match {
- case vx: BundleType => vx.fields find (_.name == s) match {
- case Some(f) => f.tpe
- case None => UnknownType
- }
+ def field_type(v: Type, s: String): Type = v match {
+ case vx: BundleType =>
+ vx.fields.find(_.name == s) match {
+ case Some(f) => f.tpe
+ case None => UnknownType
+ }
case vx => UnknownType
}
@@ -445,13 +455,12 @@ object Utils extends LazyLogging {
//// =============== EXPANSION FUNCTIONS ================
def get_size(t: Type): Int = t match {
- case tx: BundleType => (tx.fields foldLeft 0)(
- (sum, f) => sum + get_size(f.tpe))
+ case tx: BundleType => (tx.fields.foldLeft(0))((sum, f) => sum + get_size(f.tpe))
case tx: VectorType => tx.size * get_size(tx.tpe)
case tx => 1
}
- def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = {
+ def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int, Int)] = {
import passes.CheckTypes.legalResetType
//;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
(t1, t2) match {
@@ -461,27 +470,39 @@ object Utils extends LazyLogging {
case (_: AnalogType, _: AnalogType) => if (flip1 == flip2) Seq((0, 0)) else Nil
case (t1x: BundleType, t2x: BundleType) =>
def emptyMap = Map[String, (Type, Orientation, Int)]()
- val t1_fields = t1x.fields.foldLeft( (emptyMap, 0) ) { case ((map, ilen), f1) =>
- (map + (f1.name ->( (f1.tpe, f1.flip, ilen) )), ilen + get_size(f1.tpe))
- }._1
- t2x.fields.foldLeft( (Seq[(Int, Int)](), 0) ) { case ((points, jlen), f2) =>
- t1_fields get f2.name match {
- case None => (points, jlen + get_size(f2.tpe))
- case Some((f1_tpe, f1_flip, ilen)) =>
- val f1_times = times(flip1, f1_flip)
- val f2_times = times(flip2, f2.flip)
- val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times)
- (points ++ (ls map { case (x, y) => (x + ilen, y + jlen) }), jlen + get_size(f2.tpe))
+ val t1_fields = t1x.fields
+ .foldLeft((emptyMap, 0)) {
+ case ((map, ilen), f1) =>
+ (map + (f1.name -> ((f1.tpe, f1.flip, ilen))), ilen + get_size(f1.tpe))
+ }
+ ._1
+ t2x.fields
+ .foldLeft((Seq[(Int, Int)](), 0)) {
+ case ((points, jlen), f2) =>
+ t1_fields.get(f2.name) match {
+ case None => (points, jlen + get_size(f2.tpe))
+ case Some((f1_tpe, f1_flip, ilen)) =>
+ val f1_times = times(flip1, f1_flip)
+ val f2_times = times(flip2, f2.flip)
+ val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times)
+ (points ++ (ls.map { case (x, y) => (x + ilen, y + jlen) }), jlen + get_size(f2.tpe))
+ }
}
- }._1
+ ._1
case (t1x: VectorType, t2x: VectorType) =>
val size = math.min(t1x.size, t2x.size)
- (0 until size).foldLeft( (Seq[(Int, Int)](), 0, 0) ) { case ((points, ilen, jlen), _) =>
- val ls = get_valid_points(t1x.tpe, t2x.tpe, flip1, flip2)
- (points ++ (ls map { case (x, y) => (x + ilen, y + jlen) }),
- ilen + get_size(t1x.tpe), jlen + get_size(t2x.tpe))
- }._1
- case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil
+ (0 until size)
+ .foldLeft((Seq[(Int, Int)](), 0, 0)) {
+ case ((points, ilen, jlen), _) =>
+ val ls = get_valid_points(t1x.tpe, t2x.tpe, flip1, flip2)
+ (
+ points ++ (ls.map { case (x, y) => (x + ilen, y + jlen) }),
+ ilen + get_size(t1x.tpe),
+ jlen + get_size(t2x.tpe)
+ )
+ }
+ ._1
+ case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil
case (AsyncResetType, AsyncResetType) => if (flip1 == flip2) Seq((0, 0)) else Nil
// The following two cases handle driving ResetType from other legal reset types
// Flippedness is important here because ResetType can be driven by other reset types, but it
@@ -495,112 +516,114 @@ object Utils extends LazyLogging {
}
// =========== FLOW/FLIP UTILS ============
- def swap(g: Flow) : Flow = g match {
+ def swap(g: Flow): Flow = g match {
case UnknownFlow => UnknownFlow
- case SourceFlow => SinkFlow
- case SinkFlow => SourceFlow
- case DuplexFlow => DuplexFlow
+ case SourceFlow => SinkFlow
+ case SinkFlow => SourceFlow
+ case DuplexFlow => DuplexFlow
}
- def swap(d: Direction) : Direction = d match {
+ def swap(d: Direction): Direction = d match {
case Output => Input
- case Input => Output
+ case Input => Output
}
- def swap(f: Orientation) : Orientation = f match {
+ def swap(f: Orientation): Orientation = f match {
case Default => Flip
- case Flip => Default
+ case Flip => Default
}
// Input <-> SourceFlow <-> Flip
// Output <-> SinkFlow <-> Default
def to_dir(g: Flow): Direction = g match {
case SourceFlow => Input
- case SinkFlow => Output
+ case SinkFlow => Output
}
def to_dir(o: Orientation): Direction = o match {
- case Flip => Input
+ case Flip => Input
case Default => Output
}
def to_flow(d: Direction): Flow = d match {
- case Input => SourceFlow
+ case Input => SourceFlow
case Output => SinkFlow
}
def to_flip(d: Direction): Orientation = d match {
- case Input => Flip
+ case Input => Flip
case Output => Default
}
def to_flip(g: Flow): Orientation = g match {
case SourceFlow => Flip
- case SinkFlow => Default
+ case SinkFlow => Default
}
def field_flip(v: Type, s: String): Orientation = v match {
- case vx: BundleType => vx.fields find (_.name == s) match {
- case Some(ft) => ft.flip
- case None => Default
- }
+ case vx: BundleType =>
+ vx.fields.find(_.name == s) match {
+ case Some(ft) => ft.flip
+ case None => Default
+ }
case vx => Default
}
def get_field(v: Type, s: String): Field = v match {
- case vx: BundleType => vx.fields find (_.name == s) match {
- case Some(ft) => ft
- case None => throwInternalError(s"get_field: shouldn't be here - $v.$s")
- }
+ case vx: BundleType =>
+ vx.fields.find(_.name == s) match {
+ case Some(ft) => ft
+ case None => throwInternalError(s"get_field: shouldn't be here - $v.$s")
+ }
case vx => throwInternalError(s"get_field: shouldn't be here - $v")
}
- def times(d: Direction,flip: Orientation): Direction = flip match {
+ def times(d: Direction, flip: Orientation): Direction = flip match {
case Default => d
- case Flip => swap(d)
+ case Flip => swap(d)
}
- def times(g: Flow, d: Direction): Direction = times(d, g)
+ def times(g: Flow, d: Direction): Direction = times(d, g)
def times(d: Direction, g: Flow): Direction = g match {
- case SinkFlow => d
+ case SinkFlow => d
case SourceFlow => swap(d) // SourceFlow == INPUT == REVERSE
}
- def times(g: Flow, flip: Orientation): Flow = times(flip, g)
+ def times(g: Flow, flip: Orientation): Flow = times(flip, g)
def times(flip: Orientation, g: Flow): Flow = flip match {
case Default => g
- case Flip => swap(g)
+ case Flip => swap(g)
}
def times(f1: Orientation, f2: Orientation): Orientation = f2 match {
case Default => f1
- case Flip => swap(f1)
+ case Flip => swap(f1)
}
// =========== ACCESSORS =========
def kind(e: Expression): Kind = e match {
- case ex: WRef => ex.kind
- case ex: WSubField => kind(ex.expr)
- case ex: WSubIndex => kind(ex.expr)
+ case ex: WRef => ex.kind
+ case ex: WSubField => kind(ex.expr)
+ case ex: WSubIndex => kind(ex.expr)
case ex: WSubAccess => kind(ex.expr)
case ex => ExpKind
}
def flow(e: Expression): Flow = e match {
- case ex: WRef => ex.flow
- case ex: WSubField => ex.flow
- case ex: WSubIndex => ex.flow
- case ex: WSubAccess => ex.flow
- case ex: DoPrim => SourceFlow
+ case ex: WRef => ex.flow
+ case ex: WSubField => ex.flow
+ case ex: WSubIndex => ex.flow
+ case ex: WSubAccess => ex.flow
+ case ex: DoPrim => SourceFlow
case ex: UIntLiteral => SourceFlow
case ex: SIntLiteral => SourceFlow
- case ex: Mux => SourceFlow
- case ex: ValidIf => SourceFlow
+ case ex: Mux => SourceFlow
+ case ex: ValidIf => SourceFlow
case WInvalid => SourceFlow
- case ex => throwInternalError(s"flow: shouldn't be here - $e")
+ case ex => throwInternalError(s"flow: shouldn't be here - $e")
}
def get_flow(s: Statement): Flow = s match {
- case sx: DefWire => DuplexFlow
- case sx: DefRegister => DuplexFlow
- case sx: WDefInstance => SourceFlow
- case sx: DefNode => SourceFlow
- case sx: DefInstance => SourceFlow
- case sx: DefMemory => SourceFlow
- case sx: Block => UnknownFlow
- case sx: Connect => UnknownFlow
+ case sx: DefWire => DuplexFlow
+ case sx: DefRegister => DuplexFlow
+ case sx: WDefInstance => SourceFlow
+ case sx: DefNode => SourceFlow
+ case sx: DefInstance => SourceFlow
+ case sx: DefMemory => SourceFlow
+ case sx: Block => UnknownFlow
+ case sx: Connect => UnknownFlow
case sx: PartialConnect => UnknownFlow
- case sx: Stop => UnknownFlow
- case sx: Print => UnknownFlow
- case sx: IsInvalid => UnknownFlow
+ case sx: Stop => UnknownFlow
+ case sx: Print => UnknownFlow
+ case sx: IsInvalid => UnknownFlow
case EmptyStmt => UnknownFlow
}
def get_flow(p: Port): Flow = if (p.direction == Input) SourceFlow else SinkFlow
@@ -630,7 +653,7 @@ object Utils extends LazyLogging {
val (root, tail) = splitRef(e.expr)
tail match {
case EmptyExpression => (root, WRef(e.name, e.tpe, root.kind, e.flow))
- case exp => (root, WSubField(tail, e.name, e.tpe, e.flow))
+ case exp => (root, WSubField(tail, e.name, e.tpe, e.flow))
}
}
@@ -657,28 +680,28 @@ object Utils extends LazyLogging {
def getDeclaration(m: Module, expr: Expression): IsDeclaration = {
def getRootDecl(name: String)(s: Statement): Option[IsDeclaration] = s match {
case decl: IsDeclaration => if (decl.name == name) Some(decl) else None
- case c: Conditionally =>
+ case c: Conditionally =>
val m = (getRootDecl(name)(c.conseq), getRootDecl(name)(c.alt))
(m: @unchecked) match {
case (Some(decl), None) => Some(decl)
case (None, Some(decl)) => Some(decl)
- case (None, None) => None
+ case (None, None) => None
}
case begin: Block =>
- val stmts = begin.stmts flatMap getRootDecl(name) // can we short circuit?
+ val stmts = begin.stmts.flatMap(getRootDecl(name)) // can we short circuit?
if (stmts.nonEmpty) Some(stmts.head) else None
case _ => None
}
expr match {
case (_: WRef | _: WSubIndex | _: WSubField) =>
val (root, tail) = splitRef(expr)
- val rootDecl = m.ports find (_.name == root.name) match {
+ val rootDecl = m.ports.find(_.name == root.name) match {
case Some(decl) => decl
case None =>
getRootDecl(root.name)(m.body) match {
case Some(decl) => decl
- case None => throw new DeclarationNotFoundException(
- s"[module ${m.name}] Reference ${expr.serialize} not declared!")
+ case None =>
+ throw new DeclarationNotFoundException(s"[module ${m.name}] Reference ${expr.serialize} not declared!")
}
}
rootDecl
@@ -771,7 +794,7 @@ object Utils extends LazyLogging {
.findAllMatchIn(name)
.map(_.end - 1)
.toSeq
- .foldLeft(Seq[String]()){ case (seq, id) => seq :+ name.splitAt(id)._1 }
+ .foldLeft(Seq[String]()) { case (seq, id) => seq :+ name.splitAt(id)._1 }
}
/** Returns the value masked with the width.
@@ -785,14 +808,14 @@ object Utils extends LazyLogging {
}
object MemoizedHash {
- implicit def convertTo[T](e: T): MemoizedHash[T] = new MemoizedHash(e)
+ implicit def convertTo[T](e: T): MemoizedHash[T] = new MemoizedHash(e)
implicit def convertFrom[T](f: MemoizedHash[T]): T = f.t
}
class MemoizedHash[T](val t: T) {
override lazy val hashCode = t.hashCode
override def equals(that: Any) = that match {
- case x: MemoizedHash[_] => t equals x.t
+ case x: MemoizedHash[_] => t.equals(x.t)
case _ => false
}
}
@@ -833,13 +856,12 @@ class ModuleGraph {
def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = {
nodes.get(child) match {
case Some(children) =>
- if(children(parent)) {
+ if (children(parent)) {
parent :: path
- }
- else {
+ } else {
children.foreach { grandchild =>
val newPath = pathExists(grandchild, parent, grandchild :: path)
- if(newPath.nonEmpty) {
+ if (newPath.nonEmpty) {
return newPath
}
}
diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala
index 502d021d..b14c39c7 100644
--- a/src/main/scala/firrtl/Visitor.scala
+++ b/src/main/scala/firrtl/Visitor.scala
@@ -13,7 +13,6 @@ import Parser.{AppendInfo, GenInfo, IgnoreInfo, InfoMode, UseInfo}
import firrtl.ir._
import Utils.throwInternalError
-
class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] with ParseTreeVisitor[FirrtlNode] {
// Strip file path
private def stripPath(filename: String) = filename.drop(filename.lastIndexOf("/") + 1)
@@ -21,7 +20,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
// Check if identifier is made of legal characters
private def legalId(id: String) = {
val legalChars = ('A' to 'Z').toSet ++ ('a' to 'z').toSet ++ ('0' to '9').toSet ++ Set('_', '$')
- id forall legalChars
+ id.forall(legalChars)
}
def visit(ctx: CircuitContext): Circuit = visitCircuit(ctx)
@@ -37,22 +36,22 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
private def string2BigInt(s: String): BigInt = {
// private define legal patterns
s match {
- case ZeroPattern(_*) => BigInt(0)
- case HexPattern(hexdigits) => BigInt(hexdigits, 16)
- case OctalPattern(octaldigits) => BigInt(octaldigits, 8)
+ case ZeroPattern(_*) => BigInt(0)
+ case HexPattern(hexdigits) => BigInt(hexdigits, 16)
+ case OctalPattern(octaldigits) => BigInt(octaldigits, 8)
case BinaryPattern(binarydigits) => BigInt(binarydigits, 2)
- case DecPattern(num) => BigInt(num, 10)
- case _ => throw new Exception("Invalid String for conversion to BigInt " + s)
+ case DecPattern(num) => BigInt(num, 10)
+ case _ => throw new Exception("Invalid String for conversion to BigInt " + s)
}
}
private def string2BigDecimal(s: String): BigDecimal = {
// private define legal patterns
s match {
- case ZeroPattern(_*) => BigDecimal(0)
- case DecPattern(num) => BigDecimal(num)
+ case ZeroPattern(_*) => BigDecimal(0)
+ case DecPattern(num) => BigDecimal(num)
case DecimalPattern(num) => BigDecimal(num)
- case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s)
+ case _ => throw new Exception("Invalid String for conversion to BigDecimal " + s)
}
}
@@ -64,7 +63,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
parentCtx.getStart.getCharPositionInLine
lazy val useInfo: String = ctx match {
case Some(info) => info.getText.drop(2).init // remove surrounding @[ ... ]
- case None => ""
+ case None => ""
}
infoMode match {
case UseInfo =>
@@ -88,14 +87,19 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
private def visitModule(ctx: ModuleContext): DefModule = {
val info = visitInfo(Option(ctx.info), ctx)
ctx.getChild(0).getText match {
- case "module" => Module(info, ctx.id.getText, ctx.port.asScala.map(visitPort).toSeq,
- if (ctx.moduleBlock() != null)
- visitBlock(ctx.moduleBlock())
- else EmptyStmt)
+ case "module" =>
+ Module(
+ info,
+ ctx.id.getText,
+ ctx.port.asScala.map(visitPort).toSeq,
+ if (ctx.moduleBlock() != null)
+ visitBlock(ctx.moduleBlock())
+ else EmptyStmt
+ )
case "extmodule" =>
val defname = if (ctx.defname != null) ctx.defname.id.getText else ctx.id.getText
- val ports = ctx.port.asScala map visitPort
- val params = ctx.parameter.asScala map visitParameter
+ val ports = ctx.port.asScala.map(visitPort)
+ val params = ctx.parameter.asScala.map(visitParameter)
ExtModule(info, ctx.id.getText, ports.toSeq, defname, params.toSeq)
}
}
@@ -111,22 +115,22 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
case (null, str, null, null) => StringParam(name, visitStringLit(str))
case (null, null, dbl, null) => DoubleParam(name, dbl.getText.toDouble)
case (null, null, null, raw) => RawStringParam(name, raw.getText.tail.init.replace("\\'", "'")) // Remove "\'"s
- case _ => throwInternalError(s"visiting impossible parameter ${ctx.getText}")
+ case _ => throwInternalError(s"visiting impossible parameter ${ctx.getText}")
}
}
private def visitDir(ctx: DirContext): Direction =
ctx.getText match {
- case "input" => Input
+ case "input" => Input
case "output" => Output
}
private def visitMdir(ctx: MdirContext): MPortDir =
ctx.getText match {
case "infer" => MInfer
- case "read" => MRead
+ case "read" => MRead
case "write" => MWrite
- case "rdwr" => MReadWrite
+ case "rdwr" => MReadWrite
}
// Match on a type instead of on strings?
@@ -135,47 +139,53 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
ctx.getChild(0) match {
case term: TerminalNode =>
term.getText match {
- case "UInt" => if (ctx.getChildCount > 1) UIntType(getWidth(ctx.intLit(0)))
- else UIntType(UnknownWidth)
- case "SInt" => if (ctx.getChildCount > 1) SIntType(getWidth(ctx.intLit(0)))
- else SIntType(UnknownWidth)
- case "Fixed" => ctx.intLit.size match {
- case 0 => FixedType(UnknownWidth, UnknownWidth)
- case 1 => ctx.getChild(2).getText match {
- case "<" => FixedType(UnknownWidth, getWidth(ctx.intLit(0)))
- case _ => FixedType(getWidth(ctx.intLit(0)), UnknownWidth)
+ case "UInt" =>
+ if (ctx.getChildCount > 1) UIntType(getWidth(ctx.intLit(0)))
+ else UIntType(UnknownWidth)
+ case "SInt" =>
+ if (ctx.getChildCount > 1) SIntType(getWidth(ctx.intLit(0)))
+ else SIntType(UnknownWidth)
+ case "Fixed" =>
+ ctx.intLit.size match {
+ case 0 => FixedType(UnknownWidth, UnknownWidth)
+ case 1 =>
+ ctx.getChild(2).getText match {
+ case "<" => FixedType(UnknownWidth, getWidth(ctx.intLit(0)))
+ case _ => FixedType(getWidth(ctx.intLit(0)), UnknownWidth)
+ }
+ case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1)))
}
- case 2 => FixedType(getWidth(ctx.intLit(0)), getWidth(ctx.intLit(1)))
- }
- case "Interval" => ctx.boundValue.size match {
- case 0 =>
- val point = ctx.intLit.size match {
- case 0 => UnknownWidth
- case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText))
- }
- IntervalType(UnknownBound, UnknownBound, point)
- case 2 =>
- val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match {
- case (_, "?") => UnknownBound
- case ("(", v) => Open(string2BigDecimal(v))
- case ("[", v) => Closed(string2BigDecimal(v))
- }
- val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match {
- case (_, "?") => UnknownBound
- case (")", v) => Open(string2BigDecimal(v))
- case ("]", v) => Closed(string2BigDecimal(v))
- }
- val point = ctx.intLit.size match {
- case 0 => UnknownWidth
- case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText))
- }
- IntervalType(lower, upper, point)
- }
- case "Clock" => ClockType
+ case "Interval" =>
+ ctx.boundValue.size match {
+ case 0 =>
+ val point = ctx.intLit.size match {
+ case 0 => UnknownWidth
+ case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText))
+ }
+ IntervalType(UnknownBound, UnknownBound, point)
+ case 2 =>
+ val lower = (ctx.lowerBound.getText, ctx.boundValue(0).getText) match {
+ case (_, "?") => UnknownBound
+ case ("(", v) => Open(string2BigDecimal(v))
+ case ("[", v) => Closed(string2BigDecimal(v))
+ }
+ val upper = (ctx.upperBound.getText, ctx.boundValue(1).getText) match {
+ case (_, "?") => UnknownBound
+ case (")", v) => Open(string2BigDecimal(v))
+ case ("]", v) => Closed(string2BigDecimal(v))
+ }
+ val point = ctx.intLit.size match {
+ case 0 => UnknownWidth
+ case 1 => IntWidth(string2BigInt(ctx.intLit(0).getText))
+ }
+ IntervalType(lower, upper, point)
+ }
+ case "Clock" => ClockType
case "AsyncReset" => AsyncResetType
- case "Reset" => ResetType
- case "Analog" => if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0)))
- else AnalogType(UnknownWidth)
+ case "Reset" => ResetType
+ case "Analog" =>
+ if (ctx.getChildCount > 1) AnalogType(getWidth(ctx.intLit(0)))
+ else AnalogType(UnknownWidth)
case "{" => BundleType(ctx.field.asScala.map(visitField).toSeq)
}
case typeContext: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.intLit(0).getText))
@@ -208,11 +218,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
private def visitRuw(ctx: Option[RuwContext]): ReadUnderWrite.Value = ctx match {
case None => ReadUnderWrite.Undefined
- case Some(ctx) => ctx.getText match {
- case "undefined" => ReadUnderWrite.Undefined
- case "old" => ReadUnderWrite.Old
- case "new" => ReadUnderWrite.New
- }
+ case Some(ctx) =>
+ ctx.getText match {
+ case "undefined" => ReadUnderWrite.Undefined
+ case "old" => ReadUnderWrite.Old
+ case "new" => ReadUnderWrite.New
+ }
}
// Memories are fairly complicated to translate thus have a dedicated method
@@ -220,7 +231,11 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
val readers = mutable.ArrayBuffer.empty[String]
val writers = mutable.ArrayBuffer.empty[String]
val readwriters = mutable.ArrayBuffer.empty[String]
- case class ParamValue(typ: Option[Type] = None, lit: Option[BigInt] = None, ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined, unique: Boolean = true)
+ case class ParamValue(
+ typ: Option[Type] = None,
+ lit: Option[BigInt] = None,
+ ruw: ReadUnderWrite.Value = ReadUnderWrite.Undefined,
+ unique: Boolean = true)
val fieldMap = mutable.HashMap[String, ParamValue]()
val memName = ctx.id(0).getText
def parseMemFields(memFields: Seq[MemFieldContext]): Unit =
@@ -228,14 +243,14 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
val fieldName = field.children.asScala(0).getText
fieldName match {
- case "reader" => readers ++= field.id().asScala.map(_.getText)
- case "writer" => writers ++= field.id().asScala.map(_.getText)
+ case "reader" => readers ++= field.id().asScala.map(_.getText)
+ case "writer" => writers ++= field.id().asScala.map(_.getText)
case "readwriter" => readwriters ++= field.id().asScala.map(_.getText)
case _ =>
val paramDef = fieldName match {
- case "data-type" => ParamValue(typ = Some(visitType(field.`type`())))
+ case "data-type" => ParamValue(typ = Some(visitType(field.`type`())))
case "read-under-write" => ParamValue(ruw = visitRuw(Option(field.ruw)))
- case _ => ParamValue(lit = Some(BigInt(field.intLit().getText)))
+ case _ => ParamValue(lit = Some(BigInt(field.intLit().getText)))
}
if (fieldMap.contains(fieldName))
throw new ParameterRedefinedException(s"Redefinition of $fieldName in FIRRTL line:${field.start.getLine}")
@@ -255,20 +270,26 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
}
// Check for required fields
- Seq("data-type", "depth", "read-latency", "write-latency") foreach { field =>
- fieldMap.getOrElse(field, throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found"))
+ Seq("data-type", "depth", "read-latency", "write-latency").foreach { field =>
+ fieldMap.getOrElse(
+ field,
+ throw new ParameterNotSpecifiedException(s"[$info] Required mem field $field not found")
+ )
}
def lit(param: String) = fieldMap(param).lit.get
val ruw = fieldMap.get("read-under-write").map(_.ruw).getOrElse(ir.ReadUnderWrite.Undefined)
- DefMemory(info,
+ DefMemory(
+ info,
name = memName,
dataType = fieldMap("data-type").typ.get,
depth = lit("depth"),
writeLatency = lit("write-latency").toInt,
readLatency = lit("read-latency").toInt,
- readers = readers.toSeq, writers = writers.toSeq, readwriters = readwriters.toSeq,
+ readers = readers.toSeq,
+ writers = writers.toSeq,
+ readwriters = readwriters.toSeq,
readUnderWrite = ruw
)
}
@@ -299,56 +320,88 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
val info = visitInfo(Option(ctx.info), ctx)
ctx.getChild(0) match {
case when: WhenContext => visitWhen(when)
- case term: TerminalNode => term.getText match {
- case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`()))
- case "reg" =>
- val name = ctx.id(0).getText
- val tpe = visitType(ctx.`type`())
- val (reset, init, rinfo) = {
- val rb = ctx.reset_block()
- if (rb != null) {
- val sr = rb.simple_reset.simple_reset0()
- val innerInfo = if (info == NoInfo) visitInfo(Option(rb.info), ctx) else info
- (visitExp(sr.exp(0)), visitExp(sr.exp(1)), innerInfo)
+ case term: TerminalNode =>
+ term.getText match {
+ case "wire" => DefWire(info, ctx.id(0).getText, visitType(ctx.`type`()))
+ case "reg" =>
+ val name = ctx.id(0).getText
+ val tpe = visitType(ctx.`type`())
+ val (reset, init, rinfo) = {
+ val rb = ctx.reset_block()
+ if (rb != null) {
+ val sr = rb.simple_reset.simple_reset0()
+ val innerInfo = if (info == NoInfo) visitInfo(Option(rb.info), ctx) else info
+ (visitExp(sr.exp(0)), visitExp(sr.exp(1)), innerInfo)
+ } else
+ (UIntLiteral(0, IntWidth(1)), Reference(name, tpe), info)
}
- else
- (UIntLiteral(0, IntWidth(1)), Reference(name, tpe), info)
- }
- DefRegister(rinfo, name, tpe, visitExp(ctx_exp(0)), reset, init)
- case "mem" => visitMem(ctx)
- case "cmem" =>
- val (tpe, size) = visitCMemType(ctx.`type`())
- CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false)
- case "smem" =>
- val (tpe, size) = visitCMemType(ctx.`type`())
- CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw)))
- case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText)
- case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0)))
-
- case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1)))
- case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq)
- case "printf(" => Print(info, visitStringLit(ctx.StringLit), ctx_exp.drop(2).map(visitExp).toSeq,
- visitExp(ctx_exp(0)), visitExp(ctx_exp(1)))
- // formal
- case "assert" => Verification(Formal.Assert, info, visitExp(ctx_exp(0)),
- visitExp(ctx_exp(1)), visitExp(ctx_exp(2)),
- visitStringLit(ctx.StringLit))
- case "assume" => Verification(Formal.Assume, info, visitExp(ctx_exp(0)),
- visitExp(ctx_exp(1)), visitExp(ctx_exp(2)),
- visitStringLit(ctx.StringLit))
- case "cover" => Verification(Formal.Cover, info, visitExp(ctx_exp(0)),
- visitExp(ctx_exp(1)), visitExp(ctx_exp(2)),
- visitStringLit(ctx.StringLit))
- // end formal
- case "skip" => EmptyStmt
- }
+ DefRegister(rinfo, name, tpe, visitExp(ctx_exp(0)), reset, init)
+ case "mem" => visitMem(ctx)
+ case "cmem" =>
+ val (tpe, size) = visitCMemType(ctx.`type`())
+ CDefMemory(info, ctx.id(0).getText, tpe, size, seq = false)
+ case "smem" =>
+ val (tpe, size) = visitCMemType(ctx.`type`())
+ CDefMemory(info, ctx.id(0).getText, tpe, size, seq = true, readUnderWrite = visitRuw(Option(ctx.ruw)))
+ case "inst" => DefInstance(info, ctx.id(0).getText, ctx.id(1).getText)
+ case "node" => DefNode(info, ctx.id(0).getText, visitExp(ctx_exp(0)))
+
+ case "stop(" => Stop(info, string2Int(ctx.intLit().getText), visitExp(ctx_exp(0)), visitExp(ctx_exp(1)))
+ case "attach" => Attach(info, ctx_exp.map(visitExp).toSeq)
+ case "printf(" =>
+ Print(
+ info,
+ visitStringLit(ctx.StringLit),
+ ctx_exp.drop(2).map(visitExp).toSeq,
+ visitExp(ctx_exp(0)),
+ visitExp(ctx_exp(1))
+ )
+ // formal
+ case "assert" =>
+ Verification(
+ Formal.Assert,
+ info,
+ visitExp(ctx_exp(0)),
+ visitExp(ctx_exp(1)),
+ visitExp(ctx_exp(2)),
+ visitStringLit(ctx.StringLit)
+ )
+ case "assume" =>
+ Verification(
+ Formal.Assume,
+ info,
+ visitExp(ctx_exp(0)),
+ visitExp(ctx_exp(1)),
+ visitExp(ctx_exp(2)),
+ visitStringLit(ctx.StringLit)
+ )
+ case "cover" =>
+ Verification(
+ Formal.Cover,
+ info,
+ visitExp(ctx_exp(0)),
+ visitExp(ctx_exp(1)),
+ visitExp(ctx_exp(2)),
+ visitStringLit(ctx.StringLit)
+ )
+ // end formal
+ case "skip" => EmptyStmt
+ }
// If we don't match on the first child, try the next one
case _ =>
ctx.getChild(1).getText match {
case "<=" => Connect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1)))
case "<-" => PartialConnect(info, visitExp(ctx_exp(0)), visitExp(ctx_exp(1)))
case "is" => IsInvalid(info, visitExp(ctx_exp(0)))
- case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType, ctx.id(1).getText, Seq(visitExp(ctx_exp(0)), visitExp(ctx_exp(1))), visitMdir(ctx.mdir))
+ case "mport" =>
+ CDefMPort(
+ info,
+ ctx.id(0).getText,
+ UnknownType,
+ ctx.id(1).getText,
+ Seq(visitExp(ctx_exp(0)), visitExp(ctx_exp(1))),
+ visitMdir(ctx.mdir)
+ )
}
}
}
@@ -379,10 +432,12 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
new SubAccess(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType)
}
case _: PrimopContext =>
- DoPrim(visitPrimop(ctx.primop),
- ctx_exp.map(visitExp).toSeq,
- ctx.intLit.asScala.map(x => string2BigInt(x.getText)).toSeq,
- UnknownType)
+ DoPrim(
+ visitPrimop(ctx.primop),
+ ctx_exp.map(visitExp).toSeq,
+ ctx.intLit.asScala.map(x => string2BigInt(x.getText)).toSeq,
+ UnknownType
+ )
case _ =>
ctx.getChild(0).getText match {
case "UInt" =>
@@ -405,7 +460,7 @@ class Visitor(infoMode: InfoMode) extends AbstractParseTreeVisitor[FirrtlNode] w
SIntLiteral(value)
}
case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType)
- case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType)
+ case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType)
}
}
}
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 95b24ad0..4153fc74 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -27,96 +27,110 @@ case object DuplexFlow extends Flow
case object UnknownFlow extends Flow
object WRef {
+
/** Creates a WRef from a Wire */
def apply(wire: DefWire): WRef = new WRef(wire.name, wire.tpe, WireKind, UnknownFlow)
+
/** Creates a WRef from a Register */
def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UnknownFlow)
+
/** Creates a WRef from a Node */
def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, SourceFlow)
+
/** Creates a WRef from a Port */
def apply(port: Port): WRef = new WRef(port.name, port.tpe, PortKind, UnknownFlow)
+
/** Creates a WRef from a WDefInstance */
def apply(wi: WDefInstance): WRef = new WRef(wi.name, wi.tpe, InstanceKind, UnknownFlow)
+
/** Creates a WRef from a DefMemory */
def apply(mem: DefMemory): WRef = new WRef(mem.name, passes.MemPortUtils.memType(mem), MemKind, UnknownFlow)
+
/** Creates a WRef from an arbitrary string name */
def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = Reference(n, t, k, UnknownFlow)
- def apply(name: String, tpe: Type , kind: Kind, flow: Flow): WRef = Reference(name, tpe, kind, flow)
+ def apply(name: String, tpe: Type, kind: Kind, flow: Flow): WRef = Reference(name, tpe, kind, flow)
def unapply(ref: Reference): Option[(String, Type, Kind, Flow)] = Some((ref.name, ref.tpe, ref.kind, ref.flow))
}
object WSubField {
- def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UnknownFlow)
- def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UnknownFlow)
- def apply(expr: Expression, name: String, tpe: Type, flow: Flow): WSubField = new WSubField(expr, name, tpe, flow)
+ def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UnknownFlow)
+ def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UnknownFlow)
+ def apply(expr: Expression, name: String, tpe: Type, flow: Flow): WSubField = new WSubField(expr, name, tpe, flow)
def unapply(wsf: WSubField): Option[(Expression, String, Type, Flow)] = Some((wsf.expr, wsf.name, wsf.tpe, wsf.flow))
}
object WSubIndex {
- def apply(expr: Expression, value: Int, tpe: Type, flow: Flow): WSubIndex = new WSubIndex(expr, value, tpe, flow)
+ def apply(expr: Expression, value: Int, tpe: Type, flow: Flow): WSubIndex = new WSubIndex(expr, value, tpe, flow)
def unapply(wsi: WSubIndex): Option[(Expression, Int, Type, Flow)] = Some((wsi.expr, wsi.value, wsi.tpe, wsi.flow))
}
object WSubAccess {
- def apply(expr: Expression, index: Expression, tpe: Type, flow: Flow): WSubAccess = new WSubAccess(expr, index, tpe, flow)
- def unapply(wsa: WSubAccess): Option[(Expression, Expression, Type, Flow)] = Some((wsa.expr, wsa.index, wsa.tpe, wsa.flow))
+ def apply(expr: Expression, index: Expression, tpe: Type, flow: Flow): WSubAccess =
+ new WSubAccess(expr, index, tpe, flow)
+ def unapply(wsa: WSubAccess): Option[(Expression, Expression, Type, Flow)] = Some(
+ (wsa.expr, wsa.index, wsa.tpe, wsa.flow)
+ )
}
case object WVoid extends Expression with UseSerializer {
def tpe = UnknownType
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case object WInvalid extends Expression with UseSerializer {
def tpe = UnknownType
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
}
// Useful for splitting then remerging references
case object EmptyExpression extends Expression with UseSerializer {
def tpe = UnknownType
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
}
object WDefInstance {
def apply(name: String, module: String): WDefInstance = new WDefInstance(NoInfo, name, module, UnknownType)
- def apply(info: Info, name: String, module: String, tpe: Type): WDefInstance = new WDefInstance(info, name, module, tpe)
+ def apply(info: Info, name: String, module: String, tpe: Type): WDefInstance =
+ new WDefInstance(info, name, module, tpe)
def unapply(wi: WDefInstance): Option[(Info, String, String, Type)] = {
Some((wi.info, wi.name, wi.module, wi.tpe))
}
}
case class WDefInstanceConnector(
- info: Info,
- name: String,
- module: String,
- tpe: Type,
- portCons: Seq[(Expression, Expression)]) extends Statement with IsDeclaration with UseSerializer {
+ info: Info,
+ name: String,
+ module: String,
+ tpe: Type,
+ portCons: Seq[(Expression, Expression)])
+ extends Statement
+ with IsDeclaration
+ with UseSerializer {
def mapExpr(f: Expression => Expression): Statement =
- this.copy(portCons = portCons map { case (e1, e2) => (f(e1), f(e2)) })
- def mapStmt(f: Statement => Statement): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = portCons foreach { case (e1, e2) => (f(e1), f(e2)) }
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ this.copy(portCons = portCons.map { case (e1, e2) => (f(e1), f(e2)) })
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = portCons.foreach { case (e1, e2) => (f(e1), f(e2)) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
// Resultant width is the same as the maximum input width
@@ -172,12 +186,12 @@ case object Dshlw extends PrimOp {
* @note This is not allowed to leak from any transform
*/
private[firrtl] case class InfoExpr(info: Info, expr: Expression) extends Expression {
- def foreachExpr(f: Expression => Unit): Unit = f(expr)
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = ()
- def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(this.expr))
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(this.expr))
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = this
def tpe: Type = expr.tpe
// Members declared in firrtl.ir.FirrtlNode
@@ -198,33 +212,35 @@ private[firrtl] object InfoExpr {
// TODO this the right name?
def map(expr: Expression)(f: Expression => Expression): Expression = expr match {
case ie: InfoExpr => ie.mapExpr(f)
- case e => f(e)
+ case e => f(e)
}
}
object WrappedExpression {
def apply(e: Expression) = new WrappedExpression(e)
- def we(e: Expression) = new WrappedExpression(e)
- def weq(e1: Expression, e2: Expression) = we(e1) == we(e2)
+ def we(e: Expression) = new WrappedExpression(e)
+ def weq(e1: Expression, e2: Expression) = we(e1) == we(e2)
}
class WrappedExpression(val e1: Expression) {
override def equals(we: Any) = we match {
- case (we: WrappedExpression) => (e1,we.e1) match {
- case (e1x: UIntLiteral, e2x: UIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width)
- case (e1x: SIntLiteral, e2x: SIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width)
- case (e1x: WRef, e2x: WRef) => e1x.name equals e2x.name
- case (e1x: WSubField, e2x: WSubField) => (e1x.name equals e2x.name) && weq(e1x.expr,e2x.expr)
- case (e1x: WSubIndex, e2x: WSubIndex) => (e1x.value == e2x.value) && weq(e1x.expr,e2x.expr)
- case (e1x: WSubAccess, e2x: WSubAccess) => weq(e1x.index,e2x.index) && weq(e1x.expr,e2x.expr)
- case (WVoid, WVoid) => true
- case (WInvalid, WInvalid) => true
- case (e1x: DoPrim, e2x: DoPrim) => e1x.op == e2x.op &&
- ((e1x.consts zip e2x.consts) forall {case (x, y) => x == y}) &&
- ((e1x.args zip e2x.args) forall {case (x, y) => weq(x, y)})
- case (e1x: Mux, e2x: Mux) => weq(e1x.cond,e2x.cond) && weq(e1x.tval,e2x.tval) && weq(e1x.fval,e2x.fval)
- case (e1x: ValidIf, e2x: ValidIf) => weq(e1x.cond,e2x.cond) && weq(e1x.value,e2x.value)
- case (e1x, e2x) => false
- }
+ case (we: WrappedExpression) =>
+ (e1, we.e1) match {
+ case (e1x: UIntLiteral, e2x: UIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width)
+ case (e1x: SIntLiteral, e2x: SIntLiteral) => e1x.value == e2x.value && eqw(e1x.width, e2x.width)
+ case (e1x: WRef, e2x: WRef) => e1x.name.equals(e2x.name)
+ case (e1x: WSubField, e2x: WSubField) => (e1x.name.equals(e2x.name)) && weq(e1x.expr, e2x.expr)
+ case (e1x: WSubIndex, e2x: WSubIndex) => (e1x.value == e2x.value) && weq(e1x.expr, e2x.expr)
+ case (e1x: WSubAccess, e2x: WSubAccess) => weq(e1x.index, e2x.index) && weq(e1x.expr, e2x.expr)
+ case (WVoid, WVoid) => true
+ case (WInvalid, WInvalid) => true
+ case (e1x: DoPrim, e2x: DoPrim) =>
+ e1x.op == e2x.op &&
+ ((e1x.consts.zip(e2x.consts)).forall { case (x, y) => x == y }) &&
+ ((e1x.args.zip(e2x.args)).forall { case (x, y) => weq(x, y) })
+ case (e1x: Mux, e2x: Mux) => weq(e1x.cond, e2x.cond) && weq(e1x.tval, e2x.tval) && weq(e1x.fval, e2x.fval)
+ case (e1x: ValidIf, e2x: ValidIf) => weq(e1x.cond, e2x.cond) && weq(e1x.value, e2x.value)
+ case (e1x, e2x) => false
+ }
case _ => false
}
override def hashCode = e1.serialize.hashCode
@@ -237,7 +253,7 @@ private[firrtl] sealed trait HasMapWidth {
object WrappedType {
def apply(t: Type) = new WrappedType(t)
- def wt(t: Type) = apply(t)
+ def wt(t: Type) = apply(t)
// Check if it is legal for the source type to drive the sink type
// Which is which matters because ResetType can be driven by itself, Bool, or AsyncResetType, but
// it cannot drive Bool nor AsyncResetType
@@ -245,10 +261,10 @@ object WrappedType {
(sink, source) match {
case (_: UIntType, _: UIntType) => true
case (_: SIntType, _: SIntType) => true
- case (ClockType, ClockType) => true
+ case (ClockType, ClockType) => true
case (AsyncResetType, AsyncResetType) => true
- case (ResetType, tpe) => legalResetType(tpe)
- case (tpe, ResetType) => legalResetType(tpe)
+ case (ResetType, tpe) => legalResetType(tpe)
+ case (tpe, ResetType) => legalResetType(tpe)
case (_: FixedType, _: FixedType) => true
case (_: IntervalType, _: IntervalType) => true
// Analog totally skips out of the Firrtl type system.
@@ -260,13 +276,14 @@ object WrappedType {
sink.size == source.size && compare(sink.tpe, source.tpe)
case (sink: BundleType, source: BundleType) =>
(sink.fields.size == source.fields.size) &&
- sink.fields.zip(source.fields).forall { case (f1, f2) =>
- (f1.flip == f2.flip) && (f1.name == f2.name) && (f1.flip match {
- case Default => compare(f1.tpe, f2.tpe)
- // We allow UInt<1> and AsyncReset to drive Reset but not the other way around
- case Flip => compare(f2.tpe, f1.tpe)
- })
- }
+ sink.fields.zip(source.fields).forall {
+ case (f1, f2) =>
+ (f1.flip == f2.flip) && (f1.name == f2.name) && (f1.flip match {
+ case Default => compare(f1.tpe, f2.tpe)
+ // We allow UInt<1> and AsyncReset to drive Reset but not the other way around
+ case Flip => compare(f2.tpe, f1.tpe)
+ })
+ }
case _ => false
}
}
@@ -287,7 +304,7 @@ object WrappedWidth {
def eqw(w1: Width, w2: Width): Boolean = new WrappedWidth(w1) == new WrappedWidth(w2)
}
-class WrappedWidth (val w: Width) {
+class WrappedWidth(val w: Width) {
def ww(w: Width): WrappedWidth = new WrappedWidth(w)
override def toString = w match {
case (w: VarWidth) => w.name
@@ -295,12 +312,13 @@ class WrappedWidth (val w: Width) {
case UnknownWidth => "?"
}
override def equals(o: Any): Boolean = o match {
- case (w2: WrappedWidth) => (w, w2.w) match {
- case (w1: VarWidth, w2: VarWidth) => w1.name.equals(w2.name)
- case (w1: IntWidth, w2: IntWidth) => w1.width == w2.width
- case (UnknownWidth, UnknownWidth) => true
- case _ => false
- }
+ case (w2: WrappedWidth) =>
+ (w, w2.w) match {
+ case (w1: VarWidth, w2: VarWidth) => w1.name.equals(w2.name)
+ case (w1: IntWidth, w2: IntWidth) => w1.width == w2.width
+ case (UnknownWidth, UnknownWidth) => true
+ case _ => false
+ }
case _ => false
}
}
@@ -320,37 +338,38 @@ case object MReadWrite extends MPortDir {
}
case class CDefMemory(
- info: Info,
- name: String,
- tpe: Type,
- size: BigInt,
- seq: Boolean,
- readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined) extends Statement with HasInfo with UseSerializer {
- def mapExpr(f: Expression => Expression): Statement = this
- def mapStmt(f: Statement => Statement): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ info: Info,
+ name: String,
+ tpe: Type,
+ size: BigInt,
+ seq: Boolean,
+ readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined)
+ extends Statement
+ with HasInfo
+ with UseSerializer {
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
-case class CDefMPort(info: Info,
- name: String,
- tpe: Type,
- mem: String,
- exps: Seq[Expression],
- direction: MPortDir) extends Statement with HasInfo with UseSerializer {
- def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps map f)
- def mapStmt(f: Statement => Statement): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = exps.foreach(f)
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+case class CDefMPort(info: Info, name: String, tpe: Type, mem: String, exps: Seq[Expression], direction: MPortDir)
+ extends Statement
+ with HasInfo
+ with UseSerializer {
+ def mapExpr(f: Expression => Expression): Statement = this.copy(exps = exps.map(f))
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = exps.foreach(f)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/analyses/CircuitGraph.scala b/src/main/scala/firrtl/analyses/CircuitGraph.scala
index 506bba57..a1fb0f19 100644
--- a/src/main/scala/firrtl/analyses/CircuitGraph.scala
+++ b/src/main/scala/firrtl/analyses/CircuitGraph.scala
@@ -80,9 +80,10 @@ class CircuitGraph private[analyses] (connectionGraph: ConnectionGraph) {
* @return
*/
def absolutePaths(mt: ModuleTarget): Seq[IsModule] = instanceGraph.findInstancesInHierarchy(mt.module).map {
- case seq if seq.nonEmpty => seq.foldLeft(CircuitTarget(circuit.main).module(circuit.main): IsModule) {
- case (it, InstanceKey(instance, ofModule)) => it.instOf(instance, ofModule)
- }
+ case seq if seq.nonEmpty =>
+ seq.foldLeft(CircuitTarget(circuit.main).module(circuit.main): IsModule) {
+ case (it, InstanceKey(instance, ofModule)) => it.instOf(instance, ofModule)
+ }
}
/** Return the sequence of nodes from source to sink, inclusive
diff --git a/src/main/scala/firrtl/analyses/ConnectionGraph.scala b/src/main/scala/firrtl/analyses/ConnectionGraph.scala
index 0e13711a..f98cf14c 100644
--- a/src/main/scala/firrtl/analyses/ConnectionGraph.scala
+++ b/src/main/scala/firrtl/analyses/ConnectionGraph.scala
@@ -16,22 +16,24 @@ import scala.collection.mutable
* @param circuit firrtl AST of this graph.
* @param digraph Directed graph of ReferenceTarget in the AST.
* @param irLookup [[IRLookup]] instance of circuit graph.
- * */
-class ConnectionGraph protected(val circuit: Circuit,
- val digraph: DiGraph[ReferenceTarget],
- val irLookup: IRLookup)
- extends DiGraph[ReferenceTarget](digraph.getEdgeMap.asInstanceOf[mutable.LinkedHashMap[ReferenceTarget, mutable.LinkedHashSet[ReferenceTarget]]]) {
+ */
+class ConnectionGraph protected (val circuit: Circuit, val digraph: DiGraph[ReferenceTarget], val irLookup: IRLookup)
+ extends DiGraph[ReferenceTarget](
+ digraph.getEdgeMap.asInstanceOf[mutable.LinkedHashMap[ReferenceTarget, mutable.LinkedHashSet[ReferenceTarget]]]
+ ) {
lazy val serialize: String = s"""{
- |${getEdgeMap.map { case (k, vs) =>
+ |${getEdgeMap.map {
+ case (k, vs) =>
s""" "$k": {
- | "kind": "${irLookup.kind(k)}",
- | "type": "${irLookup.tpe(k)}",
- | "expr": "${irLookup.expr(k, irLookup.flow(k))}",
- | "sinks": [${vs.map { v => s""""$v"""" }.mkString(", ")}],
- | "declaration": "${irLookup.declaration(k)}"
- | }""".stripMargin }.mkString(",\n")}
- |}""".stripMargin
+ | "kind": "${irLookup.kind(k)}",
+ | "type": "${irLookup.tpe(k)}",
+ | "expr": "${irLookup.expr(k, irLookup.flow(k))}",
+ | "sinks": [${vs.map { v => s""""$v"""" }.mkString(", ")}],
+ | "declaration": "${irLookup.declaration(k)}"
+ | }""".stripMargin
+ }.mkString(",\n")}
+ |}""".stripMargin
/** Used by BFS to map each visited node to the list of instance inputs visited thus far
*
@@ -134,7 +136,10 @@ class ConnectionGraph protected(val circuit: Circuit,
/** @return a new, reversed connection graph where edges point from sinks to sources. */
def reverseConnectionGraph: ConnectionGraph = new ConnectionGraph(circuit, digraph.reverse, irLookup)
- override def BFS(root: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): collection.Map[ReferenceTarget, ReferenceTarget] = {
+ override def BFS(
+ root: ReferenceTarget,
+ blacklist: collection.Set[ReferenceTarget]
+ ): collection.Map[ReferenceTarget, ReferenceTarget] = {
val prev = new mutable.LinkedHashMap[ReferenceTarget, ReferenceTarget]()
val ordering = new Ordering[ReferenceTarget] {
override def compare(x: ReferenceTarget, y: ReferenceTarget): Int = x.path.size - y.path.size
@@ -216,7 +221,6 @@ class ConnectionGraph protected(val circuit: Circuit,
bfsShortCuts.get(localSource) match {
case Some(set) => set.map { x => x.setPathTarget(source.pathTarget) }
case None =>
-
val pathlessEdges = super.getEdges(localSource)
val ret = pathlessEdges.flatMap {
@@ -246,7 +250,9 @@ class ConnectionGraph protected(val circuit: Circuit,
// Exiting to parent, but had unresolved trip through child, so don't update shortcut
portConnectivityStack(localSink) = localSource +: currentStack
}
- Set[ReferenceTarget](localSink.setPathTarget(source.noComponents.targetParent.asInstanceOf[IsComponent].pathTarget))
+ Set[ReferenceTarget](
+ localSink.setPathTarget(source.noComponents.targetParent.asInstanceOf[IsComponent].pathTarget)
+ )
case localSink if enteringChildInstance(source)(localSink) =>
portConnectivityStack(localSink) = localSource +: portConnectivityStack.getOrElse(localSource, Nil)
@@ -265,24 +271,31 @@ class ConnectionGraph protected(val circuit: Circuit,
}
- override def path(start: ReferenceTarget, end: ReferenceTarget, blacklist: collection.Set[ReferenceTarget]): Seq[ReferenceTarget] = {
+ override def path(
+ start: ReferenceTarget,
+ end: ReferenceTarget,
+ blacklist: collection.Set[ReferenceTarget]
+ ): Seq[ReferenceTarget] = {
insertShortCuts(super.path(start, end, blacklist))
}
private def insertShortCuts(path: Seq[ReferenceTarget]): Seq[ReferenceTarget] = {
val soFar = mutable.HashSet[ReferenceTarget]()
if (path.size > 1) {
- path.head +: path.sliding(2).flatMap {
- case Seq(from, to) =>
- getShortCut(from) match {
- case Some(set) if set.contains(to) && soFar.contains(from.pathlessTarget) =>
- soFar += from.pathlessTarget
- Seq(from.pathTarget.ref("..."), to)
- case _ =>
- soFar += from.pathlessTarget
- Seq(to)
- }
- }.toSeq
+ path.head +: path
+ .sliding(2)
+ .flatMap {
+ case Seq(from, to) =>
+ getShortCut(from) match {
+ case Some(set) if set.contains(to) && soFar.contains(from.pathlessTarget) =>
+ soFar += from.pathlessTarget
+ Seq(from.pathTarget.ref("..."), to)
+ case _ =>
+ soFar += from.pathlessTarget
+ Seq(to)
+ }
+ }
+ .toSeq
} else path
}
@@ -325,16 +338,16 @@ object ConnectionGraph {
* @return
*/
def asTarget(m: ModuleTarget, tagger: TokenTagger)(e: FirrtlNode): ReferenceTarget = e match {
- case l: Literal => m.ref(tagger.getRef(l.value.toString))
+ case l: Literal => m.ref(tagger.getRef(l.value.toString))
case r: Reference => m.ref(r.name)
- case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value)
- case s: SubField => asTarget(m, tagger)(s.expr).field(s.name)
- case d: DoPrim => m.ref(tagger.getRef(d.op.serialize))
- case _: Mux => m.ref(tagger.getRef("mux"))
- case _: ValidIf => m.ref(tagger.getRef("validif"))
+ case s: SubIndex => asTarget(m, tagger)(s.expr).index(s.value)
+ case s: SubField => asTarget(m, tagger)(s.expr).field(s.name)
+ case d: DoPrim => m.ref(tagger.getRef(d.op.serialize))
+ case _: Mux => m.ref(tagger.getRef("mux"))
+ case _: ValidIf => m.ref(tagger.getRef("validif"))
case WInvalid => m.ref(tagger.getRef("invalid"))
case _: Print => m.ref(tagger.getRef("print"))
- case _: Stop => m.ref(tagger.getRef("print"))
+ case _: Stop => m.ref(tagger.getRef("print"))
case other => sys.error(s"Unsupported: $other")
}
@@ -354,30 +367,31 @@ object ConnectionGraph {
def enteringNonParentInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = {
source.path.nonEmpty &&
- (source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule != localSink.module ||
- localSink.ref != source.path.last._1.value)
+ (source.noComponents.targetParent.asInstanceOf[InstanceTarget].encapsulatingModule != localSink.module ||
+ localSink.ref != source.path.last._1.value)
}
def enteringChildInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match {
case ReferenceTarget(_, _, _, _, TargetToken.Field(port) +: comps)
- if port == localSink.ref && comps == localSink.component => true
+ if port == localSink.ref && comps == localSink.component =>
+ true
case _ => false
}
def leavingRootInstance(source: ReferenceTarget)(localSink: ReferenceTarget): Boolean = source match {
case ReferenceTarget(_, _, Seq(), port, comps)
- if port == localSink.component.head.value && comps == localSink.component.tail => true
+ if port == localSink.component.head.value && comps == localSink.component.tail =>
+ true
case _ => false
}
-
private def buildCircuitGraph(circuit: Circuit): ConnectionGraph = {
val mdg = new MutableDiGraph[ReferenceTarget]()
val declarations = mutable.LinkedHashMap[ModuleTarget, mutable.LinkedHashMap[ReferenceTarget, FirrtlNode]]()
val circuitTarget = CircuitTarget(circuit.main)
val moduleMap = circuit.modules.map { m => circuitTarget.module(m.name) -> m }.toMap
- circuit map buildModule(circuitTarget)
+ circuit.map(buildModule(circuitTarget))
def addLabeledVertex(v: ReferenceTarget, f: FirrtlNode): Unit = {
mdg.addVertex(v)
@@ -386,7 +400,7 @@ object ConnectionGraph {
def buildModule(c: CircuitTarget)(module: DefModule): DefModule = {
val m = c.module(module.name)
- module map buildPort(m) map buildStatement(m, new TokenTagger())
+ module.map(buildPort(m)).map(buildStatement(m, new TokenTagger()))
}
def buildPort(m: ModuleTarget)(port: Port): Port = {
@@ -412,7 +426,7 @@ object ConnectionGraph {
(Utils.flow(instExp), Utils.flow(modExp)) match {
case (SourceFlow, SinkFlow) => mdg.addPairWithEdge(it, mt)
case (SinkFlow, SourceFlow) => mdg.addPairWithEdge(mt, it)
- case _ => sys.error("Something went wrong...")
+ case _ => sys.error("Something went wrong...")
}
}
}
@@ -461,13 +475,14 @@ object ConnectionGraph {
// Connect each subTarget to the corresponding init subTarget
val allRegTargets = regTarget.leafSubTargets(d.tpe)
val allInitTargets = initTarget.leafSubTargets(d.tpe).zip(Utils.create_exps(d.init))
- allRegTargets.zip(allInitTargets).foreach { case (r, (i, e)) =>
- mdg.addVertex(i)
- mdg.addVertex(r)
- mdg.addEdge(clockTarget, r)
- mdg.addEdge(resetTarget, r)
- mdg.addEdge(i, r)
- buildExpression(m, tagger, i)(e)
+ allRegTargets.zip(allInitTargets).foreach {
+ case (r, (i, e)) =>
+ mdg.addVertex(i)
+ mdg.addVertex(r)
+ mdg.addEdge(clockTarget, r)
+ mdg.addEdge(resetTarget, r)
+ mdg.addEdge(i, r)
+ buildExpression(m, tagger, i)(e)
}
}
@@ -480,9 +495,10 @@ object ConnectionGraph {
val sinkTarget = m.ref(d.name)
addLabeledVertex(sinkTarget, stmt)
val nodeTargets = sinkTarget.leafSubTargets(d.value.tpe)
- nodeTargets.zip(Utils.create_exps(d.value)).foreach { case (n, e) =>
- mdg.addVertex(n)
- buildExpression(m, tagger, n)(e)
+ nodeTargets.zip(Utils.create_exps(d.value)).foreach {
+ case (n, e) =>
+ mdg.addVertex(n)
+ buildExpression(m, tagger, n)(e)
}
case c: Connect =>
@@ -512,10 +528,10 @@ object ConnectionGraph {
addLabeledVertex(m.ref(d.name), d)
buildMemory(m, d)
- /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.passes.ExpandWhensAndCheck]]*/
+ /** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.passes.ExpandWhensAndCheck]] */
case _: Conditionally => sys.error("Unsupported! Only works on Middle Firrtl")
- case s: Block => s map buildStatement(m, tagger)
+ case s: Block => s.map(buildStatement(m, tagger))
case a: Attach =>
val attachTargets = a.exprs.map { r =>
@@ -523,18 +539,25 @@ object ConnectionGraph {
mdg.addVertex(at)
at
}
- attachTargets.combinations(2).foreach { case Seq(l, r) =>
- mdg.addEdge(l, r)
- mdg.addEdge(r, l)
+ attachTargets.combinations(2).foreach {
+ case Seq(l, r) =>
+ mdg.addEdge(l, r)
+ mdg.addEdge(r, l)
}
case p: Print => addLabeledVertex(asTarget(m, tagger)(p), p)
- case s: Stop => addLabeledVertex(asTarget(m, tagger)(s), s)
+ case s: Stop => addLabeledVertex(asTarget(m, tagger)(s), s)
case EmptyStmt =>
}
stmt
}
- def buildExpression(m: ModuleTarget, tagger: TokenTagger, sinkTarget: ReferenceTarget)(expr: Expression): Expression = {
+ def buildExpression(
+ m: ModuleTarget,
+ tagger: TokenTagger,
+ sinkTarget: ReferenceTarget
+ )(expr: Expression
+ ): Expression = {
+
/** @todo [[firrtl.Transform.prerequisites]] ++ [[firrtl.stage.Forms.Resolved]]. */
val sourceTarget = asTarget(m, tagger)(expr)
mdg.addVertex(sourceTarget)
@@ -542,7 +565,7 @@ object ConnectionGraph {
expr match {
case _: DoPrim | _: Mux | _: ValidIf | _: Literal =>
addLabeledVertex(sourceTarget, expr)
- expr map buildExpression(m, tagger, sourceTarget)
+ expr.map(buildExpression(m, tagger, sourceTarget))
case _ =>
}
expr
@@ -552,7 +575,6 @@ object ConnectionGraph {
}
}
-
/** Used for obtaining a tag for a given label unnamed Target. */
class TokenTagger {
private val counterMap = mutable.HashMap[String, Int]()
diff --git a/src/main/scala/firrtl/analyses/IRLookup.scala b/src/main/scala/firrtl/analyses/IRLookup.scala
index f9819ebd..b8528a95 100644
--- a/src/main/scala/firrtl/analyses/IRLookup.scala
+++ b/src/main/scala/firrtl/analyses/IRLookup.scala
@@ -6,7 +6,22 @@ import firrtl.annotations.TargetToken._
import firrtl.annotations._
import firrtl.ir._
import firrtl.passes.MemPortUtils
-import firrtl.{DuplexFlow, ExpKind, Flow, InstanceKind, Kind, MemKind, PortKind, RegKind, SinkFlow, SourceFlow, UnknownFlow, Utils, WInvalid, WireKind}
+import firrtl.{
+ DuplexFlow,
+ ExpKind,
+ Flow,
+ InstanceKind,
+ Kind,
+ MemKind,
+ PortKind,
+ RegKind,
+ SinkFlow,
+ SourceFlow,
+ UnknownFlow,
+ Utils,
+ WInvalid,
+ WireKind
+}
import scala.collection.mutable
@@ -19,26 +34,33 @@ object IRLookup {
* @param declarations Maps references (not subreferences) to declarations
* @param modules Maps module targets to modules
*/
-class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map[ReferenceTarget, FirrtlNode]],
- private val modules: Map[ModuleTarget, DefModule]) {
+class IRLookup private[analyses] (
+ private val declarations: Map[ModuleTarget, Map[ReferenceTarget, FirrtlNode]],
+ private val modules: Map[ModuleTarget, DefModule]) {
private val flowCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Flow]]()
private val kindCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Kind]]()
private val tpeCache = mutable.HashMap[ModuleTarget, mutable.HashMap[ReferenceTarget, Type]]()
private val exprCache = mutable.HashMap[ModuleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]]()
- private val refCache = mutable.HashMap[ModuleTarget, mutable.LinkedHashMap[Kind, mutable.ArrayBuffer[ReferenceTarget]]]()
-
+ private val refCache =
+ mutable.HashMap[ModuleTarget, mutable.LinkedHashMap[Kind, mutable.ArrayBuffer[ReferenceTarget]]]()
/** @example Given ~Top|MyModule/inst:Other>foo.bar, returns ~Top|Other>foo
* @return the target converted to its local reference
*/
def asLocalRef(t: ReferenceTarget): ReferenceTarget = t.pathlessTarget.copy(component = Nil)
- def flow(t: ReferenceTarget): Flow = flowCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Flow]()).getOrElseUpdate(t.pathlessTarget, Utils.flow(expr(t.pathlessTarget)))
+ def flow(t: ReferenceTarget): Flow = flowCache
+ .getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Flow]())
+ .getOrElseUpdate(t.pathlessTarget, Utils.flow(expr(t.pathlessTarget)))
- def kind(t: ReferenceTarget): Kind = kindCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Kind]()).getOrElseUpdate(t.pathlessTarget, Utils.kind(expr(t.pathlessTarget)))
+ def kind(t: ReferenceTarget): Kind = kindCache
+ .getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Kind]())
+ .getOrElseUpdate(t.pathlessTarget, Utils.kind(expr(t.pathlessTarget)))
- def tpe(t: ReferenceTarget): Type = tpeCache.getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Type]()).getOrElseUpdate(t.pathlessTarget, expr(t.pathlessTarget).tpe)
+ def tpe(t: ReferenceTarget): Type = tpeCache
+ .getOrElseUpdate(t.moduleTarget, mutable.HashMap[ReferenceTarget, Type]())
+ .getOrElseUpdate(t.pathlessTarget, expr(t.pathlessTarget).tpe)
/** get expression of the target.
* It can return None for many reasons, including
@@ -54,7 +76,7 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
val pathless = t.pathlessTarget
inCache(pathless, flow) match {
- case e@Some(_) => return e
+ case e @ Some(_) => return e
case None =>
val mt = pathless.moduleTarget
val emt = t.encapsulatingModuleTarget
@@ -62,36 +84,50 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
declarations(emt)(asLocalRef(t)) match {
case e: Expression =>
require(e.tpe.isInstanceOf[GroundType])
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).getOrElseUpdate((pathless, Utils.flow(e)), e)
- case d: IsDeclaration => d match {
- case n: DefNode =>
- updateExpr(mt, Reference(n.name, n.value.tpe, ExpKind, SourceFlow))
- case p: Port =>
- updateExpr(mt, Reference(p.name, p.tpe, PortKind, Utils.get_flow(p)))
- case w: DefInstance =>
- updateExpr(mt, Reference(w.name, w.tpe, InstanceKind, SourceFlow))
- case w: DefWire =>
- updateExpr(mt, Reference(w.name, w.tpe, WireKind, SourceFlow))
- updateExpr(mt, Reference(w.name, w.tpe, WireKind, SinkFlow))
- updateExpr(mt, Reference(w.name, w.tpe, WireKind, DuplexFlow))
- case r: DefRegister if pathless.tokens.last == Clock =>
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.clock
- case r: DefRegister if pathless.tokens.isDefinedAt(1) && pathless.tokens(1) == Init =>
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.init
- updateExpr(pathless, r.init)
- case r: DefRegister if pathless.tokens.last == Reset =>
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = r.reset
- case r: DefRegister =>
- updateExpr(mt, Reference(r.name, r.tpe, RegKind, SourceFlow))
- updateExpr(mt, Reference(r.name, r.tpe, RegKind, SinkFlow))
- updateExpr(mt, Reference(r.name, r.tpe, RegKind, DuplexFlow))
- case m: DefMemory =>
- updateExpr(mt, Reference(m.name, MemPortUtils.memType(m), MemKind, SourceFlow))
- case other =>
- sys.error(s"Cannot call expr with: $t, given declaration $other")
- }
+ exprCache
+ .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())
+ .getOrElseUpdate((pathless, Utils.flow(e)), e)
+ case d: IsDeclaration =>
+ d match {
+ case n: DefNode =>
+ updateExpr(mt, Reference(n.name, n.value.tpe, ExpKind, SourceFlow))
+ case p: Port =>
+ updateExpr(mt, Reference(p.name, p.tpe, PortKind, Utils.get_flow(p)))
+ case w: DefInstance =>
+ updateExpr(mt, Reference(w.name, w.tpe, InstanceKind, SourceFlow))
+ case w: DefWire =>
+ updateExpr(mt, Reference(w.name, w.tpe, WireKind, SourceFlow))
+ updateExpr(mt, Reference(w.name, w.tpe, WireKind, SinkFlow))
+ updateExpr(mt, Reference(w.name, w.tpe, WireKind, DuplexFlow))
+ case r: DefRegister if pathless.tokens.last == Clock =>
+ exprCache.getOrElseUpdate(
+ pathless.moduleTarget,
+ mutable.HashMap[(ReferenceTarget, Flow), Expression]()
+ )((pathless, SourceFlow)) = r.clock
+ case r: DefRegister if pathless.tokens.isDefinedAt(1) && pathless.tokens(1) == Init =>
+ exprCache.getOrElseUpdate(
+ pathless.moduleTarget,
+ mutable.HashMap[(ReferenceTarget, Flow), Expression]()
+ )((pathless, SourceFlow)) = r.init
+ updateExpr(pathless, r.init)
+ case r: DefRegister if pathless.tokens.last == Reset =>
+ exprCache.getOrElseUpdate(
+ pathless.moduleTarget,
+ mutable.HashMap[(ReferenceTarget, Flow), Expression]()
+ )((pathless, SourceFlow)) = r.reset
+ case r: DefRegister =>
+ updateExpr(mt, Reference(r.name, r.tpe, RegKind, SourceFlow))
+ updateExpr(mt, Reference(r.name, r.tpe, RegKind, SinkFlow))
+ updateExpr(mt, Reference(r.name, r.tpe, RegKind, DuplexFlow))
+ case m: DefMemory =>
+ updateExpr(mt, Reference(m.name, MemPortUtils.memType(m), MemKind, SourceFlow))
+ case other =>
+ sys.error(s"Cannot call expr with: $t, given declaration $other")
+ }
case _: IsInvalid =>
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)) = WInvalid
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, SourceFlow)
+ ) = WInvalid
}
}
}
@@ -118,7 +154,8 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
*
* @param moduleTarget [[firrtl.annotations.ModuleTarget]] to be queried.
* @param kind [[firrtl.Kind]] to be find.
- * @return all [[firrtl.annotations.ReferenceTarget]] in this node. */
+ * @return all [[firrtl.annotations.ReferenceTarget]] in this node.
+ */
def kindFinder(moduleTarget: ModuleTarget, kind: Kind): Seq[ReferenceTarget] = {
def updateRefs(kind: Kind, rt: ReferenceTarget): Unit = refCache
.getOrElseUpdate(rt.moduleTarget, mutable.LinkedHashMap.empty[Kind, mutable.ArrayBuffer[ReferenceTarget]])
@@ -136,7 +173,11 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
case (rt, _: Port) => updateRefs(PortKind, rt)
case _ =>
}
- refCache.get(moduleTarget).map(_.getOrElse(kind, Seq.empty[ReferenceTarget])).getOrElse(Seq.empty[ReferenceTarget]).toSeq
+ refCache
+ .get(moduleTarget)
+ .map(_.getOrElse(kind, Seq.empty[ReferenceTarget]))
+ .getOrElse(Seq.empty[ReferenceTarget])
+ .toSeq
}
}
@@ -181,7 +222,7 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
def moduleLeafPortTargets(m: ModuleTarget): (Seq[(ReferenceTarget, Type)], Seq[(ReferenceTarget, Type)]) =
modules(m).ports.flatMap {
case Port(_, name, Output, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SourceFlow))
- case Port(_, name, Input, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow))
+ case Port(_, name, Input, tpe) => Utils.create_exps(Reference(name, tpe, PortKind, SinkFlow))
}.foldLeft((Vector.empty[(ReferenceTarget, Type)], Vector.empty[(ReferenceTarget, Type)])) {
case ((inputs, outputs), e) if Utils.flow(e) == SourceFlow =>
(inputs, outputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe))
@@ -189,7 +230,6 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
(inputs :+ (ConnectionGraph.asTarget(m, new TokenTagger())(e), e.tpe), outputs)
}
-
/** @param t [[firrtl.annotations.ReferenceTarget]] to be queried.
* @return whether a ReferenceTarget is contained in this IRLookup
*/
@@ -213,10 +253,10 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
val all = i.pathAsTargets :+ i.encapsulatingModuleTarget.instOf(i.instance, i.ofModule)
all.map { x =>
declarations.contains(x.moduleTarget) && declarations(x.moduleTarget).contains(x.asReference) &&
- (declarations(x.moduleTarget)(x.asReference) match {
- case DefInstance(_, _, of, _) if of == x.ofModule => validPath(x.ofModuleTarget)
- case _ => false
- })
+ (declarations(x.moduleTarget)(x.asReference) match {
+ case DefInstance(_, _, of, _) if of == x.ofModule => validPath(x.ofModuleTarget)
+ case _ => false
+ })
}.reduce(_ && _)
}
}
@@ -248,17 +288,54 @@ class IRLookup private[analyses](private val declarations: Map[ModuleTarget, Map
/** Optionally returns the expression corresponding to the target if contained in the expression cache. */
private def inCache(pathless: ReferenceTarget, flow: Flow): Option[Expression] = {
- (flow,
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SourceFlow)),
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains((pathless, SinkFlow)),
- exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]()).contains(pathless, DuplexFlow)
+ (
+ flow,
+ exprCache
+ .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())
+ .contains((pathless, SourceFlow)),
+ exprCache
+ .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())
+ .contains((pathless, SinkFlow)),
+ exprCache
+ .getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())
+ .contains(pathless, DuplexFlow)
) match {
- case (SourceFlow, true, _, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow)))
- case (SinkFlow, _, true, _) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, flow)))
- case (DuplexFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow)))
- case (UnknownFlow, _, _, true) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, DuplexFlow)))
- case (UnknownFlow, true, false, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SourceFlow)))
- case (UnknownFlow, false, true, false) => Some(exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())((pathless, SinkFlow)))
+ case (SourceFlow, true, _, _) =>
+ Some(
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, flow)
+ )
+ )
+ case (SinkFlow, _, true, _) =>
+ Some(
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, flow)
+ )
+ )
+ case (DuplexFlow, _, _, true) =>
+ Some(
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, DuplexFlow)
+ )
+ )
+ case (UnknownFlow, _, _, true) =>
+ Some(
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, DuplexFlow)
+ )
+ )
+ case (UnknownFlow, true, false, false) =>
+ Some(
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, SourceFlow)
+ )
+ )
+ case (UnknownFlow, false, true, false) =>
+ Some(
+ exprCache.getOrElseUpdate(pathless.moduleTarget, mutable.HashMap[(ReferenceTarget, Flow), Expression]())(
+ (pathless, SinkFlow)
+ )
+ )
case _ => None
}
}
diff --git a/src/main/scala/firrtl/analyses/InstanceGraph.scala b/src/main/scala/firrtl/analyses/InstanceGraph.scala
index f994b39a..4aab9a3a 100644
--- a/src/main/scala/firrtl/analyses/InstanceGraph.scala
+++ b/src/main/scala/firrtl/analyses/InstanceGraph.scala
@@ -10,7 +10,6 @@ import firrtl.Utils._
import firrtl.traversals.Foreachers._
import firrtl.annotations.TargetToken._
-
/** A class representing the instance hierarchy of a working IR Circuit
*
* @constructor constructs an instance graph from a Circuit
@@ -29,7 +28,7 @@ import firrtl.annotations.TargetToken._
class InstanceGraph(c: Circuit) {
@deprecated("Use InstanceKeyGraph.moduleMap instead.", "FIRRTL 1.4")
- val moduleMap = c.modules.map({m => (m.name,m) }).toMap
+ val moduleMap = c.modules.map({ m => (m.name, m) }).toMap
private val instantiated = new mutable.LinkedHashSet[String]
private val childInstances =
new mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]]
@@ -43,7 +42,7 @@ class InstanceGraph(c: Circuit) {
private val instanceQueue = new mutable.Queue[DefInstance]
for (subTop <- c.modules.view.map(_.name).filterNot(instantiated)) {
- val topInstance = DefInstance(subTop,subTop)
+ val topInstance = DefInstance(subTop, subTop)
instanceQueue.enqueue(topInstance)
while (instanceQueue.nonEmpty) {
val current = instanceQueue.dequeue
@@ -53,7 +52,7 @@ class InstanceGraph(c: Circuit) {
instanceQueue.enqueue(child)
instanceGraph.addVertex(child)
}
- instanceGraph.addEdge(current,child)
+ instanceGraph.addEdge(current, child)
}
}
}
@@ -73,7 +72,7 @@ class InstanceGraph(c: Circuit) {
* of all module instances in the Circuit.
*/
@deprecated("Use InstanceKeyGraph.fullHierarchy instead.", "FIRRTL 1.4")
- lazy val fullHierarchy: mutable.LinkedHashMap[DefInstance,Seq[Seq[DefInstance]]] = graph.pathsInDAG(trueTopInstance)
+ lazy val fullHierarchy: mutable.LinkedHashMap[DefInstance, Seq[Seq[DefInstance]]] = graph.pathsInDAG(trueTopInstance)
/** A count of the *static* number of instances of each module. For any module other than the top (main) module, this is
* equivalent to the number of inst statements in the circuit instantiating each module, irrespective of the number
@@ -85,7 +84,7 @@ class InstanceGraph(c: Circuit) {
lazy val staticInstanceCount: Map[OfModule, Int] = {
val foo = mutable.LinkedHashMap.empty[OfModule, Int]
childInstances.keys.foreach {
- case main if main == c.main => foo += main.OfModule -> 1
+ case main if main == c.main => foo += main.OfModule -> 1
case other => foo += other.OfModule -> 0
}
childInstances.values.flatten.map(_.OfModule).foreach {
@@ -106,7 +105,7 @@ class InstanceGraph(c: Circuit) {
@deprecated("Use InstanceKeyGraph.findInstancesInHierarchy instead (now with caching of vertices!).", "FIRRTL 1.4")
def findInstancesInHierarchy(module: String): Seq[Seq[DefInstance]] = {
val instances = graph.getVertices.filter(_.module == module).toSeq
- instances flatMap { i => fullHierarchy.getOrElse(i, Nil) }
+ instances.flatMap { i => fullHierarchy.getOrElse(i, Nil) }
}
/** An [[firrtl.graph.EulerTour EulerTour]] representation of the [[firrtl.graph.DiGraph DiGraph]] */
@@ -117,8 +116,7 @@ class InstanceGraph(c: Circuit) {
* a design
*/
@deprecated("Use InstanceKeyGraph and EulerTour(iGraph.graph, iGraph.top).rmq(moduleA, moduleB).", "FIRRTL 1.4")
- def lowestCommonAncestor(moduleA: Seq[DefInstance],
- moduleB: Seq[DefInstance]): Seq[DefInstance] = {
+ def lowestCommonAncestor(moduleA: Seq[DefInstance], moduleB: Seq[DefInstance]): Seq[DefInstance] = {
tour.rmq(moduleA, moduleB)
}
@@ -131,10 +129,9 @@ class InstanceGraph(c: Circuit) {
graph.transformNodes(_.module).linearize.map(moduleMap(_))
}
-
/** Given a circuit, returns a map from module name to children
- * instance/module definitions
- */
+ * instance/module definitions
+ */
@deprecated("Use InstanceKeyGraph.getChildInstances instead.", "FIRRTL 1.4")
def getChildrenInstances: mutable.LinkedHashMap[String, mutable.LinkedHashSet[DefInstance]] = childInstances
@@ -172,7 +169,7 @@ class InstanceGraph(c: Circuit) {
/** The set of all modules *not* reachable in the circuit */
@deprecated("Use InstanceKeyGraph.unreachableModules instead.", "FIRRTL 1.4")
- lazy val unreachableModules: collection.Set[OfModule] = modules diff reachableModules
+ lazy val unreachableModules: collection.Set[OfModule] = modules.diff(reachableModules)
}
@@ -186,10 +183,9 @@ object InstanceGraph {
* @return
*/
@deprecated("Use InstanceKeyGraph.collectInstances instead.", "FIRRTL 1.4")
- def collectInstances(insts: mutable.Set[DefInstance])
- (s: Statement): Unit = s match {
- case i: DefInstance => insts += i
- case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!")
+ def collectInstances(insts: mutable.Set[DefInstance])(s: Statement): Unit = s match {
+ case i: DefInstance => insts += i
+ case i: DefInstance => throwInternalError("Expecting DefInstance, found a DefInstance!")
case i: WDefInstanceConnector => throwInternalError("Expecting DefInstance, found a DefInstanceConnector!")
case _ => s.foreach(collectInstances(insts))
}
diff --git a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala
index 761315dc..5354888d 100644
--- a/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala
+++ b/src/main/scala/firrtl/analyses/InstanceKeyGraph.scala
@@ -14,10 +14,10 @@ import scala.collection.mutable
* pairs of InstanceName and Module name as vertex keys instead of using WDefInstance
* which will hash the instance type causing some performance issues.
*/
-class InstanceKeyGraph private(c: ir.Circuit) {
+class InstanceKeyGraph private (c: ir.Circuit) {
import InstanceKeyGraph._
- private val nameToModule: Map[String, ir.DefModule] = c.modules.map({m => (m.name,m) }).toMap
+ private val nameToModule: Map[String, ir.DefModule] = c.modules.map({ m => (m.name, m) }).toMap
private val childInstances: Seq[(String, Seq[InstanceKey])] = c.modules.map { m =>
m.name -> InstanceKeyGraph.collectInstances(m)
}
@@ -37,8 +37,8 @@ class InstanceKeyGraph private(c: ir.Circuit) {
circuitTopInstance.OfModule +: internalGraph.reachableFrom(circuitTopInstance).toSeq.map(_.OfModule)
private lazy val cachedUnreachableModules: Seq[OfModule] = {
- val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1)):_*)
- val reachable = mutable.LinkedHashSet(cachedReachableModules:_*)
+ val all = mutable.LinkedHashSet(childInstances.map(c => OfModule(c._1)): _*)
+ val reachable = mutable.LinkedHashSet(cachedReachableModules: _*)
all.diff(reachable).toSeq
}
@@ -68,11 +68,11 @@ class InstanceKeyGraph private(c: ir.Circuit) {
private lazy val cachedStaticInstanceCount = {
val foo = mutable.LinkedHashMap.empty[OfModule, Int]
childInstances.foreach {
- case (main, _) if main == c.main => foo += main.OfModule -> 1
+ case (main, _) if main == c.main => foo += main.OfModule -> 1
case (other, _) => foo += other.OfModule -> 0
}
- childInstances.flatMap(_._2).map(_.OfModule).foreach {
- mod => foo += mod -> (foo(mod) + 1)
+ childInstances.flatMap(_._2).map(_.OfModule).foreach { mod =>
+ foo += mod -> (foo(mod) + 1)
}
foo.toMap
}
@@ -88,17 +88,18 @@ class InstanceKeyGraph private(c: ir.Circuit) {
*/
def findInstancesInHierarchy(module: String): Seq[Seq[InstanceKey]] = {
val instances = vertices.filter(_.module == module).toSeq
- instances.flatMap{ i => cachedFullHierarchy.getOrElse(i, Nil) }
+ instances.flatMap { i => cachedFullHierarchy.getOrElse(i, Nil) }
}
/** Given a circuit, returns a map from module name to a map
* in turn mapping instances names to corresponding module names
*/
def getChildInstanceMap: mutable.LinkedHashMap[OfModule, mutable.LinkedHashMap[Instance, OfModule]] =
- mutable.LinkedHashMap(childInstances.map { case (k, v) =>
- val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens):_*)
- TargetToken.OfModule(k) -> moduleMap
- }:_*)
+ mutable.LinkedHashMap(childInstances.map {
+ case (k, v) =>
+ val moduleMap: mutable.LinkedHashMap[Instance, OfModule] = mutable.LinkedHashMap(v.map(_.toTokens): _*)
+ TargetToken.OfModule(k) -> moduleMap
+ }: _*)
/** All modules in the circuit reachable from the top module */
def reachableModules: Seq[OfModule] = cachedReachableModules
@@ -110,7 +111,6 @@ class InstanceKeyGraph private(c: ir.Circuit) {
def fullHierarchy: mutable.LinkedHashMap[InstanceKey, Seq[Seq[InstanceKey]]] = cachedFullHierarchy
}
-
object InstanceKeyGraph {
def apply(c: ir.Circuit): InstanceKeyGraph = new InstanceKeyGraph(c)
@@ -126,12 +126,12 @@ object InstanceKeyGraph {
/** Finds all instance definitions in a firrtl Module. */
def collectInstances(m: ir.DefModule): Seq[InstanceKey] = m match {
- case _ : ir.ExtModule => Seq()
+ case _: ir.ExtModule => Seq()
case ir.Module(_, _, _, body) => {
val instances = mutable.ArrayBuffer[InstanceKey]()
def onStmt(s: ir.Statement): Unit = s match {
case firrtl.WDefInstance(_, name, module, _) => instances += InstanceKey(name, module)
- case ir.DefInstance(_, name, module, _) => instances += InstanceKey(name, module)
+ case ir.DefInstance(_, name, module, _) => instances += InstanceKey(name, module)
case _: firrtl.WDefInstanceConnector =>
firrtl.Utils.throwInternalError("Expecting WDefInstance, found a WDefInstanceConnector!")
case other => other.foreachStmt(onStmt)
@@ -143,8 +143,10 @@ object InstanceKeyGraph {
private def topKey(module: String): InstanceKey = InstanceKey(module, module)
- private def buildGraph(childInstances: Seq[(String, Seq[InstanceKey])], roots: Iterable[String]):
- DiGraph[InstanceKey] = {
+ private def buildGraph(
+ childInstances: Seq[(String, Seq[InstanceKey])],
+ roots: Iterable[String]
+ ): DiGraph[InstanceKey] = {
val instanceGraph = new MutableDiGraph[InstanceKey]
val childInstanceMap = childInstances.toMap
diff --git a/src/main/scala/firrtl/analyses/NodeCount.scala b/src/main/scala/firrtl/analyses/NodeCount.scala
index 0276f4f5..63571503 100644
--- a/src/main/scala/firrtl/analyses/NodeCount.scala
+++ b/src/main/scala/firrtl/analyses/NodeCount.scala
@@ -21,18 +21,19 @@ class NodeCount private (node: FirrtlNode) {
@tailrec
private final def rec(xs: List[Any]): Unit =
- if (xs.isEmpty) { }
- else {
+ if (xs.isEmpty) {} else {
val node = xs.head
- require(node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode],
- "Unexpected FirrtlNode that does not implement Product!")
+ require(
+ node.isInstanceOf[Product] || !node.isInstanceOf[FirrtlNode],
+ "Unexpected FirrtlNode that does not implement Product!"
+ )
val moreToVisit =
if (identityMap.containsKey(node)) List.empty
else { // Haven't seen yet
identityMap.put(node, true)
regularSet += node
node match { // FirrtlNodes are Products
- case p: Product => p.productIterator
+ case p: Product => p.productIterator
case i: Iterable[Any] => i
case _ => List.empty
}
diff --git a/src/main/scala/firrtl/analyses/SymbolTable.scala b/src/main/scala/firrtl/analyses/SymbolTable.scala
index 53ad1614..36549160 100644
--- a/src/main/scala/firrtl/analyses/SymbolTable.scala
+++ b/src/main/scala/firrtl/analyses/SymbolTable.scala
@@ -17,26 +17,27 @@ import scala.collection.mutable
* Different implementations of SymbolTable might want to store different
* information (e.g., only the names without the types) or build
* different indices depending on what information the transform needs.
- * */
+ */
trait SymbolTable {
// methods that need to be implemented by any Symbol table
- def declare(name: String, tpe: Type, kind: Kind): Unit
+ def declare(name: String, tpe: Type, kind: Kind): Unit
def declareInstance(name: String, module: String): Unit
// convenience methods
def declare(d: DefInstance): Unit = declareInstance(d.name, d.module)
- def declare(d: DefMemory): Unit = declare(d.name, MemPortUtils.memType(d), firrtl.MemKind)
- def declare(d: DefNode): Unit = declare(d.name, d.value.tpe, firrtl.NodeKind)
- def declare(d: DefWire): Unit = declare(d.name, d.tpe, firrtl.WireKind)
+ def declare(d: DefMemory): Unit = declare(d.name, MemPortUtils.memType(d), firrtl.MemKind)
+ def declare(d: DefNode): Unit = declare(d.name, d.value.tpe, firrtl.NodeKind)
+ def declare(d: DefWire): Unit = declare(d.name, d.tpe, firrtl.WireKind)
def declare(d: DefRegister): Unit = declare(d.name, d.tpe, firrtl.RegKind)
- def declare(d: Port): Unit = declare(d.name, d.tpe, firrtl.PortKind)
+ def declare(d: Port): Unit = declare(d.name, d.tpe, firrtl.PortKind)
}
/** Trusts the type annotation on DefInstance nodes instead of re-deriving the type from
- * the module ports which would require global (cross-module) information. */
+ * the module ports which would require global (cross-module) information.
+ */
private[firrtl] abstract class LocalSymbolTable extends SymbolTable {
def declareInstance(name: String, module: String): Unit = declare(name, UnknownType, InstanceKind)
- override def declare(d: WDefInstance): Unit = declare(d.name, d.tpe, InstanceKind)
+ override def declare(d: WDefInstance): Unit = declare(d.name, d.tpe, InstanceKind)
}
/** Uses a function to derive instance types from module names */
@@ -63,10 +64,10 @@ private[firrtl] trait WithMap extends SymbolTable {
}
private case class Sym(name: String, tpe: Type, kind: Kind) extends Symbol
-private[firrtl] trait Symbol { def name: String; def tpe: Type; def kind: Kind }
+private[firrtl] trait Symbol { def name: String; def tpe: Type; def kind: Kind }
/** only remembers the names of symbols */
-private[firrtl] class NamespaceTable extends LocalSymbolTable {
+private[firrtl] class NamespaceTable extends LocalSymbolTable {
private var names = List[String]()
override def declare(name: String, tpe: Type, kind: Kind): Unit = names = name :: names
def getNames: Seq[String] = names
@@ -82,9 +83,9 @@ object SymbolTable {
}
private def scanStatement(s: Statement)(implicit table: SymbolTable): Unit = s match {
case d: DefInstance => table.declare(d)
- case d: DefMemory => table.declare(d)
- case d: DefNode => table.declare(d)
- case d: DefWire => table.declare(d)
+ case d: DefMemory => table.declare(d)
+ case d: DefNode => table.declare(d)
+ case d: DefWire => table.declare(d)
case d: DefRegister => table.declare(d)
case other => other.foreachStmt(scanStatement)
}
diff --git a/src/main/scala/firrtl/annotations/Annotation.scala b/src/main/scala/firrtl/annotations/Annotation.scala
index a382f685..16f85e67 100644
--- a/src/main/scala/firrtl/annotations/Annotation.scala
+++ b/src/main/scala/firrtl/annotations/Annotation.scala
@@ -5,7 +5,6 @@ package annotations
import firrtl.options.StageUtils
-
case class AnnotationException(message: String) extends Exception(message)
/** Base type of auxiliary information */
@@ -26,8 +25,8 @@ trait Annotation extends Product {
*/
private def extractComponents(ls: scala.collection.Traversable[_]): Seq[Target] = {
ls.collect {
- case c: Target => Seq(c)
- case o: Product => extractComponents(o.productIterator.toIterable)
+ case c: Target => Seq(c)
+ case o: Product => extractComponents(o.productIterator.toIterable)
case x: scala.collection.Traversable[_] => extractComponents(x)
}.foldRight(Seq.empty[Target])((seq, c) => c ++ seq)
}
@@ -62,52 +61,54 @@ trait SingleTargetAnnotation[T <: Named] extends Annotation {
x.map(newTargets => newTargets.map(t => duplicate(t.asInstanceOf[T]))).getOrElse(List(this))
case from: Named =>
val ret = renames.get(Target.convertNamed2Target(target))
- ret.map(_.map { newT =>
- val result = newT match {
- case c: InstanceTarget => ModuleName(c.ofModule, CircuitName(c.circuit))
- case c: IsMember =>
- val local = Target.referringModule(c)
- c.setPathTarget(local)
- case c: CircuitTarget => c.toNamed
- case other => throw Target.NamedException(s"Cannot convert $other to [[Named]]")
- }
- Target.convertTarget2Named(result) match {
- case newTarget: T @unchecked =>
- try {
- duplicate(newTarget)
- }
- catch {
- case _: java.lang.ClassCastException =>
- val msg = s"${this.getClass.getName} target ${target.getClass.getName} " +
- s"cannot be renamed to ${newTarget.getClass}"
- throw AnnotationException(msg)
- }
- }
- }).getOrElse(List(this))
+ ret
+ .map(_.map { newT =>
+ val result = newT match {
+ case c: InstanceTarget => ModuleName(c.ofModule, CircuitName(c.circuit))
+ case c: IsMember =>
+ val local = Target.referringModule(c)
+ c.setPathTarget(local)
+ case c: CircuitTarget => c.toNamed
+ case other => throw Target.NamedException(s"Cannot convert $other to [[Named]]")
+ }
+ Target.convertTarget2Named(result) match {
+ case newTarget: T @unchecked =>
+ try {
+ duplicate(newTarget)
+ } catch {
+ case _: java.lang.ClassCastException =>
+ val msg = s"${this.getClass.getName} target ${target.getClass.getName} " +
+ s"cannot be renamed to ${newTarget.getClass}"
+ throw AnnotationException(msg)
+ }
+ }
+ })
+ .getOrElse(List(this))
}
}
}
/** [[MultiTargetAnnotation]] keeps the renamed targets grouped within a single annotation. */
trait MultiTargetAnnotation extends Annotation {
+
/** Contains a sequence of [[firrtl.annotations.Target Target]].
* When created, [[targets]] should be assigned by `Seq(Seq(TargetA), Seq(TargetB), Seq(TargetC))`
*/
val targets: Seq[Seq[Target]]
- /** Create another instance of this Annotation*/
+ /** Create another instance of this Annotation */
def duplicate(n: Seq[Seq[Target]]): Annotation
/** Assume [[RenameMap]] is `Map(TargetA -> Seq(TargetA1, TargetA2, TargetA3), TargetB -> Seq(TargetB1, TargetB2))`
* in the update, this Annotation is still one annotation, but the contents are renamed in the below form
* Seq(Seq(TargetA1, TargetA2, TargetA3), Seq(TargetB1, TargetB2), Seq(TargetC))
- **/
+ */
def update(renames: RenameMap): Seq[Annotation] = Seq(duplicate(targets.map(ts => ts.flatMap(renames(_)))))
private def crossJoin[T](list: Seq[Seq[T]]): Seq[Seq[T]] =
list match {
- case Nil => Nil
- case x :: Nil => x map (Seq(_))
+ case Nil => Nil
+ case x :: Nil => x.map(Seq(_))
case x :: xs =>
val xsJoin = crossJoin(xs)
for {
@@ -123,7 +124,7 @@ trait MultiTargetAnnotation extends Annotation {
* Seq(Seq(TargetA1), Seq(TargetB1), Seq(TargetC)); Seq(Seq(TargetA1), Seq(TargetB2), Seq(TargetC))
* Seq(Seq(TargetA2), Seq(TargetB1), Seq(TargetC)); Seq(Seq(TargetA2), Seq(TargetB2), Seq(TargetC))
* Seq(Seq(TargetA3), Seq(TargetB1), Seq(TargetC)); Seq(Seq(TargetA3), Seq(TargetB2), Seq(TargetC))
- * */
+ */
def flat(): AnnotationSeq = crossJoin(targets).map(r => duplicate(r.map(Seq(_))))
}
diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala
index 58cc0097..a1276e0e 100644
--- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala
+++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala
@@ -8,14 +8,16 @@ import java.io.File
import firrtl.ir._
case class InvalidAnnotationFileException(file: File, cause: FirrtlUserException = null)
- extends FirrtlUserException(s"$file", cause)
+ extends FirrtlUserException(s"$file", cause)
case class InvalidAnnotationJSONException(msg: String) extends FirrtlUserException(msg)
-case class AnnotationFileNotFoundException(file: File) extends FirrtlUserException(
- s"Annotation file $file not found!"
-)
-case class AnnotationClassNotFoundException(className: String) extends FirrtlUserException(
- s"Annotation class $className not found! Please check spelling and classpath"
-)
+case class AnnotationFileNotFoundException(file: File)
+ extends FirrtlUserException(
+ s"Annotation file $file not found!"
+ )
+case class AnnotationClassNotFoundException(className: String)
+ extends FirrtlUserException(
+ s"Annotation class $className not found! Please check spelling and classpath"
+ )
object AnnotationUtils {
@@ -23,33 +25,33 @@ object AnnotationUtils {
val SerializedModuleName = """([a-zA-Z_][a-zA-Z_0-9~!@#$%^*\-+=?/]*)""".r
def validModuleName(s: String): Boolean = s match {
case SerializedModuleName(name) => true
- case _ => false
+ case _ => false
}
/** Returns true if a valid component/subcomponent name */
val SerializedComponentName = """([a-zA-Z_][a-zA-Z_0-9\[\]\.~!@#$%^*\-+=?/]*)""".r
def validComponentName(s: String): Boolean = s match {
case SerializedComponentName(name) => true
- case _ => false
+ case _ => false
}
/** Tokenizes a string with '[', ']', '.' as tokens, e.g.:
- * "foo.bar[boo.far]" becomes Seq("foo" "." "bar" "[" "boo" "." "far" "]")
- */
+ * "foo.bar[boo.far]" becomes Seq("foo" "." "bar" "[" "boo" "." "far" "]")
+ */
def tokenize(s: String): Seq[String] = s.find(c => "[].".contains(c)) match {
case Some(_) =>
val i = s.indexWhere(c => "[].".contains(c))
s.slice(0, i) match {
case "" => s(i).toString +: tokenize(s.drop(i + 1))
- case x => x +: s(i).toString +: tokenize(s.drop(i + 1))
+ case x => x +: s(i).toString +: tokenize(s.drop(i + 1))
}
case None if s == "" => Nil
- case None => Seq(s)
+ case None => Seq(s)
}
def toNamed(s: String): Named = s.split("\\.", 3) match {
- case Array(n) => CircuitName(n)
- case Array(c, m) => ModuleName(m, CircuitName(c))
+ case Array(n) => CircuitName(n)
+ case Array(c, m) => ModuleName(m, CircuitName(c))
case Array(c, m, x) => ComponentName(x, ModuleName(m, CircuitName(c)))
}
@@ -60,38 +62,39 @@ object AnnotationUtils {
def toSubComponents(s: String): Seq[TargetToken] = {
import TargetToken._
def exp2subcomp(e: ir.Expression): Seq[TargetToken] = e match {
- case ir.Reference(name, _, _, _) => Seq(Ref(name))
+ case ir.Reference(name, _, _, _) => Seq(Ref(name))
case ir.SubField(expr, name, _, _) => exp2subcomp(expr) :+ Field(name)
case ir.SubIndex(expr, idx, _, _) => exp2subcomp(expr) :+ Index(idx)
- case ir.SubAccess(expr, idx, _, _) => Utils.throwInternalError(s"For string $s, cannot convert a subaccess $e into a Target")
+ case ir.SubAccess(expr, idx, _, _) =>
+ Utils.throwInternalError(s"For string $s, cannot convert a subaccess $e into a Target")
}
exp2subcomp(toExp(s))
}
-
/** Given a serialized component/subcomponent reference, subindex, subaccess,
- * or subfield, return the corresponding IR expression.
- * E.g. "foo.bar" becomes SubField(Reference("foo", UnknownType), "bar", UnknownType)
- */
+ * or subfield, return the corresponding IR expression.
+ * E.g. "foo.bar" becomes SubField(Reference("foo", UnknownType), "bar", UnknownType)
+ */
def toExp(s: String): Expression = {
def parse(tokens: Seq[String]): Expression = {
val DecPattern = """(\d+)""".r
def findClose(tokens: Seq[String], index: Int, nOpen: Int): Seq[String] = {
- if(index >= tokens.size) {
+ if (index >= tokens.size) {
Utils.error("Cannot find closing bracket ]")
- } else tokens(index) match {
- case "[" => findClose(tokens, index + 1, nOpen + 1)
- case "]" if nOpen == 1 => tokens.slice(1, index)
- case "]" => findClose(tokens, index + 1, nOpen - 1)
- case _ => findClose(tokens, index + 1, nOpen)
- }
+ } else
+ tokens(index) match {
+ case "[" => findClose(tokens, index + 1, nOpen + 1)
+ case "]" if nOpen == 1 => tokens.slice(1, index)
+ case "]" => findClose(tokens, index + 1, nOpen - 1)
+ case _ => findClose(tokens, index + 1, nOpen)
+ }
}
def buildup(e: Expression, tokens: Seq[String]): Expression = tokens match {
case "[" :: tail =>
val indexOrAccess = findClose(tokens, 0, 0)
val exp = indexOrAccess.head match {
case DecPattern(d) => SubIndex(e, d.toInt, UnknownType)
- case _ => SubAccess(e, parse(indexOrAccess), UnknownType)
+ case _ => SubAccess(e, parse(indexOrAccess), UnknownType)
}
buildup(exp, tokens.drop(2 + indexOrAccess.size))
case "." :: tail =>
@@ -101,7 +104,7 @@ object AnnotationUtils {
val root = Reference(tokens.head, UnknownType)
buildup(root, tokens.tail)
}
- if(validComponentName(s)) {
+ if (validComponentName(s)) {
parse(tokenize(s))
} else {
Utils.error(s"Cannot convert $s into an expression.")
diff --git a/src/main/scala/firrtl/annotations/JsonProtocol.scala b/src/main/scala/firrtl/annotations/JsonProtocol.scala
index 941bf003..0ef8b020 100644
--- a/src/main/scala/firrtl/annotations/JsonProtocol.scala
+++ b/src/main/scala/firrtl/annotations/JsonProtocol.scala
@@ -5,7 +5,7 @@ package annotations
import firrtl.ir._
-import scala.util.{Try, Failure}
+import scala.util.{Failure, Try}
import org.json4s._
import org.json4s.native.JsonMethods._
@@ -20,112 +20,189 @@ trait HasSerializationHints {
}
object JsonProtocol {
- class TransformClassSerializer extends CustomSerializer[Class[_ <: Transform]](format => (
- { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] },
- { case x: Class[_] => JString(x.getName) }
- ))
+ class TransformClassSerializer
+ extends CustomSerializer[Class[_ <: Transform]](format =>
+ (
+ { case JString(s) => Class.forName(s).asInstanceOf[Class[_ <: Transform]] },
+ { case x: Class[_] => JString(x.getName) }
+ )
+ )
// TODO Reduce boilerplate?
- class NamedSerializer extends CustomSerializer[Named](format => (
- { case JString(s) => AnnotationUtils.toNamed(s) },
- { case named: Named => JString(named.serialize) }
- ))
- class CircuitNameSerializer extends CustomSerializer[CircuitName](format => (
- { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[CircuitName] },
- { case named: CircuitName => JString(named.serialize) }
- ))
- class ModuleNameSerializer extends CustomSerializer[ModuleName](format => (
- { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ModuleName] },
- { case named: ModuleName => JString(named.serialize) }
- ))
- class ComponentNameSerializer extends CustomSerializer[ComponentName](format => (
- { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ComponentName] },
- { case named: ComponentName => JString(named.serialize) }
- ))
- class TransformSerializer extends CustomSerializer[Transform](format => (
- { case JString(s) =>
- try {
- Class.forName(s).asInstanceOf[Class[_ <: Transform]].newInstance()
- } catch {
- case e: java.lang.InstantiationException => throw new FirrtlInternalException(
- "NoSuchMethodException during construction of serialized Transform. Is your Transform an inner class?", e)
- case t: Throwable => throw t
- }},
- { case x: Transform => JString(x.getClass.getName) }
- ))
- class LoadMemoryFileTypeSerializer extends CustomSerializer[MemoryLoadFileType](format => (
- { case JString(s) => MemoryLoadFileType.deserialize(s) },
- { case named: MemoryLoadFileType => JString(named.serialize) }
- ))
+ class NamedSerializer
+ extends CustomSerializer[Named](format =>
+ (
+ { case JString(s) => AnnotationUtils.toNamed(s) },
+ { case named: Named => JString(named.serialize) }
+ )
+ )
+ class CircuitNameSerializer
+ extends CustomSerializer[CircuitName](format =>
+ (
+ { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[CircuitName] },
+ { case named: CircuitName => JString(named.serialize) }
+ )
+ )
+ class ModuleNameSerializer
+ extends CustomSerializer[ModuleName](format =>
+ (
+ { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ModuleName] },
+ { case named: ModuleName => JString(named.serialize) }
+ )
+ )
+ class ComponentNameSerializer
+ extends CustomSerializer[ComponentName](format =>
+ (
+ { case JString(s) => AnnotationUtils.toNamed(s).asInstanceOf[ComponentName] },
+ { case named: ComponentName => JString(named.serialize) }
+ )
+ )
+ class TransformSerializer
+ extends CustomSerializer[Transform](format =>
+ (
+ {
+ case JString(s) =>
+ try {
+ Class.forName(s).asInstanceOf[Class[_ <: Transform]].newInstance()
+ } catch {
+ case e: java.lang.InstantiationException =>
+ throw new FirrtlInternalException(
+ "NoSuchMethodException during construction of serialized Transform. Is your Transform an inner class?",
+ e
+ )
+ case t: Throwable => throw t
+ }
+ },
+ { case x: Transform => JString(x.getClass.getName) }
+ )
+ )
+ class LoadMemoryFileTypeSerializer
+ extends CustomSerializer[MemoryLoadFileType](format =>
+ (
+ { case JString(s) => MemoryLoadFileType.deserialize(s) },
+ { case named: MemoryLoadFileType => JString(named.serialize) }
+ )
+ )
- class TargetSerializer extends CustomSerializer[Target](format => (
- { case JString(s) => Target.deserialize(s) },
- { case named: Target => JString(named.serialize) }
- ))
- class GenericTargetSerializer extends CustomSerializer[GenericTarget](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[GenericTarget] },
- { case named: GenericTarget => JString(named.serialize) }
- ))
- class CircuitTargetSerializer extends CustomSerializer[CircuitTarget](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[CircuitTarget] },
- { case named: CircuitTarget => JString(named.serialize) }
- ))
- class ModuleTargetSerializer extends CustomSerializer[ModuleTarget](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[ModuleTarget] },
- { case named: ModuleTarget => JString(named.serialize) }
- ))
- class InstanceTargetSerializer extends CustomSerializer[InstanceTarget](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[InstanceTarget] },
- { case named: InstanceTarget => JString(named.serialize) }
- ))
- class ReferenceTargetSerializer extends CustomSerializer[ReferenceTarget](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[ReferenceTarget] },
- { case named: ReferenceTarget => JString(named.serialize) }
- ))
- class IsModuleSerializer extends CustomSerializer[IsModule](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[IsModule] },
- { case named: IsModule => JString(named.serialize) }
- ))
- class IsMemberSerializer extends CustomSerializer[IsMember](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[IsMember] },
- { case named: IsMember => JString(named.serialize) }
- ))
- class CompleteTargetSerializer extends CustomSerializer[CompleteTarget](format => (
- { case JString(s) => Target.deserialize(s).asInstanceOf[CompleteTarget] },
- { case named: CompleteTarget => JString(named.serialize) }
- ))
+ class TargetSerializer
+ extends CustomSerializer[Target](format =>
+ (
+ { case JString(s) => Target.deserialize(s) },
+ { case named: Target => JString(named.serialize) }
+ )
+ )
+ class GenericTargetSerializer
+ extends CustomSerializer[GenericTarget](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[GenericTarget] },
+ { case named: GenericTarget => JString(named.serialize) }
+ )
+ )
+ class CircuitTargetSerializer
+ extends CustomSerializer[CircuitTarget](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[CircuitTarget] },
+ { case named: CircuitTarget => JString(named.serialize) }
+ )
+ )
+ class ModuleTargetSerializer
+ extends CustomSerializer[ModuleTarget](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[ModuleTarget] },
+ { case named: ModuleTarget => JString(named.serialize) }
+ )
+ )
+ class InstanceTargetSerializer
+ extends CustomSerializer[InstanceTarget](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[InstanceTarget] },
+ { case named: InstanceTarget => JString(named.serialize) }
+ )
+ )
+ class ReferenceTargetSerializer
+ extends CustomSerializer[ReferenceTarget](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[ReferenceTarget] },
+ { case named: ReferenceTarget => JString(named.serialize) }
+ )
+ )
+ class IsModuleSerializer
+ extends CustomSerializer[IsModule](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[IsModule] },
+ { case named: IsModule => JString(named.serialize) }
+ )
+ )
+ class IsMemberSerializer
+ extends CustomSerializer[IsMember](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[IsMember] },
+ { case named: IsMember => JString(named.serialize) }
+ )
+ )
+ class CompleteTargetSerializer
+ extends CustomSerializer[CompleteTarget](format =>
+ (
+ { case JString(s) => Target.deserialize(s).asInstanceOf[CompleteTarget] },
+ { case named: CompleteTarget => JString(named.serialize) }
+ )
+ )
// FIRRTL Serializers
- class TypeSerializer extends CustomSerializer[Type](format => (
- { case JString(s) => Parser.parseType(s) },
- { case tpe: Type => JString(tpe.serialize) }
- ))
- class ExpressionSerializer extends CustomSerializer[Expression](format => (
- { case JString(s) => Parser.parseExpression(s) },
- { case expr: Expression => JString(expr.serialize) }
- ))
- class StatementSerializer extends CustomSerializer[Statement](format => (
- { case JString(s) => Parser.parseStatement(s) },
- { case statement: Statement => JString(statement.serialize) }
- ))
- class PortSerializer extends CustomSerializer[Port](format => (
- { case JString(s) => Parser.parsePort(s) },
- { case port: Port => JString(port.serialize) }
- ))
- class DefModuleSerializer extends CustomSerializer[DefModule](format => (
- { case JString(s) => Parser.parseDefModule(s) },
- { case mod: DefModule => JString(mod.serialize) }
- ))
- class CircuitSerializer extends CustomSerializer[Circuit](format => (
- { case JString(s) => Parser.parse(s) },
- { case cir: Circuit => JString(cir.serialize) }
- ))
- class InfoSerializer extends CustomSerializer[Info](format => (
- { case JString(s) => Parser.parseInfo(s) },
- { case info: Info => JString(info.serialize) }
- ))
- class GroundTypeSerializer extends CustomSerializer[GroundType](format => (
- { case JString(s) => Parser.parseType(s).asInstanceOf[GroundType] },
- { case tpe: GroundType => JString(tpe.serialize) }
- ))
+ class TypeSerializer
+ extends CustomSerializer[Type](format =>
+ (
+ { case JString(s) => Parser.parseType(s) },
+ { case tpe: Type => JString(tpe.serialize) }
+ )
+ )
+ class ExpressionSerializer
+ extends CustomSerializer[Expression](format =>
+ (
+ { case JString(s) => Parser.parseExpression(s) },
+ { case expr: Expression => JString(expr.serialize) }
+ )
+ )
+ class StatementSerializer
+ extends CustomSerializer[Statement](format =>
+ (
+ { case JString(s) => Parser.parseStatement(s) },
+ { case statement: Statement => JString(statement.serialize) }
+ )
+ )
+ class PortSerializer
+ extends CustomSerializer[Port](format =>
+ (
+ { case JString(s) => Parser.parsePort(s) },
+ { case port: Port => JString(port.serialize) }
+ )
+ )
+ class DefModuleSerializer
+ extends CustomSerializer[DefModule](format =>
+ (
+ { case JString(s) => Parser.parseDefModule(s) },
+ { case mod: DefModule => JString(mod.serialize) }
+ )
+ )
+ class CircuitSerializer
+ extends CustomSerializer[Circuit](format =>
+ (
+ { case JString(s) => Parser.parse(s) },
+ { case cir: Circuit => JString(cir.serialize) }
+ )
+ )
+ class InfoSerializer
+ extends CustomSerializer[Info](format =>
+ (
+ { case JString(s) => Parser.parseInfo(s) },
+ { case info: Info => JString(info.serialize) }
+ )
+ )
+ class GroundTypeSerializer
+ extends CustomSerializer[GroundType](format =>
+ (
+ { case JString(s) => Parser.parseType(s).asInstanceOf[GroundType] },
+ { case tpe: GroundType => JString(tpe.serialize) }
+ )
+ )
/** Construct Json formatter for annotations */
def jsonFormat(tags: Seq[Class[_]]) = {
@@ -133,7 +210,7 @@ object JsonProtocol {
new TransformClassSerializer + new NamedSerializer + new CircuitNameSerializer +
new ModuleNameSerializer + new ComponentNameSerializer + new TargetSerializer +
new GenericTargetSerializer + new CircuitTargetSerializer + new ModuleTargetSerializer +
- new InstanceTargetSerializer + new ReferenceTargetSerializer + new TransformSerializer +
+ new InstanceTargetSerializer + new ReferenceTargetSerializer + new TransformSerializer +
new LoadMemoryFileTypeSerializer + new IsModuleSerializer + new IsMemberSerializer +
new CompleteTargetSerializer + new TypeSerializer + new ExpressionSerializer +
new StatementSerializer + new PortSerializer + new DefModuleSerializer +
@@ -144,10 +221,12 @@ object JsonProtocol {
def serialize(annos: Seq[Annotation]): String = serializeTry(annos).get
def serializeTry(annos: Seq[Annotation]): Try[String] = {
- val tags = annos.flatMap({
- case anno: HasSerializationHints => anno.getClass +: anno.typeHints
- case anno => Seq(anno.getClass)
- }).distinct
+ val tags = annos
+ .flatMap({
+ case anno: HasSerializationHints => anno.getClass +: anno.typeHints
+ case anno => Seq(anno.getClass)
+ })
+ .distinct
implicit val formats = jsonFormat(tags)
Try(writePretty(annos))
@@ -159,20 +238,25 @@ object JsonProtocol {
val parsed = parse(in)
val annos = parsed match {
case JArray(objs) => objs
- case x => throw new InvalidAnnotationJSONException(
- s"Annotations must be serialized as a JArray, got ${x.getClass.getName} instead!")
+ case x =>
+ throw new InvalidAnnotationJSONException(
+ s"Annotations must be serialized as a JArray, got ${x.getClass.getName} instead!"
+ )
}
// Recursively gather typeHints by pulling the "class" field from JObjects
// Json4s should emit this as the first field in all serialized classes
// Setting requireClassField mandates that all JObjects must provide a typeHint,
// this used on the first invocation to check all annotations do so
- def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst.flatMap({
- case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2))
- case obj: JObject if requireClassField => throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
- case JObject(fields) => findTypeHints(fields.map(_._2))
- case JArray(arr) => findTypeHints(arr)
- case oJValue => Seq()
- }).distinct
+ def findTypeHints(classInst: Seq[JValue], requireClassField: Boolean = false): Seq[String] = classInst
+ .flatMap({
+ case JObject(("class", JString(name)) :: fields) => name +: findTypeHints(fields.map(_._2))
+ case obj: JObject if requireClassField =>
+ throw new InvalidAnnotationJSONException(s"Expected field 'class' not found! $obj")
+ case JObject(fields) => findTypeHints(fields.map(_._2))
+ case JArray(arr) => findTypeHints(arr)
+ case oJValue => Seq()
+ })
+ .distinct
val classes = findTypeHints(annos, true)
val loaded = classes.map(Class.forName(_))
@@ -186,10 +270,11 @@ object JsonProtocol {
case e @ (_: org.json4s.ParserUtil.ParseException | _: org.json4s.MappingException) =>
Failure(new InvalidAnnotationJSONException(e.getMessage))
}.recoverWith { // If the input is a file, wrap in InvalidAnnotationFileException
- case e: FirrtlUserException => in match {
- case FileInput(file) =>
- Failure(new InvalidAnnotationFileException(file, e))
- case _ => Failure(e)
- }
+ case e: FirrtlUserException =>
+ in match {
+ case FileInput(file) =>
+ Failure(new InvalidAnnotationFileException(file, e))
+ case _ => Failure(e)
+ }
}
}
diff --git a/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala b/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala
index 64c30bdb..043c1b3b 100644
--- a/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala
+++ b/src/main/scala/firrtl/annotations/LoadMemoryAnnotation.scala
@@ -21,7 +21,7 @@ object MemoryLoadFileType {
def deserialize(s: String): MemoryLoadFileType = s match {
case "h" => MemoryLoadFileType.Hex
case "b" => MemoryLoadFileType.Binary
- case _ => throw new FirrtlUserException(s"Unrecognized MemoryLoadFileType: $s")
+ case _ => throw new FirrtlUserException(s"Unrecognized MemoryLoadFileType: $s")
}
}
@@ -31,11 +31,11 @@ object MemoryLoadFileType {
* @param hexOrBinary use `\$readmemh` or `\$readmemb`
*/
case class LoadMemoryAnnotation(
- target: ComponentName,
- fileName: String,
- hexOrBinary: MemoryLoadFileType = MemoryLoadFileType.Hex,
- originalMemoryNameOpt: Option[String] = None
-) extends SingleTargetAnnotation[Named] {
+ target: ComponentName,
+ fileName: String,
+ hexOrBinary: MemoryLoadFileType = MemoryLoadFileType.Hex,
+ originalMemoryNameOpt: Option[String] = None)
+ extends SingleTargetAnnotation[Named] {
val (prefix, suffix) = {
fileName.split("""\.""").toList match {
@@ -57,7 +57,7 @@ case class LoadMemoryAnnotation(
def getPrefix: String =
prefix + originalMemoryNameOpt.map(n => target.name.drop(n.length)).getOrElse("")
- def getSuffix: String = suffix
+ def getSuffix: String = suffix
def getFileName: String = getPrefix + getSuffix
def duplicate(newNamed: Named): LoadMemoryAnnotation = {
diff --git a/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala b/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala
index 44a8e3b5..7cefdef8 100644
--- a/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala
+++ b/src/main/scala/firrtl/annotations/MemoryInitAnnotation.scala
@@ -5,10 +5,10 @@ package firrtl.annotations
import firrtl.{MemoryArrayInit, MemoryEmissionOption, MemoryInitValue, MemoryRandomInit, MemoryScalarInit}
/**
- * Represents the initial value of the annotated memory.
- * While not supported on normal ASIC flows, it can be useful for simulation and FPGA flows.
- * This annotation is consumed by the verilog emitter.
- */
+ * Represents the initial value of the annotated memory.
+ * While not supported on normal ASIC flows, it can be useful for simulation and FPGA flows.
+ * This annotation is consumed by the verilog emitter.
+ */
sealed trait MemoryInitAnnotation extends SingleTargetAnnotation[ReferenceTarget] with MemoryEmissionOption {
def isRandomInit: Boolean
}
@@ -16,20 +16,20 @@ sealed trait MemoryInitAnnotation extends SingleTargetAnnotation[ReferenceTarget
/** Randomly initialize the `target` memory. This is the same as the default behavior. */
case class MemoryRandomInitAnnotation(target: ReferenceTarget) extends MemoryInitAnnotation {
override def duplicate(n: ReferenceTarget): Annotation = copy(n)
- override def initValue: MemoryInitValue = MemoryRandomInit
+ override def initValue: MemoryInitValue = MemoryRandomInit
override def isRandomInit: Boolean = true
}
/** Initialize all entries of the `target` memory with the scalar `value`. */
case class MemoryScalarInitAnnotation(target: ReferenceTarget, value: BigInt) extends MemoryInitAnnotation {
override def duplicate(n: ReferenceTarget): Annotation = copy(n)
- override def initValue: MemoryInitValue = MemoryScalarInit(value)
- override def isRandomInit: Boolean = false
+ override def initValue: MemoryInitValue = MemoryScalarInit(value)
+ override def isRandomInit: Boolean = false
}
/** Initialize the `target` memory with the array of `values` which must be the same size as the memory depth. */
case class MemoryArrayInitAnnotation(target: ReferenceTarget, values: Seq[BigInt]) extends MemoryInitAnnotation {
override def duplicate(n: ReferenceTarget): Annotation = copy(n)
- override def initValue: MemoryInitValue = MemoryArrayInit(values)
- override def isRandomInit: Boolean = false
-} \ No newline at end of file
+ override def initValue: MemoryInitValue = MemoryArrayInit(values)
+ override def isRandomInit: Boolean = false
+}
diff --git a/src/main/scala/firrtl/annotations/PresetAnnotations.scala b/src/main/scala/firrtl/annotations/PresetAnnotations.scala
index 727417c1..d6066aa7 100644
--- a/src/main/scala/firrtl/annotations/PresetAnnotations.scala
+++ b/src/main/scala/firrtl/annotations/PresetAnnotations.scala
@@ -10,11 +10,11 @@ package annotations
* @param target ReferenceTarget to an AsyncReset
*/
case class PresetAnnotation(target: ReferenceTarget)
- extends SingleTargetAnnotation[ReferenceTarget] with firrtl.transforms.DontTouchAllTargets {
+ extends SingleTargetAnnotation[ReferenceTarget]
+ with firrtl.transforms.DontTouchAllTargets {
override def duplicate(n: ReferenceTarget) = this.copy(target = n)
}
-
/**
* Transform the targeted asynchronously-reset Reg into a bitstream preset Reg
* Used internally to annotate all registers associated to an AsyncReset tree
@@ -22,12 +22,10 @@ case class PresetAnnotation(target: ReferenceTarget)
* @param target ReferenceTarget to a Reg
*/
private[firrtl] case class PresetRegAnnotation(
- target: ReferenceTarget
-) extends SingleTargetAnnotation[ReferenceTarget] with RegisterEmissionOption {
+ target: ReferenceTarget)
+ extends SingleTargetAnnotation[ReferenceTarget]
+ with RegisterEmissionOption {
def duplicate(n: ReferenceTarget) = this.copy(target = n)
override def useInitAsPreset = true
override def disableRandomization = true
}
-
-
-
diff --git a/src/main/scala/firrtl/annotations/Target.scala b/src/main/scala/firrtl/annotations/Target.scala
index 4d1cdc2f..afde84dc 100644
--- a/src/main/scala/firrtl/annotations/Target.scala
+++ b/src/main/scala/firrtl/annotations/Target.scala
@@ -4,7 +4,7 @@ package firrtl
package annotations
import firrtl.ir.{Field => _, _}
-import firrtl.Utils.{sub_type, field_type}
+import firrtl.Utils.{field_type, sub_type}
import AnnotationUtils.{toExp, validComponentName, validModuleName}
import TargetToken._
@@ -29,27 +29,29 @@ sealed trait Target extends Named {
def tokens: Seq[TargetToken]
/** @return Returns a new [[GenericTarget]] with new values */
- def modify(circuitOpt: Option[String] = circuitOpt,
- moduleOpt: Option[String] = moduleOpt,
- tokens: Seq[TargetToken] = tokens): GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector)
+ def modify(
+ circuitOpt: Option[String] = circuitOpt,
+ moduleOpt: Option[String] = moduleOpt,
+ tokens: Seq[TargetToken] = tokens
+ ): GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector)
/** @return Human-readable serialization */
def serialize: String = {
val circuitString = "~" + circuitOpt.getOrElse("???")
val moduleString = "|" + moduleOpt.getOrElse("???")
val tokensString = tokens.map {
- case Ref(r) => s">$r"
- case Instance(i) => s"/$i"
- case OfModule(o) => s":$o"
+ case Ref(r) => s">$r"
+ case Instance(i) => s"/$i"
+ case OfModule(o) => s":$o"
case TargetToken.Field(f) => s".$f"
- case Index(v) => s"[$v]"
- case Clock => s"@clock"
- case Reset => s"@reset"
- case Init => s"@init"
+ case Index(v) => s"[$v]"
+ case Clock => s"@clock"
+ case Reset => s"@reset"
+ case Init => s"@init"
}.mkString("")
- if(moduleOpt.isEmpty && tokens.isEmpty) {
+ if (moduleOpt.isEmpty && tokens.isEmpty) {
circuitString
- } else if(tokens.isEmpty) {
+ } else if (tokens.isEmpty) {
circuitString + moduleString
} else {
circuitString + moduleString + tokensString
@@ -64,24 +66,23 @@ sealed trait Target extends Named {
val moduleString = s"""\n$tab└── module ${moduleOpt.getOrElse("???")}:"""
var depth = 4
val tokenString = tokens.map {
- case Ref(r) => val rx = s"""\n$tab${" "*depth}└── $r"""; depth += 4; rx
- case Instance(i) => val ix = s"""\n$tab${" "*depth}└── inst $i """; ix
+ case Ref(r) => val rx = s"""\n$tab${" " * depth}└── $r"""; depth += 4; rx
+ case Instance(i) => val ix = s"""\n$tab${" " * depth}└── inst $i """; ix
case OfModule(o) => val ox = s"of $o:"; depth += 4; ox
- case Field(f) => s".$f"
- case Index(v) => s"[$v]"
- case Clock => s"@clock"
- case Reset => s"@reset"
- case Init => s"@init"
+ case Field(f) => s".$f"
+ case Index(v) => s"[$v]"
+ case Clock => s"@clock"
+ case Reset => s"@reset"
+ case Init => s"@init"
}.mkString("")
(moduleOpt.isEmpty, tokens.isEmpty) match {
case (true, true) => circuitString
- case (_, true) => circuitString + moduleString
- case (_, _) => circuitString + moduleString + tokenString
+ case (_, true) => circuitString + moduleString
+ case (_, _) => circuitString + moduleString + tokenString
}
}
-
/** @return Converts this [[Target]] into a [[GenericTarget]] */
def toGenericTarget: GenericTarget = GenericTarget(circuitOpt, moduleOpt, tokens.toVector)
@@ -113,13 +114,13 @@ sealed trait Target extends Named {
object Target {
def asTarget(m: ModuleTarget)(e: Expression): ReferenceTarget = e match {
case r: ir.Reference => m.ref(r.name)
- case s: ir.SubIndex => asTarget(m)(s.expr).index(s.value)
- case s: ir.SubField => asTarget(m)(s.expr).field(s.name)
+ case s: ir.SubIndex => asTarget(m)(s.expr).index(s.value)
+ case s: ir.SubField => asTarget(m)(s.expr).field(s.name)
case s: ir.SubAccess => asTarget(m)(s.expr).field("@" + s.index.serialize)
- case d: DoPrim => m.ref("@" + d.serialize)
- case d: Mux => m.ref("@" + d.serialize)
- case d: ValidIf => m.ref("@" + d.serialize)
- case d: Literal => m.ref("@" + d.serialize)
+ case d: DoPrim => m.ref("@" + d.serialize)
+ case d: Mux => m.ref("@" + d.serialize)
+ case d: ValidIf => m.ref("@" + d.serialize)
+ case d: Literal => m.ref("@" + d.serialize)
case other => sys.error(s"Unsupported: $other")
}
@@ -131,14 +132,14 @@ object Target {
case class NamedException(message: String) extends Exception(message)
- implicit def convertCircuitTarget2CircuitName(c: CircuitTarget): CircuitName = c.toNamed
- implicit def convertModuleTarget2ModuleName(c: ModuleTarget): ModuleName = c.toNamed
- implicit def convertIsComponent2ComponentName(c: IsComponent): ComponentName = c.toNamed
- implicit def convertTarget2Named(c: Target): Named = c.toNamed
- implicit def convertCircuitName2CircuitTarget(c: CircuitName): CircuitTarget = c.toTarget
- implicit def convertModuleName2ModuleTarget(c: ModuleName): ModuleTarget = c.toTarget
+ implicit def convertCircuitTarget2CircuitName(c: CircuitTarget): CircuitName = c.toNamed
+ implicit def convertModuleTarget2ModuleName(c: ModuleTarget): ModuleName = c.toNamed
+ implicit def convertIsComponent2ComponentName(c: IsComponent): ComponentName = c.toNamed
+ implicit def convertTarget2Named(c: Target): Named = c.toNamed
+ implicit def convertCircuitName2CircuitTarget(c: CircuitName): CircuitTarget = c.toTarget
+ implicit def convertModuleName2ModuleTarget(c: ModuleName): ModuleTarget = c.toTarget
implicit def convertComponentName2ReferenceTarget(c: ComponentName): ReferenceTarget = c.toTarget
- implicit def convertNamed2Target(n: Named): CompleteTarget = n.toTarget
+ implicit def convertNamed2Target(n: Named): CompleteTarget = n.toTarget
/** Converts [[ComponentName]]'s name into TargetTokens
* @param name
@@ -148,7 +149,7 @@ object Target {
val tokens = AnnotationUtils.tokenize(name)
val subComps = mutable.ArrayBuffer[TargetToken]()
subComps += Ref(tokens.head)
- if(tokens.tail.nonEmpty) {
+ if (tokens.tail.nonEmpty) {
tokens.tail.zip(tokens.tail.tail).foreach {
case (".", value: String) => subComps += Field(value)
case ("[", value: String) => subComps += Index(value.toInt)
@@ -163,31 +164,33 @@ object Target {
* @param keywords
* @return
*/
- def isOnly(seq: Seq[TargetToken], keywords:String*): Boolean = {
- seq.map(_.is(keywords:_*)).foldLeft(false)(_ || _) && keywords.nonEmpty
+ def isOnly(seq: Seq[TargetToken], keywords: String*): Boolean = {
+ seq.map(_.is(keywords: _*)).foldLeft(false)(_ || _) && keywords.nonEmpty
}
/** @return [[Target]] from human-readable serialization */
def deserialize(s: String): Target = {
val regex = """(?=[~|>/:.\[@])"""
- s.split(regex).foldLeft(GenericTarget(None, None, Vector.empty)) { (t, tokenString) =>
- val value = tokenString.tail
- tokenString(0) match {
- case '~' if t.circuitOpt.isEmpty && t.moduleOpt.isEmpty && t.tokens.isEmpty =>
- if(value == "???") t else t.copy(circuitOpt = Some(value))
- case '|' if t.moduleOpt.isEmpty && t.tokens.isEmpty =>
- if(value == "???") t else t.copy(moduleOpt = Some(value))
- case '/' => t.add(Instance(value))
- case ':' => t.add(OfModule(value))
- case '>' => t.add(Ref(value))
- case '.' => t.add(Field(value))
- case '[' if value.dropRight(1).toInt >= 0 => t.add(Index(value.dropRight(1).toInt))
- case '@' if value == "clock" => t.add(Clock)
- case '@' if value == "init" => t.add(Init)
- case '@' if value == "reset" => t.add(Reset)
- case other => throw NamedException(s"Cannot deserialize Target: $s")
+ s.split(regex)
+ .foldLeft(GenericTarget(None, None, Vector.empty)) { (t, tokenString) =>
+ val value = tokenString.tail
+ tokenString(0) match {
+ case '~' if t.circuitOpt.isEmpty && t.moduleOpt.isEmpty && t.tokens.isEmpty =>
+ if (value == "???") t else t.copy(circuitOpt = Some(value))
+ case '|' if t.moduleOpt.isEmpty && t.tokens.isEmpty =>
+ if (value == "???") t else t.copy(moduleOpt = Some(value))
+ case '/' => t.add(Instance(value))
+ case ':' => t.add(OfModule(value))
+ case '>' => t.add(Ref(value))
+ case '.' => t.add(Field(value))
+ case '[' if value.dropRight(1).toInt >= 0 => t.add(Index(value.dropRight(1).toInt))
+ case '@' if value == "clock" => t.add(Clock)
+ case '@' if value == "init" => t.add(Init)
+ case '@' if value == "reset" => t.add(Reset)
+ case other => throw NamedException(s"Cannot deserialize Target: $s")
+ }
}
- }.tryToComplete
+ .tryToComplete
}
/** Returns the module that a [[Target]] "refers" to.
@@ -217,14 +220,16 @@ object Target {
def getReferenceTarget(t: Target): Target = {
(t.toGenericTarget match {
case t: GenericTarget if t.isLegal =>
- val newTokens = t.tokens.reverse.dropWhile({
- case x: Field => true
- case x: Index => true
- case Clock => true
- case Init => true
- case Reset => true
- case other => false
- }).reverse
+ val newTokens = t.tokens.reverse
+ .dropWhile({
+ case x: Field => true
+ case x: Index => true
+ case Clock => true
+ case Init => true
+ case Reset => true
+ case other => false
+ })
+ .reverse
GenericTarget(t.circuitOpt, t.moduleOpt, newTokens)
case other => sys.error(s"Can't make $other pathless!")
}).tryToComplete
@@ -236,9 +241,8 @@ object Target {
* @param moduleOpt Optional module name
* @param tokens [[TargetToken]]s to represent the target in a circuit and module
*/
-case class GenericTarget(circuitOpt: Option[String],
- moduleOpt: Option[String],
- tokens: Vector[TargetToken]) extends Target {
+case class GenericTarget(circuitOpt: Option[String], moduleOpt: Option[String], tokens: Vector[TargetToken])
+ extends Target {
override def toGenericTarget: GenericTarget = this
@@ -252,11 +256,12 @@ case class GenericTarget(circuitOpt: Option[String],
override def toTarget: CompleteTarget = getComplete.get
override def getComplete: Option[CompleteTarget] = {
- if(!isComplete) None else {
+ if (!isComplete) None
+ else {
val target = this match {
- case GenericTarget(Some(c), None, Vector()) => CircuitTarget(c)
- case GenericTarget(Some(c), Some(m), Vector()) => ModuleTarget(c, m)
- case GenericTarget(Some(c), Some(m), Ref(r) +: component) => ReferenceTarget(c, m, Nil, r, component)
+ case GenericTarget(Some(c), None, Vector()) => CircuitTarget(c)
+ case GenericTarget(Some(c), Some(m), Vector()) => ModuleTarget(c, m)
+ case GenericTarget(Some(c), Some(m), Ref(r) +: component) => ReferenceTarget(c, m, Nil, r, component)
case GenericTarget(Some(c), Some(m), Instance(i) +: OfModule(o) +: Vector()) => InstanceTarget(c, m, Nil, i, o)
case GenericTarget(Some(c), Some(m), component) =>
val path = getPath.getOrElse(Nil)
@@ -271,7 +276,7 @@ case class GenericTarget(circuitOpt: Option[String],
override def isLocal: Boolean = !(getPath.nonEmpty && getPath.get.nonEmpty)
- def path: Vector[(Instance, OfModule)] = if(isComplete){
+ def path: Vector[(Instance, OfModule)] = if (isComplete) {
tokens.zip(tokens.tail).collect {
case (i: Instance, o: OfModule) => (i, o)
}
@@ -280,9 +285,9 @@ case class GenericTarget(circuitOpt: Option[String],
/** If complete, return this [[GenericTarget]]'s path
* @return
*/
- def getPath: Option[Seq[(Instance, OfModule)]] = if(isComplete) {
- val allInstOfs = tokens.grouped(2).collect { case Seq(i: Instance, o:OfModule) => (i, o)}.toSeq
- if(tokens.nonEmpty && tokens.last.isInstanceOf[OfModule]) Some(allInstOfs.dropRight(1)) else Some(allInstOfs)
+ def getPath: Option[Seq[(Instance, OfModule)]] = if (isComplete) {
+ val allInstOfs = tokens.grouped(2).collect { case Seq(i: Instance, o: OfModule) => (i, o) }.toSeq
+ if (tokens.nonEmpty && tokens.last.isInstanceOf[OfModule]) Some(allInstOfs.dropRight(1)) else Some(allInstOfs)
} else {
None
}
@@ -290,7 +295,7 @@ case class GenericTarget(circuitOpt: Option[String],
/** If complete and a reference, return the reference and subcomponents
* @return
*/
- def getRef: Option[(String, Seq[TargetToken])] = if(isComplete) {
+ def getRef: Option[(String, Seq[TargetToken])] = if (isComplete) {
val (optRef, comps) = tokens.foldLeft((None: Option[String], Vector.empty[TargetToken])) {
case ((None, v), Ref(r)) => (Some(r), v)
case ((r: Some[String], comps), c) => (r, comps :+ c)
@@ -304,7 +309,7 @@ case class GenericTarget(circuitOpt: Option[String],
/** If complete and an instance target, return the instance and ofmodule
* @return
*/
- def getInstanceOf: Option[(String, String)] = if(isComplete) {
+ def getInstanceOf: Option[(String, String)] = if (isComplete) {
tokens.grouped(2).foldLeft(None: Option[(String, String)]) {
case (instOf, Seq(i: Instance, o: OfModule)) => Some((i.value, o.value))
case (instOf, _) => None
@@ -328,14 +333,14 @@ case class GenericTarget(circuitOpt: Option[String],
*/
def add(token: TargetToken): GenericTarget = {
token match {
- case _: Instance => requireLast(true, "inst", "of")
- case _: OfModule => requireLast(false, "inst")
- case _: Ref => requireLast(true, "inst", "of")
- case _: Field => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
- case _: Index => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
- case Init => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
- case Clock => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
- case Reset => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
+ case _: Instance => requireLast(true, "inst", "of")
+ case _: OfModule => requireLast(false, "inst")
+ case _: Ref => requireLast(true, "inst", "of")
+ case _: Field => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
+ case _: Index => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
+ case Init => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
+ case Clock => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
+ case Reset => requireLast(true, "ref", "[]", ".", "init", "clock", "reset")
}
this.copy(tokens = tokens :+ token)
}
@@ -345,7 +350,7 @@ case class GenericTarget(circuitOpt: Option[String],
/** Optionally tries to append token to tokens, fails return is not a legal Target */
def optAdd(token: TargetToken): Option[Target] = {
- try{
+ try {
Some(add(token))
} catch {
case _: IllegalArgumentException => None
@@ -358,7 +363,7 @@ case class GenericTarget(circuitOpt: Option[String],
def isLegal: Boolean = {
try {
var comp: GenericTarget = this.copy(tokens = Vector.empty)
- for(token <- tokens) {
+ for (token <- tokens) {
comp = comp.add(token)
}
true
@@ -374,19 +379,18 @@ case class GenericTarget(circuitOpt: Option[String],
def isComplete: Boolean = {
isLegal && (isCircuitTarget || isModuleTarget || (isComponentTarget && tokens.tails.forall {
case Instance(_) +: OfModule(_) +: tail => true
- case Instance(_) +: x +: tail => false
- case x +: OfModule(_) +: tail => false
- case _ => true
- } ))
+ case Instance(_) +: x +: tail => false
+ case x +: OfModule(_) +: tail => false
+ case _ => true
+ }))
}
-
- def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty
- def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty
+ def isCircuitTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.isEmpty && tokens.isEmpty
+ def isModuleTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.isEmpty
def isComponentTarget: Boolean = circuitOpt.nonEmpty && moduleOpt.nonEmpty && tokens.nonEmpty
lazy val (parentModule: Option[String], astModule: Option[String]) = path match {
- case Seq() => (None, moduleOpt)
+ case Seq() => (None, moduleOpt)
case Seq((i, OfModule(o))) => (moduleOpt, Some(o))
case seq if seq.size > 1 =>
val reversed = seq.reverse
@@ -421,7 +425,6 @@ trait CompleteTarget extends Target {
override def toString: String = serialize
}
-
/** A member of a FIRRTL Circuit (e.g. cannot point to a CircuitTarget)
* Concrete Subclasses are: [[ModuleTarget]], [[InstanceTarget]], and [[ReferenceTarget]]
*/
@@ -456,10 +459,12 @@ trait IsMember extends CompleteTarget {
/** @return List of local Instance Targets refering to each instance/ofModule in this member's path */
def pathAsTargets: Seq[InstanceTarget] = {
- path.foldLeft((module, Vector.empty[InstanceTarget])) {
- case ((m, vec), (Instance(i), OfModule(o))) =>
- (o, vec :+ InstanceTarget(circuit, m, Nil, i, o))
- }._2
+ path
+ .foldLeft((module, Vector.empty[InstanceTarget])) {
+ case ((m, vec), (Instance(i), OfModule(o))) =>
+ (o, vec :+ InstanceTarget(circuit, m, Nil, i, o))
+ }
+ ._2
}
/** Resets this target to have a new path
@@ -469,7 +474,7 @@ trait IsMember extends CompleteTarget {
def setPathTarget(newPath: IsModule): CompleteTarget
/** @return The [[ModuleTarget]] of the module that directly contains this component */
- def encapsulatingModule: String = if(path.isEmpty) module else path.last._2.value
+ def encapsulatingModule: String = if (path.isEmpty) module else path.last._2.value
def encapsulatingModuleTarget: ModuleTarget = ModuleTarget(circuit, encapsulatingModule)
@@ -492,6 +497,7 @@ trait IsModule extends IsMember {
/** A component of a FIRRTL Module (e.g. cannot point to a CircuitTarget or ModuleTarget)
*/
trait IsComponent extends IsMember {
+
/** Removes n levels of instance hierarchy
*
* Example: n=1, transforms (Top, A)/b:B/c:C -> (Top, B)/c:C
@@ -501,13 +507,13 @@ trait IsComponent extends IsMember {
def stripHierarchy(n: Int): IsMember
override def toNamed: ComponentName = {
- if(isLocal){
+ if (isLocal) {
val mn = ModuleName(module, CircuitName(circuit))
- Seq(tokens:_*) match {
+ Seq(tokens: _*) match {
case Seq(Ref(name)) => ComponentName(name, mn)
case Ref(_) :: tail if Target.isOnly(tail, ".", "[]") =>
- val name = tokens.foldLeft(""){
- case ("", Ref(name)) => name
+ val name = tokens.foldLeft("") {
+ case ("", Ref(name)) => name
case (string, Field(value)) => s"$string.$value"
case (string, Index(value)) => s"$string[$value]"
}
@@ -524,7 +530,8 @@ trait IsComponent extends IsMember {
}
override def pathTarget: IsModule = {
- if(path.isEmpty) moduleTarget else {
+ if (path.isEmpty) moduleTarget
+ else {
val (i, o) = path.last
InstanceTarget(circuit, module, path.dropRight(1), i.value, o.value)
}
@@ -535,7 +542,6 @@ trait IsComponent extends IsMember {
override def isLocal = path.isEmpty
}
-
/** Target pointing to a FIRRTL [[firrtl.ir.Circuit]]
* @param circuit Name of a FIRRTL circuit
*/
@@ -577,7 +583,8 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule {
override def targetParent: CircuitTarget = CircuitTarget(circuit)
- override def addHierarchy(root: String, instance: String): InstanceTarget = InstanceTarget(circuit, root, Nil, instance, module)
+ override def addHierarchy(root: String, instance: String): InstanceTarget =
+ InstanceTarget(circuit, root, Nil, instance, module)
override def ref(value: String): ReferenceTarget = ReferenceTarget(circuit, module, Nil, value, Nil)
@@ -613,11 +620,13 @@ case class ModuleTarget(circuit: String, module: String) extends IsModule {
* @param ref Name of component
* @param component Subcomponent of this reference, e.g. field or index
*/
-case class ReferenceTarget(circuit: String,
- module: String,
- override val path: Seq[(Instance, OfModule)],
- ref: String,
- component: Seq[TargetToken]) extends IsComponent {
+case class ReferenceTarget(
+ circuit: String,
+ module: String,
+ override val path: Seq[(Instance, OfModule)],
+ ref: String,
+ component: Seq[TargetToken])
+ extends IsComponent {
/** @param value Index value of this target
* @return A new [[ReferenceTarget]] to the specified index of this [[ReferenceTarget]]
@@ -648,7 +657,7 @@ case class ReferenceTarget(circuit: String,
baseType
} else {
val headType = tokens.head match {
- case Index(idx) => sub_type(baseType)
+ case Index(idx) => sub_type(baseType)
case Field(field) => field_type(baseType, field)
case _: Ref => baseType
}
@@ -662,7 +671,8 @@ case class ReferenceTarget(circuit: String,
override def targetParent: CompleteTarget = component match {
case Nil =>
- if(path.isEmpty) moduleTarget else {
+ if (path.isEmpty) moduleTarget
+ else {
val (i, o) = path.last
InstanceTarget(circuit, module, path.dropRight(1), i.value, o.value)
}
@@ -676,7 +686,8 @@ case class ReferenceTarget(circuit: String,
override def stripHierarchy(n: Int): ReferenceTarget = {
require(path.size >= n, s"Cannot strip $n levels of hierarchy from $this")
- if(n == 0) this else {
+ if (n == 0) this
+ else {
val newModule = path(n - 1)._2.value
ReferenceTarget(circuit, newModule, path.drop(n), ref, component)
}
@@ -700,15 +711,15 @@ case class ReferenceTarget(circuit: String,
def leafSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match {
case _: firrtl.ir.GroundType => Vector(this)
case firrtl.ir.VectorType(t, size) => (0 until size).flatMap { i => index(i).leafSubTargets(t) }
- case firrtl.ir.BundleType(fields) => fields.flatMap { f => field(f.name).leafSubTargets(f.tpe)}
- case other => sys.error(s"Error! Unexpected type $other")
+ case firrtl.ir.BundleType(fields) => fields.flatMap { f => field(f.name).leafSubTargets(f.tpe) }
+ case other => sys.error(s"Error! Unexpected type $other")
}
def allSubTargets(tpe: firrtl.ir.Type): Seq[ReferenceTarget] = tpe match {
case _: firrtl.ir.GroundType => Vector(this)
case firrtl.ir.VectorType(t, size) => this +: (0 until size).flatMap { i => index(i).allSubTargets(t) }
- case firrtl.ir.BundleType(fields) => this +: fields.flatMap { f => field(f.name).allSubTargets(f.tpe)}
- case other => sys.error(s"Error! Unexpected type $other")
+ case firrtl.ir.BundleType(fields) => this +: fields.flatMap { f => field(f.name).allSubTargets(f.tpe) }
+ case other => sys.error(s"Error! Unexpected type $other")
}
override def leafModule: String = encapsulatingModule
@@ -721,11 +732,14 @@ case class ReferenceTarget(circuit: String,
* @param instance Name of the instance
* @param ofModule Name of the instance's module
*/
-case class InstanceTarget(circuit: String,
- module: String,
- override val path: Seq[(Instance, OfModule)],
- instance: String,
- ofModule: String) extends IsModule with IsComponent {
+case class InstanceTarget(
+ circuit: String,
+ module: String,
+ override val path: Seq[(Instance, OfModule)],
+ instance: String,
+ ofModule: String)
+ extends IsModule
+ with IsComponent {
/** @return a [[ReferenceTarget]] referring to this declaration of this instance */
def asReference: ReferenceTarget = ReferenceTarget(circuit, module, path, instance, Nil)
@@ -744,7 +758,8 @@ case class InstanceTarget(circuit: String,
override def moduleOpt: Option[String] = Some(module)
override def targetParent: IsModule = {
- if(isLocal) ModuleTarget(circuit, module) else {
+ if (isLocal) ModuleTarget(circuit, module)
+ else {
val (newInstance, newOfModule) = path.last
InstanceTarget(circuit, module, path.dropRight(1), newInstance.value, newOfModule.value)
}
@@ -759,8 +774,9 @@ case class InstanceTarget(circuit: String,
override def stripHierarchy(n: Int): IsModule = {
require(path.size + 1 >= n, s"Cannot strip $n levels of hierarchy from $this")
- if(n == 0) this else {
- if(path.size < n){
+ if (n == 0) this
+ else {
+ if (path.size < n) {
ModuleTarget(circuit, ofModule)
} else {
val newModule = path(n - 1)._2.value
@@ -769,7 +785,7 @@ case class InstanceTarget(circuit: String,
}
}
- override def asPath: Seq[(Instance, OfModule)] = path :+( (Instance(instance), OfModule(ofModule)) )
+ override def asPath: Seq[(Instance, OfModule)] = path :+ ((Instance(instance), OfModule(ofModule)))
override def pathlessTarget: InstanceTarget = InstanceTarget(circuit, encapsulatingModule, Nil, instance, ofModule)
@@ -781,33 +797,32 @@ case class InstanceTarget(circuit: String,
override def leafModule: String = ofModule
}
-
/** Named classes associate an annotation with a component in a Firrtl circuit */
sealed trait Named {
def serialize: String
- def toTarget: CompleteTarget
+ def toTarget: CompleteTarget
}
final case class CircuitName(name: String) extends Named {
- if(!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name")
+ if (!validModuleName(name)) throw AnnotationException(s"Illegal circuit name: $name")
def serialize: String = name
- def toTarget: CircuitTarget = CircuitTarget(name)
+ def toTarget: CircuitTarget = CircuitTarget(name)
}
final case class ModuleName(name: String, circuit: CircuitName) extends Named {
- if(!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name")
+ if (!validModuleName(name)) throw AnnotationException(s"Illegal module name: $name")
def serialize: String = circuit.serialize + "." + name
- def toTarget: ModuleTarget = ModuleTarget(circuit.name, name)
+ def toTarget: ModuleTarget = ModuleTarget(circuit.name, name)
}
final case class ComponentName(name: String, module: ModuleName) extends Named {
- if(!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name")
- def expr: Expression = toExp(name)
+ if (!validComponentName(name)) throw AnnotationException(s"Illegal component name: $name")
+ def expr: Expression = toExp(name)
def serialize: String = module.serialize + "." + name
def toTarget: ReferenceTarget = {
Target.toTargetTokens(name).toList match {
case Ref(r) :: components => ReferenceTarget(module.circuit.name, module.name, Nil, r, components)
- case other => throw Target.NamedException(s"Cannot convert $this into [[ReferenceTarget]]: $other")
+ case other => throw Target.NamedException(s"Cannot convert $this into [[ReferenceTarget]]: $other")
}
}
}
diff --git a/src/main/scala/firrtl/annotations/TargetToken.scala b/src/main/scala/firrtl/annotations/TargetToken.scala
index 765102a6..a4a98eed 100644
--- a/src/main/scala/firrtl/annotations/TargetToken.scala
+++ b/src/main/scala/firrtl/annotations/TargetToken.scala
@@ -3,12 +3,12 @@
package firrtl.annotations
import firrtl._
-import ir.{DefModule, DefInstance}
+import ir.{DefInstance, DefModule}
/** Building block to represent a [[Target]] of a FIRRTL component */
sealed trait TargetToken {
def keyword: String
- def value: Any
+ def value: Any
/** Returns whether this token is one of the type of tokens whose keyword is passed as an argument
* @param keywords
@@ -16,8 +16,10 @@ sealed trait TargetToken {
*/
def is(keywords: String*): Boolean = {
keywords.map { kw =>
- require(TargetToken.keyword2targettoken.keySet.contains(kw),
- s"Keyword $kw must be in set ${TargetToken.keyword2targettoken.keys}")
+ require(
+ TargetToken.keyword2targettoken.keySet.contains(kw),
+ s"Keyword $kw must be in set ${TargetToken.keyword2targettoken.keys}"
+ )
val lastClass = this.getClass
lastClass == TargetToken.keyword2targettoken(kw)("0").getClass
}.reduce(_ || _)
@@ -26,20 +28,20 @@ sealed trait TargetToken {
/** Object containing all [[TargetToken]] subclasses */
case object TargetToken {
- case class Instance(value: String) extends TargetToken { override def keyword: String = "inst" }
- case class OfModule(value: String) extends TargetToken { override def keyword: String = "of" }
- case class Ref(value: String) extends TargetToken { override def keyword: String = "ref" }
- case class Index(value: Int) extends TargetToken { override def keyword: String = "[]" }
- case class Field(value: String) extends TargetToken { override def keyword: String = "." }
- case object Clock extends TargetToken { override def keyword: String = "clock"; val value = "" }
- case object Init extends TargetToken { override def keyword: String = "init"; val value = "" }
- case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" }
+ case class Instance(value: String) extends TargetToken { override def keyword: String = "inst" }
+ case class OfModule(value: String) extends TargetToken { override def keyword: String = "of" }
+ case class Ref(value: String) extends TargetToken { override def keyword: String = "ref" }
+ case class Index(value: Int) extends TargetToken { override def keyword: String = "[]" }
+ case class Field(value: String) extends TargetToken { override def keyword: String = "." }
+ case object Clock extends TargetToken { override def keyword: String = "clock"; val value = "" }
+ case object Init extends TargetToken { override def keyword: String = "init"; val value = "" }
+ case object Reset extends TargetToken { override def keyword: String = "reset"; val value = "" }
implicit class fromStringToTargetToken(s: String) {
def Instance: Instance = new TargetToken.Instance(s)
def OfModule: OfModule = new TargetToken.OfModule(s)
- def Ref: Ref = new TargetToken.Ref(s)
- def Field: Field = new TargetToken.Field(s)
+ def Ref: Ref = new TargetToken.Ref(s)
+ def Field: Field = new TargetToken.Field(s)
}
implicit class fromIntToTargetToken(i: Int) {
@@ -67,4 +69,3 @@ case object TargetToken {
"reset" -> ((value: String) => Reset)
)
}
-
diff --git a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala
index 8f925ee7..31d13139 100644
--- a/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala
+++ b/src/main/scala/firrtl/annotations/analysis/DuplicationHelper.scala
@@ -88,10 +88,12 @@ case class DuplicationHelper(existingModules: Set[String]) {
* @param originalOfModule original module being instantiated in originalModule
* @return
*/
- def getNewOfModule(originalModule: String,
- newModule: String,
- instance: Instance,
- originalOfModule: OfModule): OfModule = {
+ def getNewOfModule(
+ originalModule: String,
+ newModule: String,
+ instance: Instance,
+ originalOfModule: OfModule
+ ): OfModule = {
dupMap.get(originalModule) match {
case None => // No duplication, can return originalOfModule
originalOfModule
@@ -129,18 +131,18 @@ case class DuplicationHelper(existingModules: Set[String]) {
val newTops = getDuplicates(top)
newTops.map { newTop =>
val newPath = mutable.ArrayBuffer[TargetToken]()
- path.foldLeft((top, newTop)) { case ((originalModule, newModule), (instance, ofModule)) =>
- val newOfModule = getNewOfModule(originalModule, newModule, instance, ofModule)
- newPath ++= Seq(instance, newOfModule)
- (ofModule.value, newOfModule.value)
+ path.foldLeft((top, newTop)) {
+ case ((originalModule, newModule), (instance, ofModule)) =>
+ val newOfModule = getNewOfModule(originalModule, newModule, instance, ofModule)
+ newPath ++= Seq(instance, newOfModule)
+ (ofModule.value, newOfModule.value)
}
- val module = if(newPath.nonEmpty) newPath.last.value.toString else newTop
+ val module = if (newPath.nonEmpty) newPath.last.value.toString else newTop
t.notPath match {
- case Seq() => ModuleTarget(t.circuit, module)
+ case Seq() => ModuleTarget(t.circuit, module)
case Instance(i) +: OfModule(m) +: Seq() => ModuleTarget(t.circuit, module)
- case Ref(r) +: components => ReferenceTarget(t.circuit, module, Nil, r, components)
+ case Ref(r) +: components => ReferenceTarget(t.circuit, module, Nil, r, components)
}
}.toSeq
}
}
-
diff --git a/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala b/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala
index a4219a03..20304378 100644
--- a/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala
+++ b/src/main/scala/firrtl/annotations/transforms/CleanupNamedTargets.scala
@@ -27,19 +27,25 @@ class CleanupNamedTargets extends Transform with DependencyAPIMigration {
override def invalidates(a: Transform) = false
- private def onStatement(statement: ir.Statement)
- (implicit references: ISet[ReferenceTarget],
- renameMap: RenameMap,
- module: ModuleTarget): Unit = statement match {
+ private def onStatement(
+ statement: ir.Statement
+ )(
+ implicit references: ISet[ReferenceTarget],
+ renameMap: RenameMap,
+ module: ModuleTarget
+ ): Unit = statement match {
case ir.DefInstance(_, a, b, _) if references(module.instOf(a, b).asReference) =>
renameMap.record(module.instOf(a, b).asReference, module.instOf(a, b))
case a => statement.foreach(onStatement)
}
- private def onModule(module: ir.DefModule)
- (implicit references: ISet[ReferenceTarget],
- renameMap: RenameMap,
- circuit: CircuitTarget): Unit = {
+ private def onModule(
+ module: ir.DefModule
+ )(
+ implicit references: ISet[ReferenceTarget],
+ renameMap: RenameMap,
+ circuit: CircuitTarget
+ ): Unit = {
implicit val mTarget = circuit.module(module.name)
module.foreach(onStatement)
}
@@ -49,7 +55,7 @@ class CleanupNamedTargets extends Transform with DependencyAPIMigration {
implicit val rTargets: ISet[ReferenceTarget] = state.annotations.flatMap {
case a: SingleTargetAnnotation[_] => Some(a.target)
case a: MultiTargetAnnotation => a.targets.flatten
- case _ => None
+ case _ => None
}.collect {
case a: ReferenceTarget => a
}.toSet
diff --git a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala
index d92d3b5e..596a344f 100644
--- a/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala
+++ b/src/main/scala/firrtl/annotations/transforms/EliminateTargetPaths.scala
@@ -5,7 +5,7 @@ package firrtl.annotations.transforms
import firrtl.Mappers._
import firrtl.analyses.InstanceKeyGraph
import firrtl.annotations.ModuleTarget
-import firrtl.annotations.TargetToken.{Instance, OfModule, fromDefModuleToTargetToken}
+import firrtl.annotations.TargetToken.{fromDefModuleToTargetToken, Instance, OfModule}
import firrtl.annotations.analysis.DuplicationHelper
import firrtl.annotations._
import firrtl.ir._
@@ -15,7 +15,6 @@ import firrtl.transforms.DedupedResult
import scala.collection.mutable
-
/** Group of targets that should become local targets
* @param targets
*/
@@ -36,7 +35,7 @@ case class DupedResult(newModules: Set[IsModule], originalModule: ModuleTarget)
override def duplicate(n: Seq[Seq[Target]]): Annotation = {
n.toList match {
case Seq(newMods) => DupedResult(newMods.collect { case x: IsModule => x }.toSet, originalModule)
- case _ => DupedResult(Set.empty, originalModule)
+ case _ => DupedResult(Set.empty, originalModule)
}
}
}
@@ -47,35 +46,35 @@ object EliminateTargetPaths {
def renameModules(c: Circuit, toRename: Map[String, String], renameMap: RenameMap): Circuit = {
val ct = CircuitTarget(c.main)
- val cx = if(toRename.contains(c.main)) {
+ val cx = if (toRename.contains(c.main)) {
renameMap.record(ct, CircuitTarget(toRename(c.main)))
c.copy(main = toRename(c.main))
} else {
c
}
def onMod(m: DefModule): DefModule = {
- m map onStmt match {
+ m.map(onStmt) match {
case e: ExtModule if toRename.contains(e.name) =>
renameMap.record(ct.module(e.name), ct.module(toRename(e.name)))
e.copy(name = toRename(e.name))
- case e: Module if toRename.contains(e.name) =>
+ case e: Module if toRename.contains(e.name) =>
renameMap.record(ct.module(e.name), ct.module(toRename(e.name)))
e.copy(name = toRename(e.name))
case o => o
}
}
- def onStmt(s: Statement): Statement = s map onStmt match {
- case w@DefInstance(info, name, module, _) if toRename.contains(module) => w.copy(module = toRename(module))
- case other => other
+ def onStmt(s: Statement): Statement = s.map(onStmt) match {
+ case w @ DefInstance(info, name, module, _) if toRename.contains(module) => w.copy(module = toRename(module))
+ case other => other
}
- cx map onMod
+ cx.map(onMod)
}
def reorderModules(c: Circuit, toReorder: Map[String, Double]): Circuit = {
val newOrderMap = c.modules.zipWithIndex.map {
case (m, _) if toReorder.contains(m.name) => m.name -> toReorder(m.name)
- case (m, i) if c.modules.size > 1 => m.name -> i.toDouble / (c.modules.size - 1)
- case (m, _) => m.name -> 1.0
+ case (m, i) if c.modules.size > 1 => m.name -> i.toDouble / (c.modules.size - 1)
+ case (m, _) => m.name -> 1.0
}.toMap
val newOrder = c.modules.sortBy { m => newOrderMap(m.name) }
@@ -83,7 +82,6 @@ object EliminateTargetPaths {
c.copy(modules = newOrder)
}
-
}
/** For a set of non-local targets, modify the instance/module hierarchy of the circuit such that
@@ -116,24 +114,20 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
* @param s
* @return
*/
- private def onStmt(dupMap: DuplicationHelper)
- (originalModule: String, newModule: String)
- (s: Statement): Statement = s match {
- case d@DefInstance(_, name, module, _) =>
- val ofModule = dupMap.getNewOfModule(originalModule, newModule, Instance(name), OfModule(module)).value
- d.copy(module = ofModule)
- case other => other map onStmt(dupMap)(originalModule, newModule)
- }
+ private def onStmt(dupMap: DuplicationHelper)(originalModule: String, newModule: String)(s: Statement): Statement =
+ s match {
+ case d @ DefInstance(_, name, module, _) =>
+ val ofModule = dupMap.getNewOfModule(originalModule, newModule, Instance(name), OfModule(module)).value
+ d.copy(module = ofModule)
+ case other => other.map(onStmt(dupMap)(originalModule, newModule))
+ }
/** Returns a modified circuit and [[RenameMap]] containing the associated target remapping
* @param cir
* @param targets
* @return
*/
- def run(cir: Circuit,
- targets: Seq[IsMember],
- iGraph: InstanceKeyGraph
- ): (Circuit, RenameMap, AnnotationSeq) = {
+ def run(cir: Circuit, targets: Seq[IsMember], iGraph: InstanceKeyGraph): (Circuit, RenameMap, AnnotationSeq) = {
val dupMap = DuplicationHelper(cir.modules.map(_.name).toSet)
@@ -161,7 +155,7 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
}
val finalModuleList = duplicatedModuleList
- lazy val finalModuleSet = finalModuleList.map{ case a: DefModule => a.name }.toSet
+ lazy val finalModuleSet = finalModuleList.map { case a: DefModule => a.name }.toSet
// Records how targets have been renamed
val renameMap = RenameMap()
@@ -203,8 +197,9 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
duplicatedParents.foreach { parent =>
val paths = iGraph.findInstancesInHierarchy(parent.value)
val newTargets = paths.map { path =>
- path.tail.foldLeft(topMod: IsModule) { case (mod, wDefInst) =>
- mod.instOf(wDefInst.name, wDefInst.module)
+ path.tail.foldLeft(topMod: IsModule) {
+ case (mod, wDefInst) =>
+ mod.instOf(wDefInst.name, wDefInst.module)
}
}
newTargets.foreach(addSelfRecord(_))
@@ -219,13 +214,11 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
val (remainingAnnotations, targetsToEliminate, previouslyDeduped) =
state.annotations.foldLeft(
- ( Vector.empty[Annotation],
- Seq.empty[CompleteTarget],
- Map.empty[IsModule, (ModuleTarget, Double)]
- )
- ) { case ((remainingAnnos, targets, dedupedResult), anno) =>
+ (Vector.empty[Annotation], Seq.empty[CompleteTarget], Map.empty[IsModule, (ModuleTarget, Double)])
+ ) {
+ case ((remainingAnnos, targets, dedupedResult), anno) =>
anno match {
- case ResolvePaths(ts) =>
+ case ResolvePaths(ts) =>
(remainingAnnos, ts ++ targets, dedupedResult)
case DedupedResult(orig, dups, idx) if dups.nonEmpty =>
(remainingAnnos, targets, dedupedResult ++ dups.map(_ -> (orig, idx)).toMap)
@@ -234,29 +227,29 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
}
}
-
// Collect targets that are not local
val targets = targetsToEliminate.collect { case x: IsMember => x }
// Check validity of paths in targets
val iGraph = InstanceKeyGraph(state.circuit)
- val instanceOfModules = iGraph.getChildInstances.map { case(k,v) => k -> v.map(_.toTokens) }.toMap
+ val instanceOfModules = iGraph.getChildInstances.map { case (k, v) => k -> v.map(_.toTokens) }.toMap
val targetsWithInvalidPaths = mutable.ArrayBuffer[IsMember]()
targets.foreach { t =>
val path = t match {
- case _: ModuleTarget => Nil
- case i: InstanceTarget => i.asPath
+ case _: ModuleTarget => Nil
+ case i: InstanceTarget => i.asPath
case r: ReferenceTarget => r.path
}
- path.foldLeft(t.module) { case (module, (inst: Instance, of: OfModule)) =>
- val childrenOpt = instanceOfModules.get(module)
- if(childrenOpt.isEmpty || !childrenOpt.get.contains((inst, of))) {
- targetsWithInvalidPaths += t
- }
- of.value
+ path.foldLeft(t.module) {
+ case (module, (inst: Instance, of: OfModule)) =>
+ val childrenOpt = instanceOfModules.get(module)
+ if (childrenOpt.isEmpty || !childrenOpt.get.contains((inst, of))) {
+ targetsWithInvalidPaths += t
+ }
+ of.value
}
}
- if(targetsWithInvalidPaths.nonEmpty) {
+ if (targetsWithInvalidPaths.nonEmpty) {
val string = targetsWithInvalidPaths.mkString(",")
throw NoSuchTargetException(s"""Some targets have illegal paths that cannot be resolved/eliminated: $string""")
}
@@ -292,7 +285,7 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
}
val newTarget = t match {
case r: ReferenceTarget => r.setPathTarget(newIsModule)
- case i: InstanceTarget => newIsModule
+ case i: InstanceTarget => newIsModule
}
firstRenameMap.record(t, Seq(newTarget))
newTarget +: acc
@@ -312,10 +305,10 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
}
val iGraphx = InstanceKeyGraph(newCircuit)
- val newlyUnreachableModules = iGraphx.unreachableModules.toSet diff iGraph.unreachableModules.toSet
+ val newlyUnreachableModules = iGraphx.unreachableModules.toSet.diff(iGraph.unreachableModules.toSet)
val newCircuitGC = {
- val modulesx = newCircuit.modules.flatMap{
+ val modulesx = newCircuit.modules.flatMap {
case dead if newlyUnreachableModules(dead.OfModule) => None
case live =>
val m = CircuitTarget(newCircuit.main).module(live.name)
@@ -338,7 +331,8 @@ class EliminateTargetPaths extends Transform with DependencyAPIMigration {
val renamedCircuit = renameModules(newCircuitGC, newModuleNameMapping, renamedModuleMap)
- val reorderedCircuit = reorderModules(renamedCircuit,
+ val reorderedCircuit = reorderModules(
+ renamedCircuit,
previouslyDeduped.map {
case (current: IsModule, (orig: ModuleTarget, idx)) =>
orig.name -> idx
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
index f7ab9927..66690f56 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
@@ -26,7 +26,8 @@ private class Btor2Serializer private () {
private def comment(c: String): Unit = { lines += s"; $c" }
private def trailingComment(c: String): Unit = {
val lastLine = lines.last
- val newLine = if(lastLine.contains(';')) { lastLine + " " + c} else { lastLine + " ; " + c }
+ val newLine = if (lastLine.contains(';')) { lastLine + " " + c }
+ else { lastLine + " ; " + c }
lines(lines.size - 1) = newLine
}
@@ -38,54 +39,55 @@ private class Btor2Serializer private () {
// bit vector expression serialization
private def s(expr: BVExpr): Int = expr match {
case BVLiteral(value, width) => lit(value, width)
- case BVSymbol(name, _) => symbols(name)
- case BVExtend(e, 0, _) => s(e)
- case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
- case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
+ case BVSymbol(name, _) => symbols(name)
+ case BVExtend(e, 0, _) => s(e)
+ case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
+ case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
case BVSlice(e, hi, lo) =>
- if (lo == 0 && hi == e.width - 1) { s(e) } else {
+ if (lo == 0 && hi == e.width - 1) { s(e) }
+ else {
line(s"slice ${t(expr.width)} ${s(e)} $hi $lo")
}
- case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b)
- case BVNot(BVNot(e)) => s(e)
- case BVNot(e) => unary("not", expr.width, e)
- case BVNegate(e) => unary("neg", expr.width, e)
- case BVReduceAnd(e) => unary("redand", expr.width, e)
- case BVReduceOr(e) => unary("redor", expr.width, e)
- case BVReduceXor(e) => unary("redxor", expr.width, e)
- case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b)
- case BVImplies(a, b) => binary("implies", expr.width, a, b)
- case BVEqual(a, b) => binary("eq", expr.width, a, b)
- case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}")
- case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b)
+ case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b)
+ case BVNot(BVNot(e)) => s(e)
+ case BVNot(e) => unary("not", expr.width, e)
+ case BVNegate(e) => unary("neg", expr.width, e)
+ case BVReduceAnd(e) => unary("redand", expr.width, e)
+ case BVReduceOr(e) => unary("redor", expr.width, e)
+ case BVReduceXor(e) => unary("redxor", expr.width, e)
+ case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b)
+ case BVImplies(a, b) => binary("implies", expr.width, a, b)
+ case BVEqual(a, b) => binary("eq", expr.width, a, b)
+ case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}")
+ case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b)
case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b)
- case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b)
- case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
- case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
- case BVConcat(a, b) => binary("concat", expr.width, a, b)
+ case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b)
+ case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
+ case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
+ case BVConcat(a, b) => binary("concat", expr.width, a, b)
case ArrayRead(array, index) =>
line(s"read ${t(expr.width)} ${s(array)} ${s(index)}")
case BVIte(cond, tru, fals) =>
line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}")
- case r : BVRawExpr =>
+ case r: BVRawExpr =>
throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
}
private def s(op: Op.Value): String = op match {
- case Op.And => "and"
- case Op.Or => "or"
- case Op.Xor => "xor"
+ case Op.And => "and"
+ case Op.Or => "or"
+ case Op.Xor => "xor"
case Op.ArithmeticShiftRight => "sra"
- case Op.ShiftRight => "srl"
- case Op.ShiftLeft => "sll"
- case Op.Add => "add"
- case Op.Mul => "mul"
- case Op.Sub => "sub"
- case Op.SignedDiv => "sdiv"
- case Op.UnsignedDiv => "udiv"
- case Op.SignedMod => "smod"
- case Op.SignedRem => "srem"
- case Op.UnsignedRem => "urem"
+ case Op.ShiftRight => "srl"
+ case Op.ShiftLeft => "sll"
+ case Op.Add => "add"
+ case Op.Mul => "mul"
+ case Op.Sub => "sub"
+ case Op.SignedDiv => "sdiv"
+ case Op.UnsignedDiv => "udiv"
+ case Op.SignedMod => "smod"
+ case Op.SignedRem => "srem"
+ case Op.UnsignedRem => "urem"
}
private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}")
@@ -123,18 +125,18 @@ private class Btor2Serializer private () {
// It is essential to model memories, so any support in the wild should be fairly well tested.
line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}")
case ArrayConstant(e, _) => s(e)
- case r : ArrayRawExpr =>
+ case r: ArrayRawExpr =>
throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
}
private def s(expr: SMTExpr): Int = expr match {
- case b: BVExpr => s(b)
+ case b: BVExpr => s(b)
case a: ArrayExpr => s(a)
}
// serialize the type of the expression
private def t(expr: SMTExpr): Int = expr match {
- case b: BVExpr => t(b.width)
+ case b: BVExpr => t(b.width)
case a: ArrayExpr => t(a.indexWidth, a.dataWidth)
}
@@ -145,7 +147,7 @@ private class Btor2Serializer private () {
symbols(name) = id
if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name")
if (sys.assumes.contains(name)) line(s"constraint $id ; $name")
- if (sys.asserts.contains(name)){
+ if (sys.asserts.contains(name)) {
val invertedId = line(s"not ${t(1)} $id")
line(s"bad $invertedId ; $name")
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index 0a223840..efa89687 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -9,26 +9,26 @@ import firrtl.passes.CheckWidths.WidthTooBig
private trait TranslationContext {
def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe))
- def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe))
- def getRandom(width: Int): BVExpr
+ def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe))
+ def getRandom(width: Int): BVExpr
}
private object FirrtlExpressionSemantics {
def getWidth(tpe: ir.Type): Int = tpe match {
- case ir.UIntType(ir.IntWidth(w)) => w.toInt
- case ir.SIntType(ir.IntWidth(w)) => w.toInt
- case ir.ClockType => 1
- case ir.ResetType => 1
+ case ir.UIntType(ir.IntWidth(w)) => w.toInt
+ case ir.SIntType(ir.IntWidth(w)) => w.toInt
+ case ir.ClockType => 1
+ case ir.ResetType => 1
case ir.AnalogType(ir.IntWidth(w)) => w.toInt
- case other => throw new RuntimeException(s"Cannot handle type $other")
+ case other => throw new RuntimeException(s"Cannot handle type $other")
}
def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = {
val eSMT = e match {
case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts)
- case r : ir.Reference => ctx.getReference(r.serialize, r.tpe)
- case r : ir.SubField => ctx.getReference(r.serialize, r.tpe)
- case r : ir.SubIndex => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.Reference => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.SubField => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.SubIndex => ctx.getReference(r.serialize, r.tpe)
case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
case ir.Mux(cond, tval, fval, _) =>
@@ -38,7 +38,10 @@ private object FirrtlExpressionSemantics {
val tru = toSMT(value)
BVIte(toSMT(cond), tru, ctx.getRandom(tpe))
}
- assert(eSMT.width == getWidth(e), "We aim to always produce a SMT expression of the same width as the firrtl expression.")
+ assert(
+ eSMT.width == getWidth(e),
+ "We aim to always produce a SMT expression of the same width as the firrtl expression."
+ )
eSMT
}
@@ -47,8 +50,8 @@ private object FirrtlExpressionSemantics {
forceWidth(toSMT(e), isSigned(e), width, allowNarrow)
private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = {
- if(eSMT.width == width) { eSMT }
- else if(width < eSMT.width) {
+ if (eSMT.width == width) { eSMT }
+ else if (width < eSMT.width) {
assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!")
BVSlice(eSMT, width - 1, 0)
} else {
@@ -57,8 +60,13 @@ private object FirrtlExpressionSemantics {
}
// see "Primitive Operations" section in the Firrtl Specification
- private def onPrim(op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt])(implicit ctx: TranslationContext):
- BVExpr = {
+ private def onPrim(
+ op: ir.PrimOp,
+ args: Seq[ir.Expression],
+ consts: Seq[BigInt]
+ )(
+ implicit ctx: TranslationContext
+ ): BVExpr = {
(op, args, consts) match {
case (PrimOps.Add, Seq(e1, e2), _) =>
val width = args.map(getWidth).max + 1
@@ -70,7 +78,7 @@ private object FirrtlExpressionSemantics {
val width = args.map(getWidth).sum
BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Div, Seq(num, den), _) =>
- val (width, op) = if(isSigned(num)) {
+ val (width, op) = if (isSigned(num)) {
(getWidth(num) + 1, Op.SignedDiv)
} else { (getWidth(num), Op.UnsignedDiv) }
// "The result of a division where den is zero is undefined."
@@ -83,11 +91,12 @@ private object FirrtlExpressionSemantics {
val width = getWidth(num) + 1
BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width))
case (PrimOps.Rem, Seq(num, den), _) =>
- val op = if(isSigned(num)) Op.SignedRem else Op.UnsignedRem
+ val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem
val width = args.map(getWidth).max
val resWidth = args.map(getWidth).min
val res = BVOp(op, toSMT(num, width), toSMT(den, width))
- if(res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } else { res }
+ if (res.width > resWidth) { BVSlice(res, resWidth - 1, 0) }
+ else { res }
case (PrimOps.Lt, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
@@ -108,25 +117,29 @@ private object FirrtlExpressionSemantics {
BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width)))
case (PrimOps.Pad, Seq(e), Seq(n)) =>
val width = getWidth(e)
- if(n <= width) { toSMT(e) } else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) }
- case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e) ; toSMT(e)
- case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e) ; toSMT(e)
+ if (n <= width) { toSMT(e) }
+ else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) }
+ case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e); toSMT(e)
+ case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e); toSMT(e)
case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!")
- case (PrimOps.AsClock, Seq(e), _) => toSMT(e)
+ case (PrimOps.AsClock, Seq(e), _) => toSMT(e)
case (PrimOps.AsAsyncReset, Seq(e), _) =>
checkForClockInCast(PrimOps.AsAsyncReset, e)
throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.")
- case (PrimOps.Shl, Seq(e), Seq(n)) => if(n == 0) { toSMT(e) } else {
- val zeros = BVLiteral(0, n.toInt)
- BVConcat(toSMT(e), zeros)
- }
+ case (PrimOps.Shl, Seq(e), Seq(n)) =>
+ if (n == 0) { toSMT(e) }
+ else {
+ val zeros = BVLiteral(0, n.toInt)
+ BVConcat(toSMT(e), zeros)
+ }
case (PrimOps.Shr, Seq(e), Seq(n)) =>
val width = getWidth(e)
// "If n is greater than or equal to the bit-width of e,
// the resulting value will be zero for unsigned types
// and the sign bit for signed types"
- if(n >= width) {
- if(isSigned(e)) { BV1BitZero } else { BVSlice(toSMT(e), width - 1, width - 1) }
+ if (n >= width) {
+ if (isSigned(e)) { BV1BitZero }
+ else { BVSlice(toSMT(e), width - 1, width - 1) }
} else {
BVSlice(toSMT(e), width - 1, n.toInt)
}
@@ -135,9 +148,11 @@ private object FirrtlExpressionSemantics {
BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Dshr, Seq(e1, e2), _) =>
val width = getWidth(e1)
- val o = if(isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight
+ val o = if (isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight
BVOp(o, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Cvt, Seq(e), _) => if(isSigned(e)) { toSMT(e) } else { BVConcat(BV1BitZero, toSMT(e)) }
+ case (PrimOps.Cvt, Seq(e), _) =>
+ if (isSigned(e)) { toSMT(e) }
+ else { BVConcat(BV1BitZero, toSMT(e)) }
case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e)))
case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e))
case (PrimOps.And, Seq(e1, e2), _) =>
@@ -149,10 +164,10 @@ private object FirrtlExpressionSemantics {
case (PrimOps.Xor, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e))
- case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e))
- case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e))
- case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2))
+ case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e))
+ case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e))
+ case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e))
+ case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2))
case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt)
case (PrimOps.Head, Seq(e), Seq(n)) =>
val width = getWidth(e)
@@ -167,7 +182,8 @@ private object FirrtlExpressionSemantics {
}
/** For now we strictly forbid casting clocks to anything else.
- * Eventually this should be replaced by a more sophisticated clock analysis pass. */
+ * Eventually this should be replaced by a more sophisticated clock analysis pass.
+ */
private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = {
assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!")
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index b3a2ff17..0888b062 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -11,7 +11,16 @@ import firrtl.passes.PassException
import firrtl.stage.Forms
import firrtl.stage.TransformManager.TransformDependency
import firrtl.transforms.PropagatePresetAnnotations
-import firrtl.{CircuitState, DependencyAPIMigration, MemoryArrayInit, MemoryInitValue, MemoryScalarInit, Transform, Utils, ir}
+import firrtl.{
+ ir,
+ CircuitState,
+ DependencyAPIMigration,
+ MemoryArrayInit,
+ MemoryInitValue,
+ MemoryScalarInit,
+ Transform,
+ Utils
+}
import logger.LazyLogging
import scala.collection.mutable
@@ -22,15 +31,21 @@ import scala.collection.mutable
private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr])
private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) }
private case class TransitionSystem(
- name: String, inputs: Array[BVSymbol], states: Array[State], signals: Array[Signal],
- outputs: Set[String], assumes: Set[String], asserts: Set[String], fair: Set[String],
- comments: Map[String, String] = Map(), header: Array[String] = Array()) {
+ name: String,
+ inputs: Array[BVSymbol],
+ states: Array[State],
+ signals: Array[Signal],
+ outputs: Set[String],
+ assumes: Set[String],
+ asserts: Set[String],
+ fair: Set[String],
+ comments: Map[String, String] = Map(),
+ header: Array[String] = Array()) {
def serialize: String = {
(Iterator(name) ++
inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
- states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")
- ).mkString("\n")
+ states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n")
}
}
@@ -53,26 +68,30 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
// run the preset pass to extract all preset registers and remove preset reset signals
val afterPreset = presetPass.execute(state)
val circuit = afterPreset.circuit
- val presetRegs = afterPreset.annotations
- .collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet
+ val presetRegs = afterPreset.annotations.collect {
+ case PresetRegAnnotation(target) if target.module == circuit.main => target.ref
+ }.toSet
// collect all non-random memory initialization
val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a }
- .filter(_.target.module == circuit.main).map(a => a.target.ref -> a.initValue).toMap
+ .filter(_.target.module == circuit.main)
+ .map(a => a.target.ref -> a.initValue)
+ .toMap
// convert the main module
val main = circuit.modules.find(_.name == circuit.main).get
val sys = main match {
case x: ir.ExtModule =>
throw new ExtModuleException(
- "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog.")
+ "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog."
+ )
case m: ir.Module =>
- new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit=memInit)
+ new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit)
}
val sortedSys = TopologicalSort.run(sys)
val anno = TransitionSystemAnnotation(sortedSys)
- state.copy(circuit=circuit, annotations = afterPreset.annotations :+ anno )
+ state.copy(circuit = circuit, annotations = afterPreset.annotations :+ anno)
}
}
@@ -94,18 +113,23 @@ private object UnsupportedException {
}
private class ExtModuleException(s: String) extends PassException(s)
-private class AsyncResetException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering)
-private class MultiClockException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering)
-private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s)
+private class AsyncResetException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering)
+private class MultiClockException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering)
+private class MissingFeatureException(s: String)
+ extends PassException("Unfortunately the SMT backend does not yet support: " + s)
private class ModuleToTransitionSystem extends LazyLogging {
- def run(m: ir.Module, presetRegs: Set[String] = Set(), memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = {
+ def run(
+ m: ir.Module,
+ presetRegs: Set[String] = Set(),
+ memInit: Map[String, MemoryInitValue] = Map()
+ ): TransitionSystem = {
// first pass over the module to convert expressions; discover state and I/O
val scan = new ModuleScanner(makeRandom)
m.foreachPort(scan.onPort)
// multi-clock support requires the StutteringClock transform to be run
- if(scan.clocks.size > 1) {
+ if (scan.clocks.size > 1) {
throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}")
}
m.foreachStmt(scan.onStatement)
@@ -115,14 +139,16 @@ private class ModuleToTransitionSystem extends LazyLogging {
val constraints = scan.assumes.toSet
val bad = scan.asserts.toSet
val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad
- val signals = scan.connects.filter{ case(name, _) => isSignal.contains(name) }
- .map { case (name, expr) => Signal(name, expr) }
+ val signals = scan.connects.filter { case (name, _) => isSignal.contains(name) }.map {
+ case (name, expr) => Signal(name, expr)
+ }
// turn registers and memories into states
val registers = scan.registers.map(r => r._1 -> r).toMap
- val regStates = scan.connects.filter(s => registers.contains(s._1)).map { case (name, nextExpr) =>
- val (_, width, resetExpr, initExpr) = registers(name)
- onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
+ val regStates = scan.connects.filter(s => registers.contains(s._1)).map {
+ case (name, nextExpr) =>
+ val (_, width, resetExpr, initExpr) = registers(name)
+ onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
}
// turn memories into state
val memoryEncoding = new MemoryEncoding(makeRandom)
@@ -135,16 +161,22 @@ private class ModuleToTransitionSystem extends LazyLogging {
} else { s }
}
// filter out any left-over self assignments (this happens when we have a registered read port)
- .filter(s => s match { case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false case _ => true })
+ .filter(s =>
+ s match {
+ case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false
+ case _ => true
+ }
+ )
val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1)
// generate comments from infos
val comments = mutable.HashMap[String, String]()
- scan.infos.foreach { case (name, info) =>
- serializeInfo(info).foreach { infoString =>
- if(comments.contains(name)) { comments(name) += InfoSeparator + infoString }
- else { comments(name) = InfoPrefix + infoString }
- }
+ scan.infos.foreach {
+ case (name, info) =>
+ serializeInfo(info).foreach { infoString =>
+ if (comments.contains(name)) { comments(name) += InfoSeparator + infoString }
+ else { comments(name) = InfoPrefix + infoString }
+ }
}
// inputs are original module inputs and any "random" signal we need for modelling
@@ -154,11 +186,28 @@ private class ModuleToTransitionSystem extends LazyLogging {
val header = serializeInfo(m.info).map(InfoPrefix + _).toArray
val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints
- TransitionSystem(m.name, inputs.toArray, states, signalsWithMem.toArray, outputs, constraints, bad, fair, comments.toMap, header)
+ TransitionSystem(
+ m.name,
+ inputs.toArray,
+ states,
+ signalsWithMem.toArray,
+ outputs,
+ constraints,
+ bad,
+ fair,
+ comments.toMap,
+ header
+ )
}
- private def onRegister(name: String, width: Int, resetExpr: BVExpr, initExpr: BVExpr,
- nextExpr: BVExpr, presetRegs: Set[String]): State = {
+ private def onRegister(
+ name: String,
+ width: Int,
+ resetExpr: BVExpr,
+ initExpr: BVExpr,
+ nextExpr: BVExpr,
+ presetRegs: Set[String]
+ ): State = {
assert(initExpr.width == width)
assert(nextExpr.width == width)
assert(resetExpr.width == 1)
@@ -166,9 +215,9 @@ private class ModuleToTransitionSystem extends LazyLogging {
val hasReset = initExpr != sym
val isPreset = presetRegs.contains(name)
assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!")
- if(hasReset) {
- val init = if(isPreset) Some(initExpr) else None
- val next = if(isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr)
+ if (hasReset) {
+ val init = if (isPreset) Some(initExpr) else None
+ val next = if (isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr)
State(sym, next = Some(next), init = init)
} else {
State(sym, next = Some(nextExpr), init = None)
@@ -179,10 +228,11 @@ private class ModuleToTransitionSystem extends LazyLogging {
private val InfoPrefix = "@ "
private def serializeInfo(info: ir.Info): Option[String] = info match {
case ir.NoInfo => None
- case f : ir.FileInfo => Some(f.escaped)
- case m : ir.MultiInfo =>
+ case f: ir.FileInfo => Some(f.escaped)
+ case m: ir.MultiInfo =>
val infos = m.flatten
- if(infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
+ if (infos.isEmpty) { None }
+ else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
}
private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
@@ -190,7 +240,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
// TODO: actually ensure that there cannot be any name clashes with other identifiers
val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
- val sym = BVSymbol(name, width)
+ val sym = BVSymbol(name, width)
randoms(name) = sym
sym
}
@@ -198,10 +248,16 @@ private class ModuleToTransitionSystem extends LazyLogging {
private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging {
type Connects = Iterable[(String, BVExpr)]
- def onMemory(defMem: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (Iterable[State], Connects) = {
+ def onMemory(
+ defMem: ir.DefMemory,
+ connects: Connects,
+ initValue: Option[MemoryInitValue]
+ ): (Iterable[State], Connects) = {
// we can only work on appropriately lowered memories
- assert(defMem.dataType.isInstanceOf[ir.GroundType],
- s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!")
+ assert(
+ defMem.dataType.isInstanceOf[ir.GroundType],
+ s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!"
+ )
assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
// collect all memory meta-data in a custom class
@@ -214,17 +270,19 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val init = initValue.map(getInit(m, _))
// parse and check read and write ports
- val writers = defMem.writers.map( w => new WritePort(m, w, inputs))
- val readers = defMem.readers.map( r => new ReadPort(m, r, inputs))
+ val writers = defMem.writers.map(w => new WritePort(m, w, inputs))
+ val readers = defMem.readers.map(r => new ReadPort(m, r, inputs))
// derive next state from all write ports
assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.")
- val next: ArrayExpr = if(writers.isEmpty) { m.sym } else {
- if(writers.length > 2) {
+ val next: ArrayExpr = if (writers.isEmpty) { m.sym }
+ else {
+ if (writers.length > 2) {
throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})")
}
val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) }
- if(writers.length == 1) { validData } else {
+ if (writers.length == 1) { validData }
+ else {
assert(writers.length == 2)
val conflict = writers.head.doesConflict(writers.last)
val conflictData = writers.head.makeRandomData("_write_write_collision")
@@ -236,13 +294,13 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
// derive data signals from all read ports
assert(defMem.readLatency >= 0)
- if(defMem.readLatency > 1) {
+ if (defMem.readLatency > 1) {
throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})")
}
- val readPortSignals = if(defMem.readLatency == 0) {
+ val readPortSignals = if (defMem.readLatency == 0) {
readers.map { r =>
// combinatorial read
- if(defMem.readUnderWrite != ir.ReadUnderWrite.New) {
+ if (defMem.readUnderWrite != ir.ReadUnderWrite.New) {
//logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
// s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.")
}
@@ -251,22 +309,25 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
r.data.name -> data
}
} else { Seq() }
- val readPortStates = if(defMem.readLatency == 1) {
+ val readPortStates = if (defMem.readLatency == 1) {
readers.map { r =>
// we create a register for the read port data
val next = defMem.readUnderWrite match {
case ir.ReadUnderWrite.New =>
- throw new UnsupportedFeatureException(s"registered read ports that return the new value (${m.name}.${r.name})")
- // the thing that makes this hard is to properly handle write conflicts
+ throw new UnsupportedFeatureException(
+ s"registered read ports that return the new value (${m.name}.${r.name})"
+ )
+ // the thing that makes this hard is to properly handle write conflicts
case ir.ReadUnderWrite.Undefined =>
val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r)))
- if(anyWriteToTheSameAddress == False) { r.readOld() } else {
+ if (anyWriteToTheSameAddress == False) { r.readOld() }
+ else {
val readUnderWriteData = r.makeRandomData("_read_under_write_undefined")
BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld())
}
case ir.ReadUnderWrite.Old => r.readOld()
}
- State(r.data, init=None, next=Some(next))
+ State(r.data, init = None, next = Some(next))
}
} else { Seq() }
@@ -276,16 +337,20 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match {
case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth)
case MemoryArrayInit(values) =>
- assert(values.length == m.depth,
- s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!")
+ assert(
+ values.length == m.depth,
+ s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
+ )
// in order to get a more compact encoding try to find the most common values
val histogram = mutable.LinkedHashMap[BigInt, Int]()
values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
val baseValue = histogram.maxBy(_._2)._1
val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth)
- values.zipWithIndex.filterNot(_._1 == baseValue)
- .foldLeft[ArrayExpr](base) { case (array, (value, index)) =>
- ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
+ values.zipWithIndex
+ .filterNot(_._1 == baseValue)
+ .foldLeft[ArrayExpr](base) {
+ case (array, (value, index)) =>
+ ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
}
case other => throw new RuntimeException(s"Unsupported memory init option: $other")
}
@@ -295,19 +360,20 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val depth = m.depth
// derrive the type of the memory from the dataType and depth
val dataWidth = getWidth(m.dataType)
- val indexWidth = Utils.getUIntWidth(m.depth - 1) max 1
+ val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
val sym = ArraySymbol(m.name, indexWidth, dataWidth)
val prefix = m.name + "."
val fullAddressRange = (BigInt(1) << indexWidth) == m.depth
lazy val depthBV = BVLiteral(m.depth, indexWidth)
def isValidAddress(addr: BVExpr): BVExpr = {
- if(fullAddressRange) { True } else {
+ if (fullAddressRange) { True }
+ else {
BVComparison(Compare.Greater, depthBV, addr, signed = false)
}
}
}
private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) {
- val en: BVSymbol = makeField("en", 1)
+ val en: BVSymbol = makeField("en", 1)
val data: BVSymbol = makeField("data", memory.dataWidth)
val addr: BVSymbol = makeField("addr", memory.indexWidth)
protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width)
@@ -321,11 +387,11 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val canBeOutOfRange = !memory.fullAddressRange
val canBeDisabled = !enIsTrue
val data = ArrayRead(memory.sym, addr)
- val dataWithRangeCheck = if(canBeOutOfRange) {
+ val dataWithRangeCheck = if (canBeOutOfRange) {
val outOfRangeData = makeRandomData("_addr_out_of_range")
BVIte(memory.isValidAddress(addr), data, outOfRangeData)
} else { data }
- val dataWithEnabledCheck = if(canBeDisabled) {
+ val dataWithEnabledCheck = if (canBeDisabled) {
val disabledData = makeRandomData("_not_enabled")
BVIte(en, dataWithRangeCheck, disabledData)
} else { dataWithRangeCheck }
@@ -333,48 +399,49 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
}
}
private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {
+ extends MemPort(memory, name, inputs) {
assert(inputs(data.name).width == data.width)
val mask: BVSymbol = makeField("mask", 1)
assert(inputs(mask.name).width == mask.width)
val maskIsTrue: Boolean = inputs(mask.name) == True
val doWrite: BVExpr = (enIsTrue, maskIsTrue) match {
- case (true, true) => True
- case (true, false) => mask
- case (false, true) => en
+ case (true, true) => True
+ case (true, false) => mask
+ case (false, true) => en
case (false, false) => and(en, mask)
}
def doesConflict(r: ReadPort): BVExpr = {
val sameAddress = BVEqual(r.addr, addr)
- if(doWrite == True) { sameAddress } else { and(doWrite, sameAddress) }
+ if (doWrite == True) { sameAddress }
+ else { and(doWrite, sameAddress) }
}
def doesConflict(w: WritePort): BVExpr = {
val bothWrite = and(doWrite, w.doWrite)
val sameAddress = BVEqual(addr, w.addr)
- if(bothWrite == True) { sameAddress } else { and(doWrite, sameAddress) }
+ if (bothWrite == True) { sameAddress }
+ else { and(doWrite, sameAddress) }
}
def writeTo(array: ArrayExpr): ArrayExpr = {
- val doUpdate = if(memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
- val update = ArrayStore(array, index=addr, data=data)
- if(doUpdate == True) update else ArrayIte(doUpdate, update, array)
+ val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
+ val update = ArrayStore(array, index = addr, data = data)
+ if (doUpdate == True) update else ArrayIte(doUpdate, update, array)
}
}
private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {
- }
+ extends MemPort(memory, name, inputs) {}
- private def and(a: BVExpr, b: BVExpr): BVExpr = (a,b) match {
+ private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match {
case (True, True) => True
- case (True, x) => x
- case (x, True) => x
- case _ => BVOp(Op.And, a, b)
+ case (True, x) => x
+ case (x, True) => x
+ case _ => BVOp(Op.And, a, b)
}
private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b)
private val True = BVLiteral(1, 1)
private val False = BVLiteral(0, 1)
- private def all(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) False else b.reduce((a,b) => and(a,b))
- private def any(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) True else b.reduce((a,b) => or(a,b))
+ private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b))
+ private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b))
}
// performas a first pass over the module collecting all connections, wires, registers, input and outputs
@@ -399,13 +466,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]()
private[firrtl] def onPort(p: ir.Port): Unit = {
- if(isAsyncReset(p.tpe)) {
+ if (isAsyncReset(p.tpe)) {
throw new AsyncResetException(s"Found AsyncReset ${p.name}.")
}
infos.append(p.name -> p.info)
p.direction match {
case ir.Input =>
- if(isClock(p.tpe)) {
+ if (isClock(p.tpe)) {
clocks.add(p.name)
} else {
inputs.append(BVSymbol(p.name, getWidth(p.tpe)))
@@ -416,12 +483,12 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
private[firrtl] def onStatement(s: ir.Statement): Unit = s match {
case ir.DefWire(info, name, tpe) =>
- if(!isClock(tpe)) {
+ if (!isClock(tpe)) {
infos.append(name -> info)
wires.append(name)
}
case ir.DefNode(info, name, expr) =>
- if(!isClock(expr.tpe)) {
+ if (!isClock(expr.tpe)) {
insertDummyAssignsForMemoryOutputs(expr)
infos.append(name -> info)
val e = onExpression(expr, name)
@@ -436,7 +503,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val resetExpr = onExpression(reset, 1, name + "_reset")
val initExpr = onExpression(init, width, name + "_init")
registers.append((name, width, resetExpr, initExpr))
- case m : ir.DefMemory =>
+ case m: ir.DefMemory =>
infos.append(m.name -> m.info)
val outputs = getMemOutputs(m)
(getMemInputs(m) ++ outputs).foreach(memSignals.append(_))
@@ -444,37 +511,39 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
outputs.foreach(name => unusedMemOutputs(name) = dataWidth)
memories.append(m)
case ir.Connect(info, loc, expr) =>
- if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
val name = loc.serialize
insertDummyAssignsForMemoryOutputs(expr)
infos.append(name -> info)
connects.append((name, onExpression(expr, getWidth(loc.tpe), name)))
case ir.IsInvalid(info, loc) =>
- if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
val name = loc.serialize
infos.append(name -> info)
connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe))))
case ir.DefInstance(info, name, module, tpe) =>
- if(!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
+ if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
// we treat all instances as blackboxes
- logger.warn(s"WARN: treating instance $name of $module as blackbox. " +
- "Please flatten your hierarchy if you want to include submodules in the formal model.")
+ logger.warn(
+ s"WARN: treating instance $name of $module as blackbox. " +
+ "Please flatten your hierarchy if you want to include submodules in the formal model."
+ )
val ports = tpe.asInstanceOf[ir.BundleType].fields
// skip clock and async reset ports
- ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe) ).foreach { p =>
- if(!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
+ ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe)).foreach { p =>
+ if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
val isOutput = p.flip == ir.Default
val pName = name + "." + p.name
infos.append(pName -> info)
// outputs of the submodule become inputs to our module
- if(isOutput) {
+ if (isOutput) {
inputs.append(BVSymbol(pName, getWidth(p.tpe)))
} else {
outputs.append(pName)
}
}
case s @ ir.Verification(op, info, _, pred, en, msg) =>
- if(op == ir.Formal.Cover) {
+ if (op == ir.Formal.Cover) {
logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}")
} else {
val name = msgToName(op.toString, msg.string)
@@ -483,22 +552,22 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val e = BVImplies(enabled, predicate)
infos.append(name -> info)
connects.append(name -> e)
- if(op == ir.Formal.Assert) {
+ if (op == ir.Formal.Assert) {
asserts.append(name)
} else {
assumes.append(name)
}
}
- case s : ir.Conditionally =>
+ case s: ir.Conditionally =>
error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}")
- case s : ir.PartialConnect =>
+ case s: ir.PartialConnect =>
error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}")
- case s : ir.Attach =>
+ case s: ir.Attach =>
error(s"Analog wires are not supported in the SMT backend: ${s.serialize}")
- case s : ir.Stop =>
+ case s: ir.Stop =>
// we could wire up the stop condition as output for debug reasons
logger.warn(s"WARN: Stop statements are currently not supported. Ignoring: ${s.serialize}")
- case s : ir.Print =>
+ case s: ir.Print =>
logger.warn(s"WARN: Print statements are not supported. Ignoring: ${s.serialize}")
case other => other.foreachStmt(onStatement)
}
@@ -520,21 +589,22 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
// example:
// m.r.data <= m.r.data ; this is the dummy assign
// test <= m.r.data ; this is the first use of m.r.data
- private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if(unusedMemOutputs.nonEmpty) {
+ private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if (unusedMemOutputs.nonEmpty) {
implicit val uses = mutable.ArrayBuffer[String]()
findUnusedMemoryOutputUse(next)
- if(uses.nonEmpty) {
+ if (uses.nonEmpty) {
val useSet = uses.toSet
- unusedMemOutputs.foreach { case (name, width) =>
- if(useSet.contains(name)) connects.append(name -> BVSymbol(name, width))
+ unusedMemOutputs.foreach {
+ case (name, width) =>
+ if (useSet.contains(name)) connects.append(name -> BVSymbol(name, width))
}
useSet.foreach(name => unusedMemOutputs.remove(name))
}
}
private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match {
- case s : ir.SubField =>
+ case s: ir.SubField =>
val name = s.serialize
- if(unusedMemOutputs.contains(name)) uses.append(name)
+ if (unusedMemOutputs.contains(name)) uses.append(name)
case other => other.foreachExpr(findUnusedMemoryOutputUse)
}
@@ -555,17 +625,18 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
// TODO: ensure that we can generate unique names
prefix + "_" + msg.replace(" ", "_").replace("|", "")
}
- private def error(msg: String): Unit = throw new RuntimeException(msg)
+ private def error(msg: String): Unit = throw new RuntimeException(msg)
private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType]
- private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
+ private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType
}
private object TopologicalSort {
+
/** Ensures that all signals in the resulting system are topologically sorted.
* This is necessary because [[firrtl.transforms.RemoveWires]] does
* not sort assignments to outputs, submodule inputs nor memory ports.
- * */
+ */
def run(sys: TransitionSystem): TransitionSystem = {
val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
@@ -583,23 +654,24 @@ private object TopologicalSort {
val known = new mutable.HashSet[String]() ++ globalSignals
var needsReordering = false
val digraph = new MutableDiGraph[String]
- signals.foreach { case (name, expr) =>
- digraph.addVertex(name)
- val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
- uniqueDependencies.foreach { d =>
- if(!known.contains(d)) { needsReordering = true }
- digraph.addPairWithEdge(name, d)
- }
- known.add(name)
+ signals.foreach {
+ case (name, expr) =>
+ digraph.addVertex(name)
+ val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
+ uniqueDependencies.foreach { d =>
+ if (!known.contains(d)) { needsReordering = true }
+ digraph.addPairWithEdge(name, d)
+ }
+ known.add(name)
}
- if(needsReordering) {
+ if (needsReordering) {
Some(digraph.linearize.reverse)
} else { None }
}
private def findDependencies(expr: SMTExpr): List[String] = expr match {
- case BVSymbol(name, _) => List(name)
+ case BVSymbol(name, _) => List(name)
case ArraySymbol(name, _, _) => List(name)
- case other => other.children.flatMap(findDependencies)
+ case other => other.children.flatMap(findDependencies)
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
index 322b8961..1c7ea42f 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
@@ -11,8 +11,10 @@ import firrtl.options.Viewer.view
import firrtl.options.{CustomFileEmission, Dependency}
import firrtl.stage.FirrtlOptions
-
-private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform with Emitter with DependencyAPIMigration {
+private[firrtl] abstract class SMTEmitter private[firrtl] ()
+ extends Transform
+ with Emitter
+ with DependencyAPIMigration {
override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem))
override def invalidates(a: Transform): Boolean = false
@@ -30,16 +32,16 @@ private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform w
override protected def execute(state: CircuitState): CircuitState = {
val emitCircuit = state.annotations.exists {
- case EmitCircuitAnnotation(a) if this.getClass == a => true
+ case EmitCircuitAnnotation(a) if this.getClass == a => true
case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!")
- case _ => false
+ case _ => false
}
- if(!emitCircuit) { return state }
+ if (!emitCircuit) { return state }
logger.warn(BleedingEdgeWarning)
- val sys = state.annotations.collectFirst{ case TransitionSystemAnnotation(sys) => sys }.getOrElse {
+ val sys = state.annotations.collectFirst { case TransitionSystemAnnotation(sys) => sys }.getOrElse {
error("Could not find the transition system!")
}
state.copy(annotations = state.annotations :+ serialize(sys))
@@ -52,11 +54,12 @@ private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform w
}
case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String)
- extends NoTargetAnnotation with CustomFileEmission {
+ extends NoTargetAnnotation
+ with CustomFileEmission {
override protected def baseFileName(annotations: AnnotationSeq): String =
view[FirrtlOptions](annotations).outputFileName.getOrElse(name)
override protected def suffix: Option[String] = Some(outputSuffix)
- override def getBytes: Iterable[Byte] = src.getBytes
+ override def getBytes: Iterable[Byte] = src.getBytes
}
private[firrtl] class Btor2Emitter extends SMTEmitter {
@@ -72,14 +75,14 @@ private[firrtl] class SMTLibEmitter extends SMTEmitter {
override protected def serialize(sys: TransitionSystem): Annotation = {
val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr])
val logic = SMTLibSerializer.setLogic(hasMemory) + "\n"
- val header = if(hasMemory) {
+ val header = if (hasMemory) {
"; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" +
- "; see https://github.com/Z3Prover/z3/issues/1803\n" +
- "; for CVC4 you probably want to include the logic\n" +
- ";" + logic
+ "; see https://github.com/Z3Prover/z3/issues/1803\n" +
+ "; for CVC4 you probably want to include the logic\n" +
+ ";" + logic
} else { logic }
val smt = generatedHeader("SMT-LIBv2", sys.name) + header +
SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n"
EmittedSMTModelAnnotation(sys.name, smt, outputSuffix)
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
index 10a89e8d..ebb9e309 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
@@ -9,7 +9,7 @@ private sealed trait SMTExpr { def children: List[SMTExpr] }
private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String }
private object SMTSymbol {
def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match {
- case b: BVExpr => BVSymbol(name, b.width)
+ case b: BVExpr => BVSymbol(name, b.width)
case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth)
}
}
@@ -19,19 +19,19 @@ private sealed trait SMTNullaryExpr extends SMTExpr {
private sealed trait BVExpr extends SMTExpr { def width: Int }
private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr {
- private def minWidth = value.bitLength + (if(value <= 0) 1 else 0)
+ private def minWidth = value.bitLength + (if (value <= 0) 1 else 0)
assert(width > 0, "Zero or negative width literals are not allowed!")
assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.")
- override def toString: String = if(width <= 8) {
+ override def toString: String = if (width <= 8) {
width.toString + "'b" + value.toString(2)
} else { width.toString + "'x" + value.toString(16) }
}
private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol {
- assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
+ assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
assert(width > 0, "Zero width bit vectors are not supported!")
override def toString: String = name
- def toStringWithType: String = name + " : " + SMTExpr.serializeType(this)
+ def toStringWithType: String = name + " : " + SMTExpr.serializeType(this)
}
private sealed trait BVUnaryExpr extends BVExpr {
@@ -41,34 +41,35 @@ private sealed trait BVUnaryExpr extends BVExpr {
private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr {
assert(by >= 0, "Extension must be non-negative!")
override val width: Int = e.width + by
- override def toString: String = if(signed) { s"sext($e, $by)" } else { s"zext($e, $by)" }
+ override def toString: String = if (signed) { s"sext($e, $by)" }
+ else { s"zext($e, $by)" }
}
// also known as bit extract operation
private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr {
assert(lo >= 0, s"lo (lsb) must be non-negative!")
assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo")
assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi")
- override def width: Int = hi - lo + 1
- override def toString: String = if(hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]"
+ override def width: Int = hi - lo + 1
+ override def toString: String = if (hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]"
}
private case class BVNot(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
+ override val width: Int = e.width
override def toString: String = s"not($e)"
}
private case class BVNegate(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
+ override val width: Int = e.width
override def toString: String = s"neg($e)"
}
private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"redor($e)"
}
private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"redand($e)"
}
private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"redxor($e)"
}
@@ -79,12 +80,12 @@ private sealed trait BVBinaryExpr extends BVExpr {
}
private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!")
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"impl($a, $b)"
}
private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"eq($a, $b)"
}
private object Compare extends Enumeration {
@@ -94,8 +95,8 @@ private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed:
assert(a.width == b.width, s"Both argument need to be the same width!")
override def width: Int = 1
override def toString: String = op match {
- case Compare.Greater => (if(signed) "sgt" else "ugt") + s"($a, $b)"
- case Compare.GreaterEqual => (if(signed) "sgeq" else "ugeq") + s"($a, $b)"
+ case Compare.Greater => (if (signed) "sgt" else "ugt") + s"($a, $b)"
+ case Compare.GreaterEqual => (if (signed) "sgeq" else "ugeq") + s"($a, $b)"
}
}
private object Op extends Enumeration {
@@ -116,81 +117,87 @@ private object Op extends Enumeration {
}
private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
- override val width: Int = a.width
+ override val width: Int = a.width
override def toString: String = s"$op($a, $b)"
}
private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- override val width: Int = a.width + b.width
+ override val width: Int = a.width + b.width
override def toString: String = s"concat($a, $b)"
}
private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr {
assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
- override val width: Int = array.dataWidth
+ override val width: Int = array.dataWidth
override def toString: String = s"$array[$index]"
override def children: List[SMTExpr] = List(array, index)
}
private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr {
assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}")
- override val width: Int = tru.width
+ override val width: Int = tru.width
override def toString: String = s"ite($cond, $tru, $fals)"
override def children: List[BVExpr] = List(cond, tru, fals)
}
private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int }
private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol {
- assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
+ assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
override def toString: String = name
- def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>"
+ def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>"
}
private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr {
assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
assert(array.dataWidth == data.width, "Data with does not match expected array data width!")
- override val dataWidth: Int = array.dataWidth
+ override val dataWidth: Int = array.dataWidth
override val indexWidth: Int = array.indexWidth
- override def toString: String = s"$array[$index := $data]"
- override def children: List[SMTExpr] = List(array, index, data)
+ override def toString: String = s"$array[$index := $data]"
+ override def children: List[SMTExpr] = List(array, index, data)
}
private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr {
assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
- assert(tru.indexWidth == fals.indexWidth,
- s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}")
- assert(tru.dataWidth == fals.dataWidth,
- s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}")
- override val dataWidth: Int = tru.dataWidth
+ assert(
+ tru.indexWidth == fals.indexWidth,
+ s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}"
+ )
+ assert(
+ tru.dataWidth == fals.dataWidth,
+ s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}"
+ )
+ override val dataWidth: Int = tru.dataWidth
override val indexWidth: Int = tru.indexWidth
- override def toString: String = s"ite($cond, $tru, $fals)"
- override def children: List[SMTExpr] = List(cond, tru, fals)
+ override def toString: String = s"ite($cond, $tru, $fals)"
+ override def children: List[SMTExpr] = List(cond, tru, fals)
}
private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr {
assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!")
assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!")
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"eq($a, $b)"
override def children: List[SMTExpr] = List(a, b)
}
private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr {
override val dataWidth: Int = e.width
- override def toString: String = s"([$e] x ${ (BigInt(1) << indexWidth) })"
- override def children: List[SMTExpr] = List(e)
+ override def toString: String = s"([$e] x ${(BigInt(1) << indexWidth)})"
+ override def children: List[SMTExpr] = List(e)
}
private object SMTEqual {
- def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a,b) match {
- case (ab : BVExpr, bb : BVExpr) => BVEqual(ab, bb)
- case (aa : ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba)
+ def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a, b) match {
+ case (ab: BVExpr, bb: BVExpr) => BVEqual(ab, bb)
+ case (aa: ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba)
case _ => throw new RuntimeException(s"Cannot compare $a and $b")
}
}
private object SMTExpr {
def serializeType(e: SMTExpr): String = e match {
- case b: BVExpr => s"bv<${b.width}>"
+ case b: BVExpr => s"bv<${b.width}>"
case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>"
}
}
// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]]
private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr
-private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTNullaryExpr \ No newline at end of file
+private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int)
+ extends ArrayExpr
+ with SMTNullaryExpr
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
index 14e73253..defc787c 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
@@ -9,7 +9,7 @@ private object SMTExprVisitor {
type BVFun = BVExpr => BVExpr
def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match {
- case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
+ case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T]
}
def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T =
@@ -17,57 +17,56 @@ private object SMTExprVisitor {
private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match {
// nullary
- case old : BVLiteral => bv(old)
- case old : BVSymbol => bv(old)
- case old : BVRawExpr => bv(old)
+ case old: BVLiteral => bv(old)
+ case old: BVSymbol => bv(old)
+ case old: BVRawExpr => bv(old)
// unary
- case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed))
- case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo))
- case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n))
- case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n))
- case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n))
- case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n))
- case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n))
+ case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVExtend(n, by, signed))
+ case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVSlice(n, hi, lo))
+ case old @ BVNot(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNot(n))
+ case old @ BVNegate(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNegate(n))
+ case old @ BVReduceAnd(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceAnd(n))
+ case old @ BVReduceOr(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceOr(n))
+ case old @ BVReduceXor(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceXor(n))
// binary
case old @ BVImplies(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
case old @ BVEqual(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
case old @ ArrayEqual(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
case old @ BVComparison(op, a, b, signed) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
case old @ BVOp(op, a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
case old @ BVConcat(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
case old @ ArrayRead(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
// ternary
case old @ BVIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
+ bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
}
-
private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match {
- case old : ArrayRawExpr => ar(old)
- case old : ArraySymbol => ar(old)
+ case old: ArrayRawExpr => ar(old)
+ case old: ArraySymbol => ar(old)
case old @ ArrayConstant(e, indexWidth) =>
- val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth))
+ val n = map(e, bv, ar); ar(if (n.eq(e)) old else ArrayConstant(n, indexWidth))
case old @ ArrayStore(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
+ ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
case old @ ArrayIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
+ ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
}
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
index 1993da87..bd5e4d8c 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
@@ -6,83 +6,87 @@ package firrtl.backends.experimental.smt
import scala.util.matching.Regex
/** Converts STM Expressions to a SMTLib compatible string representation.
- * See http://smtlib.cs.uiowa.edu/
- * Assumes well typed expression, so it is advisable to run the TypeChecker
- * before serializing!
- * Automatically converts 1-bit vectors to bool.
- */
+ * See http://smtlib.cs.uiowa.edu/
+ * Assumes well typed expression, so it is advisable to run the TypeChecker
+ * before serializing!
+ * Automatically converts 1-bit vectors to bool.
+ */
private object SMTLibSerializer {
- def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if(hasMem) "A" else "") + "UFBV)"
+ def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if (hasMem) "A" else "") + "UFBV)"
def serialize(e: SMTExpr): String = e match {
- case b : BVExpr => serialize(b)
- case a : ArrayExpr => serialize(a)
+ case b: BVExpr => serialize(b)
+ case a: ArrayExpr => serialize(a)
}
def serializeType(e: SMTExpr): String = e match {
- case b : BVExpr => serializeBitVectorType(b.width)
- case a : ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
+ case b: BVExpr => serializeBitVectorType(b.width)
+ case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
}
private def serialize(e: BVExpr): String = e match {
case BVLiteral(value, width) =>
- val mask = (BigInt(1) << width) - 1
- val twosComplement = if(value < 0) { ((~(-value)) & mask) + 1 } else value
- if(width == 1) {
- if(twosComplement == 1) "true" else "false"
+ val mask = (BigInt(1) << width) - 1
+ val twosComplement = if (value < 0) { ((~(-value)) & mask) + 1 }
+ else value
+ if (width == 1) {
+ if (twosComplement == 1) "true" else "false"
} else {
s"(_ bv$twosComplement $width)"
}
- case BVSymbol(name, _) => escapeIdentifier(name)
- case BVExtend(e, 0, _) => serialize(e)
+ case BVSymbol(name, _) => escapeIdentifier(name)
+ case BVExtend(e, 0, _) => serialize(e)
case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by))
case BVExtend(e, by, signed) =>
- val foo = if(signed) "sign_extend" else "zero_extend"
+ val foo = if (signed) "sign_extend" else "zero_extend"
s"((_ $foo $by) ${asBitVector(e)})"
case BVSlice(e, hi, lo) =>
- if(lo == 0 && hi == e.width - 1) { serialize(e)
- } else {
+ if (lo == 0 && hi == e.width - 1) { serialize(e) }
+ else {
val bits = s"((_ extract $hi $lo) ${asBitVector(e)})"
// 1-bit extracts need to be turned into a boolean
- if(lo == hi) { toBool(bits) } else { bits }
+ if (lo == hi) { toBool(bits) }
+ else { bits }
}
case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})"
- case BVNot(BVNot(e)) => serialize(e)
- case BVNot(e) => if(e.width == 1) { s"(not ${serialize(e)})" } else { s"(bvnot ${serialize(e)})" }
+ case BVNot(BVNot(e)) => serialize(e)
+ case BVNot(e) =>
+ if (e.width == 1) { s"(not ${serialize(e)})" }
+ else { s"(bvnot ${serialize(e)})" }
case BVNegate(e) => s"(bvneg ${asBitVector(e)})"
case r: BVReduceAnd => serialize(Expander.expand(r))
- case r: BVReduceOr => serialize(Expander.expand(r))
+ case r: BVReduceOr => serialize(Expander.expand(r))
case r: BVReduceXor => serialize(Expander.expand(r))
- case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b)
- case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})"
- case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
- case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
- case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})"
+ case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b)
+ case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})"
+ case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
+ case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
+ case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})"
case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
// boolean operations get a special treatment for 1-bit vectors aka bools
case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})"
- case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
+ case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})"
- case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
- case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
- case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
- case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
- case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case BVRawExpr(serialized, _) => serialized
+ case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
+ case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
+ case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
+ case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
+ case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case BVRawExpr(serialized, _) => serialized
}
def serialize(e: ArrayExpr): String = e match {
- case ArraySymbol(name, _, _) => escapeIdentifier(name)
+ case ArraySymbol(name, _, _) => escapeIdentifier(name)
case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
- case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
+ case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
case ArrayRawExpr(serialized, _, _) => serialized
}
def serialize(c: SMTCommand): String = c match {
- case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
+ case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)"
case DefineFunction(name, args, e) =>
val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ")
@@ -95,23 +99,24 @@ private object SMTLibSerializer {
private def serializeArrayType(indexWidth: Int, dataWidth: Int): String =
s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})"
private def serializeBitVectorType(width: Int): String =
- if(width == 1) { "Bool" } else { assert(width > 1) ; s"(_ BitVec $width)" }
+ if (width == 1) { "Bool" }
+ else { assert(width > 1); s"(_ BitVec $width)" }
private def serialize(op: Op.Value): String = op match {
- case Op.And => "bvand"
- case Op.Or => "bvor"
- case Op.Xor => "bvxor"
+ case Op.And => "bvand"
+ case Op.Or => "bvor"
+ case Op.Xor => "bvxor"
case Op.ArithmeticShiftRight => "bvashr"
- case Op.ShiftRight => "bvlshr"
- case Op.ShiftLeft => "bvshl"
- case Op.Add => "bvadd"
- case Op.Mul => "bvmul"
- case Op.Sub => "bvsub"
- case Op.SignedDiv => "bvsdiv"
- case Op.UnsignedDiv => "bvudiv"
- case Op.SignedMod => "bvsmod"
- case Op.SignedRem => "bvsrem"
- case Op.UnsignedRem => "bvurem"
+ case Op.ShiftRight => "bvlshr"
+ case Op.ShiftLeft => "bvshl"
+ case Op.Add => "bvadd"
+ case Op.Mul => "bvmul"
+ case Op.Sub => "bvsub"
+ case Op.SignedDiv => "bvsdiv"
+ case Op.UnsignedDiv => "bvudiv"
+ case Op.SignedMod => "bvsmod"
+ case Op.SignedRem => "bvsrem"
+ case Op.UnsignedRem => "bvurem"
}
private def toBool(e: String): String = s"(= $e (_ bv1 1))"
@@ -119,33 +124,37 @@ private object SMTLibSerializer {
private val bvZero = "(_ bv0 1)"
private val bvOne = "(_ bv1 1)"
private def asBitVector(e: BVExpr): String =
- if(e.width > 1) { serialize(e) } else { s"(ite ${serialize(e)} $bvOne $bvZero)" }
+ if (e.width > 1) { serialize(e) }
+ else { s"(ite ${serialize(e)} $bvOne $bvZero)" }
// See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec
private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.\$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.\$$_~&\^<>@]*".r
def escapeIdentifier(name: String): String = name match {
case simple() => name
- case _ => if(name.startsWith("|") && name.endsWith("|")) name else s"|$name|"
+ case _ => if (name.startsWith("|") && name.endsWith("|")) name else s"|$name|"
}
}
/** Expands expressions that are not natively supported by SMTLib */
private object Expander {
def expand(r: BVReduceAnd): BVExpr = {
- if(r.e.width == 1) { r.e } else {
+ if (r.e.width == 1) { r.e }
+ else {
val allOnes = (BigInt(1) << r.e.width) - 1
BVEqual(r.e, BVLiteral(allOnes, r.e.width))
}
}
def expand(r: BVReduceOr): BVExpr = {
- if(r.e.width == 1) { r.e } else {
+ if (r.e.width == 1) { r.e }
+ else {
BVNot(BVEqual(r.e, BVLiteral(0, r.e.width)))
}
}
def expand(r: BVReduceXor): BVExpr = {
- if(r.e.width == 1) { r.e } else {
+ if (r.e.width == 1) { r.e }
+ else {
val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii))
- bits.reduce[BVExpr]((a,b) => BVOp(Op.Xor, a, b))
+ bits.reduce[BVExpr]((a, b) => BVOp(Op.Xor, a, b))
}
}
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
index e9acc05b..4c60a1b0 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
@@ -10,7 +10,7 @@ import scala.collection.mutable
* It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state
* prevents this encoding from working with boolector.
* For simplicity reasons, we do not support hierarchical designs (no `_h` function).
- * */
+ */
private object SMTTransitionSystemEncoder {
def encode(sys: TransitionSystem): Iterable[SMTCommand] = {
@@ -38,10 +38,10 @@ private object SMTTransitionSystemEncoder {
cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e))
}
sys.signals.foreach { signal =>
- val kind = if(sys.outputs.contains(signal.name)) { "output"
- } else if(sys.assumes.contains(signal.name)) { "assume"
- } else if(sys.asserts.contains(signal.name)) { "assert"
- } else { "wire" }
+ val kind = if (sys.outputs.contains(signal.name)) { "output" }
+ else if (sys.assumes.contains(signal.name)) { "assume" }
+ else if (sys.asserts.contains(signal.name)) { "assert" }
+ else { "wire" }
val sym = SMTSymbol.fromExpr(signal.name, signal.e)
cmds ++= toDescription(sym, kind, sys.comments.get)
define(sym, signal.e)
@@ -105,18 +105,18 @@ private object SMTTransitionSystemEncoder {
}
private def andReduce(e: Iterable[BVExpr]): BVExpr =
- if(e.isEmpty) BVLiteral(1, 1) else e.reduce((a,b) => BVOp(Op.And, a, b))
+ if (e.isEmpty) BVLiteral(1, 1) else e.reduce((a, b) => BVOp(Op.And, a, b))
// All signals are modelled with functions that need to be called with the state as argument,
// this replaces all Symbols with function applications to the state.
private def replaceSymbols(e: SMTExpr): SMTExpr = {
SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e)
}
- private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr]
+ private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr]
private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match {
- case BVSymbol(name, width) => BVRawExpr(s"(${id(name+suffix)} $arg)", width)
- case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name+suffix)} $arg)", indexWidth, dataWidth)
- case other => other
+ case BVSymbol(name, width) => BVRawExpr(s"(${id(name + suffix)} $arg)", width)
+ case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name + suffix)} $arg)", indexWidth, dataWidth)
+ case other => other
}
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
index d8e203f8..95db95ef 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
@@ -3,7 +3,7 @@
package firrtl.backends.experimental.smt
-import firrtl.{CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils, ir}
+import firrtl.{ir, CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils}
import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation}
import firrtl.ir.EmptyStmt
import firrtl.options.Dependency
@@ -32,16 +32,17 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// since this pass only runs on the main module, inlining needs to happen before
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
-
override protected def execute(state: CircuitState): CircuitState = {
- if(state.circuit.modules.size > 1) {
- logger.warn("WARN: StutteringClockTransform currently only supports running on a single module.\n" +
- s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want.")
+ if (state.circuit.modules.size > 1) {
+ logger.warn(
+ "WARN: StutteringClockTransform currently only supports running on a single module.\n" +
+ s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want."
+ )
}
// get main module
val main = state.circuit.modules.find(_.name == state.circuit.main).get match {
- case m: ir.Module => m
+ case m: ir.Module => m
case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e")
}
mainName = main.name
@@ -64,19 +65,21 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// replace all other clocks with enable signals, unless they are the global clock
val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name)
- val clockToEnable = clocks.map{c =>
+ val clockToEnable = clocks.map { c =>
c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow)
}.toMap
val portsWithEnableSignals = portsWithGlobalClock.map { p =>
- if(clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } else { p }
+ if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) }
+ else { p }
}
// replace async reset with synchronous reset (since everything will we synchronous with the global clock)
// unless it is a preset reset
val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name)
- val isPresetReset = state.annotations.collect{ case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
+ val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
val resetsToChange = asyncResets.filterNot(isPresetReset).toSet
val portsWithSyncReset = portsWithEnableSignals.map { p =>
- if(resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } else { p }
+ if (resetsToChange.contains(p.name)) { p.copy(tpe = Bool) }
+ else { p }
}
// discover clock and reset connections
@@ -85,8 +88,9 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// rename clocks to clock enable signals
val mRef = CircuitTarget(state.circuit.main).module(main.name)
val renameMap = RenameMap()
- scan.clockToEnable.foreach { case (clk, en) =>
- renameMap.record(mRef.ref(clk), mRef.ref(en.name))
+ scan.clockToEnable.foreach {
+ case (clk, en) =>
+ renameMap.record(mRef.ref(clk), mRef.ref(en.name))
}
// make changes
@@ -103,51 +107,58 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
s match {
// memory field connects
case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _)
- if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) =>
+ if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) =>
// replace clock with the global clock
- if(field == "clk") {
+ if (field == "clk") {
c.copy(expr = ctx.globalClock)
- } else if(field == "en") {
+ } else if (field == "en") {
val m = ctx.memInfo(mem)
val isWritePort = m.writers.contains(port)
assert(isWritePort || m.readers.contains(port))
// for write ports we guard the write enable with the clock enable signal, similar to registers
- if(isWritePort) {
+ if (isWritePort) {
val clockEn = ctx.memPortToClockEnable(mem + "." + port)
val guardedEnable = and(clockEn, c.expr)
c.copy(expr = guardedEnable)
} else { c }
- } else { c}
+ } else { c }
// register field connects
- case c @ ir.Connect(_, r : ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
+ case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
val clockEnable = ctx.registerToEnable(r.name)
val guardedNext = mux(clockEnable, next, r)
c.copy(expr = guardedNext)
// remove other clock wires and nodes
case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt
- case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
- case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
+ case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
+ case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
// change async reset to synchronous reset
- case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
- ir.Connect(info, loc.copy(tpe=Bool), expr.copy(tpe=Bool))
- case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
- d.copy(value = value.copy(tpe=Bool))
- case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe=Bool)
+ case ir.Connect(info, loc: ir.Reference, expr: ir.Reference)
+ if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
+ ir.Connect(info, loc.copy(tpe = Bool), expr.copy(tpe = Bool))
+ case d @ ir.DefNode(_, name, value: ir.Reference)
+ if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
+ d.copy(value = value.copy(tpe = Bool))
+ case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe = Bool)
// change memory clock and synchronize reset
case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) =>
val clockEnable = ctx.registerToEnable(name)
val newReset = reset match {
- case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe=Bool)
- case other => other
+ case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe = Bool)
+ case other => other
}
- val synchronizedReset = if(reset.tpe == ir.AsyncResetType) { newReset } else { and(newReset, clockEnable) }
+ val synchronizedReset = if (reset.tpe == ir.AsyncResetType) { newReset }
+ else { and(newReset, clockEnable) }
ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init)
case other => other.mapStmt(onStatement)
}
}
- private def scanClocks(m: ir.Module, initialClockToEnable: Map[String, ir.Reference], resetsToChange: Set[String]): ScanCtx = {
+ private def scanClocks(
+ m: ir.Module,
+ initialClockToEnable: Map[String, ir.Reference],
+ resetsToChange: Set[String]
+ ): ScanCtx = {
implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange)
m.foreachStmt(scanClocksAndResets)
ctx
@@ -162,9 +173,9 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
ctx.clockToEnable.get(expr.serialize).foreach { clockEn =>
ctx.clockToEnable(locName) = clockEn
// keep track of memory clocks
- if(loc.isInstanceOf[ir.SubField]) {
+ if (loc.isInstanceOf[ir.SubField]) {
val parts = locName.split('.')
- if(ctx.mems.contains(parts.head)) {
+ if (ctx.mems.contains(parts.head)) {
assert(parts.length == 3 && parts.last == "clk")
ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn)
}
@@ -182,11 +193,11 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
ctx.clockToEnable.get(clock.serialize).foreach { clockEnable =>
ctx.registerToEnable.append(name -> clockEnable)
}
- case m : ir.DefMemory =>
+ case m: ir.DefMemory =>
assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!")
assert(m.writeLatency == 1, "Only write-latency 1 is supported!")
- if(m.readers.nonEmpty && m.readLatency == 1) {
+ if (m.readers.nonEmpty && m.readLatency == 1) {
unsupportedError("Registers memory read ports are not properly implemented yet :(")
}
ctx.mems(m.name) = m
@@ -233,8 +244,8 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// memory enables which need to be guarded with clock enables
val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap
// keep track of memory names
- val isMem: String => Boolean = scanResults.mems.contains
- val memInfo: String => ir.DefMemory = scanResults.mems
+ val isMem: String => Boolean = scanResults.mems.contains
+ val memInfo: String => ir.DefMemory = scanResults.mems
val isResetToChange: String => Boolean = scanResults.resetsToChange.contains
}
@@ -250,4 +261,4 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
private val Bool = ir.UIntType(ir.IntWidth(1))
}
-private class UnsupportedFeatureException(s: String) extends PassException(s) \ No newline at end of file
+private class UnsupportedFeatureException(s: String) extends PassException(s)
diff --git a/src/main/scala/firrtl/checks/CheckResets.scala b/src/main/scala/firrtl/checks/CheckResets.scala
index 06bd5cba..a17e3e7b 100644
--- a/src/main/scala/firrtl/checks/CheckResets.scala
+++ b/src/main/scala/firrtl/checks/CheckResets.scala
@@ -14,8 +14,8 @@ import scala.collection.mutable
import scala.annotation.tailrec
object CheckResets {
- class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException(
- s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'")
+ class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String)
+ extends PassException(s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'")
// Map of Initialization Expression to check
private type RegCheckList = mutable.ListBuffer[(Expression, DefRegister)]
@@ -31,9 +31,11 @@ object CheckResets {
class CheckResets extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
- Dependency(firrtl.transforms.RemoveReset) ) ++ firrtl.stage.Forms.MidForm
+ Seq(
+ Dependency(passes.LowerTypes),
+ Dependency(passes.Legalize),
+ Dependency(firrtl.transforms.RemoveReset)
+ ) ++ firrtl.stage.Forms.MidForm
override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.CheckCombLoops])
@@ -45,10 +47,10 @@ class CheckResets extends Transform with DependencyAPIMigration {
private def onStmt(regCheck: RegCheckList, drivers: DirectDriverMap)(stmt: Statement): Unit = {
stmt match {
- case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr
- case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs
- case reg @ DefRegister(_, name, _,_,_, init) if weq(WRef(name), init) => // Self-reset, allowed!
- case reg @ DefRegister(_,_,_,_, reset, init) if reset.tpe == AsyncResetType =>
+ case DefNode(_, name, expr) => drivers += we(WRef(name)) -> expr
+ case Connect(_, lhs, rhs) => drivers += we(lhs) -> rhs
+ case reg @ DefRegister(_, name, _, _, _, init) if weq(WRef(name), init) => // Self-reset, allowed!
+ case reg @ DefRegister(_, _, _, _, reset, init) if reset.tpe == AsyncResetType =>
regCheck += init -> reg
case _ => // Do nothing
}
@@ -60,11 +62,12 @@ class CheckResets extends Transform with DependencyAPIMigration {
@tailrec
private def findDriver(drivers: DirectDriverMap)(expr: Expression): Expression = expr match {
case lit: Literal => lit
- case DoPrim(op, args, _,_) if isCast(op) => findDriver(drivers)(args.head)
- case other => drivers.get(we(other)) match {
- case Some(e) if wireOrNode(Utils.kind(other)) => findDriver(drivers)(e)
- case _ => other
- }
+ case DoPrim(op, args, _, _) if isCast(op) => findDriver(drivers)(args.head)
+ case other =>
+ drivers.get(we(other)) match {
+ case Some(e) if wireOrNode(Utils.kind(other)) => findDriver(drivers)(e)
+ case _ => other
+ }
}
private def onMod(errors: Errors)(mod: DefModule): Unit = {
diff --git a/src/main/scala/firrtl/constraint/Constraint.scala b/src/main/scala/firrtl/constraint/Constraint.scala
index 247593ee..1a3bc21a 100644
--- a/src/main/scala/firrtl/constraint/Constraint.scala
+++ b/src/main/scala/firrtl/constraint/Constraint.scala
@@ -12,7 +12,7 @@ trait Constraint {
/** Trait for constraints with more than one argument */
trait MultiAry extends Constraint {
- def op(a: IsKnown, b: IsKnown): IsKnown
+ def op(a: IsKnown, b: IsKnown): IsKnown
def merge(b1: Option[IsKnown], b2: Option[IsKnown]): Option[IsKnown] = (b1, b2) match {
case (Some(x), Some(y)) => Some(op(x, y))
case (_, y: Some[_]) => y
diff --git a/src/main/scala/firrtl/constraint/ConstraintSolver.scala b/src/main/scala/firrtl/constraint/ConstraintSolver.scala
index a421ae17..64271ae1 100644
--- a/src/main/scala/firrtl/constraint/ConstraintSolver.scala
+++ b/src/main/scala/firrtl/constraint/ConstraintSolver.scala
@@ -24,7 +24,6 @@ class ConstraintSolver {
type ConstraintMap = mutable.HashMap[String, (Constraint, Boolean)]
private val solvedConstraintMap = new ConstraintMap()
-
/** Clear all previously recorded/solved constraints */
def clear(): Unit = {
constraints.clear()
@@ -78,7 +77,7 @@ class ConstraintSolver {
def get(b: Constraint): Option[IsKnown] = {
val name = b match {
case IsVar(name) => name
- case x => ""
+ case x => ""
}
solvedConstraintMap.get(name) match {
case None => None
@@ -94,7 +93,7 @@ class ConstraintSolver {
def get(b: Width): Option[IsKnown] = {
val name = b match {
case IsVar(name) => name
- case x => ""
+ case x => ""
}
solvedConstraintMap.get(name) match {
case None => None
@@ -103,10 +102,8 @@ class ConstraintSolver {
}
}
-
private def add(c: Inequality) = constraints += c
-
/** Creates an Inequality given a variable name, constraint, and whether its >= or <=
* @param left
* @param right
@@ -114,7 +111,7 @@ class ConstraintSolver {
* @return
*/
private def genConst(left: String, right: Constraint, geq: Boolean): Inequality = geq match {
- case true => GreaterOrEqual(left, right)
+ case true => GreaterOrEqual(left, right)
case false => LesserOrEqual(left, right)
}
@@ -122,14 +119,13 @@ class ConstraintSolver {
def serializeConstraints: String = constraints.mkString("\n")
/** For debugging, can serialize the solved constraints */
- def serializeSolutions: String = solvedConstraintMap.map{
+ def serializeSolutions: String = solvedConstraintMap.map {
case (k, (v, true)) => s"$k >= ${v.serialize}"
case (k, (v, false)) => s"$k <= ${v.serialize}"
}.mkString("\n")
-
-
- /************* Constraint Solver Engine ****************/
+ /** *********** Constraint Solver Engine ***************
+ */
/** Merges constraints on the same variable
*
@@ -148,17 +144,16 @@ class ConstraintSolver {
private def mergeConstraints(constraints: Seq[Inequality]): Seq[Inequality] = {
val mergedMap = mutable.HashMap[String, Inequality]()
constraints.foreach {
- case c if c.geq && mergedMap.contains(c.left) =>
- mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true)
- case c if !c.geq && mergedMap.contains(c.left) =>
- mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false)
- case c =>
- mergedMap(c.left) = c
+ case c if c.geq && mergedMap.contains(c.left) =>
+ mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true)
+ case c if !c.geq && mergedMap.contains(c.left) =>
+ mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false)
+ case c =>
+ mergedMap(c.left) = c
}
mergedMap.values.toList
}
-
/** Attempts to substitute variables with their corresponding forward-solved constraints
* If no corresponding constraint has been visited yet, keep variable as is
*
@@ -167,15 +162,16 @@ class ConstraintSolver {
* @return Forward solved constraint
*/
private def forwardSubstitution(forwardSolved: ConstraintMap)(constraint: Constraint): Constraint = {
- val x = constraint map forwardSubstitution(forwardSolved)
+ val x = constraint.map(forwardSubstitution(forwardSolved))
x match {
- case isVar: IsVar => forwardSolved get isVar.name match {
- case None => isVar.asInstanceOf[Constraint]
- case Some((p, geq)) =>
- val newT = forwardSubstitution(forwardSolved)(p)
- forwardSolved(isVar.name) = (newT, geq)
- newT
- }
+ case isVar: IsVar =>
+ forwardSolved.get(isVar.name) match {
+ case None => isVar.asInstanceOf[Constraint]
+ case Some((p, geq)) =>
+ val newT = forwardSubstitution(forwardSolved)(p)
+ forwardSolved(isVar.name) = (newT, geq)
+ newT
+ }
case other => other
}
}
@@ -190,11 +186,12 @@ class ConstraintSolver {
*/
private def backwardSubstitution(backwardSolved: ConstraintMap)(constraint: Constraint): Constraint = {
constraint match {
- case isVar: IsVar => backwardSolved.get(isVar.name) match {
- case Some((p, geq)) => p
- case _ => isVar
- }
- case other => other map backwardSubstitution(backwardSolved)
+ case isVar: IsVar =>
+ backwardSolved.get(isVar.name) match {
+ case Some((p, geq)) => p
+ case _ => isVar
+ }
+ case other => other.map(backwardSubstitution(backwardSolved))
}
}
@@ -211,7 +208,7 @@ class ConstraintSolver {
* @return
*/
private def removeCycle(name: String, geq: Boolean)(constraint: Constraint): Constraint =
- if(geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint)
+ if (geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint)
/** Removes solvable cycles of <= inequalities
* @param name Name of the variable on left side of inequality
@@ -220,7 +217,7 @@ class ConstraintSolver {
*/
private def removeLeqCycle(name: String)(constraint: Constraint): Constraint = constraint match {
case x if greaterEqThan(name)(x) => VarCon(name)
- case isMin: IsMin => IsMin(isMin.children.filter{ c => !greaterEqThan(name)(c)})
+ case isMin: IsMin => IsMin(isMin.children.filter { c => !greaterEqThan(name)(c) })
case x => x
}
@@ -231,43 +228,48 @@ class ConstraintSolver {
*/
private def removeGeqCycle(name: String)(constraint: Constraint): Constraint = constraint match {
case x if lessEqThan(name)(x) => VarCon(name)
- case isMax: IsMax => IsMax(isMax.children.filter{c => !lessEqThan(name)(c)})
+ case isMax: IsMax => IsMax(isMax.children.filter { c => !lessEqThan(name)(c) })
case x => x
}
private def greaterEqThan(name: String)(constraint: Constraint): Boolean = constraint match {
case isMin: IsMin => isMin.children.map(greaterEqThan(name)).reduce(_ && _)
- case isAdd: IsAdd => isAdd.children match {
- case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
- case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
- case _ => false
- }
- case isMul: IsMul => isMul.children match {
- case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
- case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
- case _ => false
- }
+ case isAdd: IsAdd =>
+ isAdd.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
+ case _ => false
+ }
+ case isMul: IsMul =>
+ isMul.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
+ case _ => false
+ }
case isVar: IsVar if isVar.name == name => true
case _ => false
}
private def lessEqThan(name: String)(constraint: Constraint): Boolean = constraint match {
case isMax: IsMax => isMax.children.map(lessEqThan(name)).reduce(_ && _)
- case isAdd: IsAdd => isAdd.children match {
- case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
- case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
- case _ => false
- }
- case isMul: IsMul => isMul.children match {
- case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
- case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
- case _ => false
- }
+ case isAdd: IsAdd =>
+ isAdd.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
+ case _ => false
+ }
+ case isMul: IsMul =>
+ isMul.children match {
+ case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
+ case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
+ case _ => false
+ }
case isVar: IsVar if isVar.name == name => true
- case isNeg: IsNeg => isNeg.child match {
- case isVar: IsVar if isVar.name == name => true
- case _ => false
- }
+ case isNeg: IsNeg =>
+ isNeg.child match {
+ case isVar: IsVar if isVar.name == name => true
+ case _ => false
+ }
case _ => false
}
@@ -283,7 +285,7 @@ class ConstraintSolver {
case isVar: IsVar if isVar.name == name => has = true
case _ =>
}
- constraint map rec
+ constraint.map(rec)
}
rec(constraint)
has
@@ -300,7 +302,7 @@ class ConstraintSolver {
checkMap(c.left) = c
seq ++ Nil
case Some(x) if x.geq != c.geq => seq ++ Seq(x, c)
- case Some(x) => seq ++ Nil
+ case Some(x) => seq ++ Nil
}
}
}
diff --git a/src/main/scala/firrtl/constraint/Inequality.scala b/src/main/scala/firrtl/constraint/Inequality.scala
index 0fa1d2eb..a01b7c85 100644
--- a/src/main/scala/firrtl/constraint/Inequality.scala
+++ b/src/main/scala/firrtl/constraint/Inequality.scala
@@ -6,9 +6,9 @@ package firrtl.constraint
* Is passed to the constraint solver to resolve
*/
trait Inequality {
- def left: String
+ def left: String
def right: Constraint
- def geq: Boolean
+ def geq: Boolean
}
case class GreaterOrEqual(left: String, right: Constraint) extends Inequality {
@@ -20,5 +20,3 @@ case class LesserOrEqual(left: String, right: Constraint) extends Inequality {
val geq = false
override def toString: String = s"$left <= ${right.serialize}"
}
-
-
diff --git a/src/main/scala/firrtl/constraint/IsAdd.scala b/src/main/scala/firrtl/constraint/IsAdd.scala
index e177a8b9..9305db89 100644
--- a/src/main/scala/firrtl/constraint/IsAdd.scala
+++ b/src/main/scala/firrtl/constraint/IsAdd.scala
@@ -1,39 +1,38 @@
// See LICENSE for license details.
-
package firrtl.constraint
// Is case class because writing tests is easier due to equality is not object equality
-case class IsAdd private (known: Option[IsKnown],
- maxs: Vector[IsMax],
- mins: Vector[IsMin],
- others: Vector[Constraint]) extends Constraint with MultiAry {
+case class IsAdd private (known: Option[IsKnown], maxs: Vector[IsMax], mins: Vector[IsMin], others: Vector[Constraint])
+ extends Constraint
+ with MultiAry {
def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 + b2
lazy val children: Vector[Constraint] = {
- if(known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others
+ if (known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others
}
def addChild(x: Constraint): IsAdd = x match {
- case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others)
- case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others)
- case max: IsMax => new IsAdd(known, maxs :+ max, mins, others)
- case min: IsMin => new IsAdd(known, maxs, mins :+ min, others)
- case other => new IsAdd(known, maxs, mins, others :+ other)
+ case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others)
+ case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others)
+ case max: IsMax => new IsAdd(known, maxs :+ max, mins, others)
+ case min: IsMin => new IsAdd(known, maxs, mins :+ min, others)
+ case other => new IsAdd(known, maxs, mins, others :+ other)
}
override def serialize: String = "(" + children.map(_.serialize).mkString(" + ") + ")"
- override def map(f: Constraint=>Constraint): Constraint = IsAdd(children.map(f))
+ override def map(f: Constraint => Constraint): Constraint = IsAdd(children.map(f))
def reduce(): Constraint = {
- if(children.size == 1) children.head else {
+ if (children.size == 1) children.head
+ else {
(known, maxs, mins, others) match {
case (Some(k), _, _, _) if k.value == 0 => new IsAdd(None, maxs, mins, others).reduce()
case (Some(k), Vector(max), Vector(), Vector()) => max.map { o => IsAdd(k, o) }.reduce()
case (Some(k), Vector(), Vector(min), Vector()) => min.map { o => IsAdd(k, o) }.reduce()
- case _ => this
+ case _ => this
}
}
}
@@ -45,8 +44,10 @@ object IsAdd {
case _ => apply(Seq(left, right))
}
def apply(children: Seq[Constraint]): Constraint = {
- children.foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) =>
- add.addChild(c)
- }.reduce()
+ children
+ .foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) =>
+ add.addChild(c)
+ }
+ .reduce()
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/constraint/IsFloor.scala b/src/main/scala/firrtl/constraint/IsFloor.scala
index 5de4697e..60f049bb 100644
--- a/src/main/scala/firrtl/constraint/IsFloor.scala
+++ b/src/main/scala/firrtl/constraint/IsFloor.scala
@@ -10,13 +10,13 @@ case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint
override def reduce(): Constraint = child match {
case k: IsKnown => k.floor
- case x: IsAdd => this
- case x: IsMul => this
- case x: IsNeg => this
- case x: IsPow => this
+ case x: IsAdd => this
+ case x: IsMul => this
+ case x: IsNeg => this
+ case x: IsPow => this
// floor(max(a, b)) -> max(floor(a), floor(b))
- case x: IsMax => IsMax(x.children.map {b => IsFloor(b)})
- case x: IsMin => IsMin(x.children.map {b => IsFloor(b)})
+ case x: IsMax => IsMax(x.children.map { b => IsFloor(b) })
+ case x: IsMin => IsMin(x.children.map { b => IsFloor(b) })
case x: IsVar => this
// floor(floor(x)) -> floor(x)
case x: IsFloor => x
@@ -24,9 +24,7 @@ case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint
}
val children = Vector(child)
- override def map(f: Constraint=>Constraint): Constraint = IsFloor(f(child))
+ override def map(f: Constraint => Constraint): Constraint = IsFloor(f(child))
override def serialize: String = "floor(" + child.serialize + ")"
}
-
-
diff --git a/src/main/scala/firrtl/constraint/IsKnown.scala b/src/main/scala/firrtl/constraint/IsKnown.scala
index 5bd25f92..07e0531c 100644
--- a/src/main/scala/firrtl/constraint/IsKnown.scala
+++ b/src/main/scala/firrtl/constraint/IsKnown.scala
@@ -34,11 +34,9 @@ trait IsKnown extends Constraint {
/** Floor */
def floor: IsKnown
- override def map(f: Constraint=>Constraint): Constraint = this
+ override def map(f: Constraint => Constraint): Constraint = this
val children: Vector[Constraint] = Vector.empty[Constraint]
def reduce(): IsKnown = this
}
-
-
diff --git a/src/main/scala/firrtl/constraint/IsMax.scala b/src/main/scala/firrtl/constraint/IsMax.scala
index 3f24b7c0..0ba20c08 100644
--- a/src/main/scala/firrtl/constraint/IsMax.scala
+++ b/src/main/scala/firrtl/constraint/IsMax.scala
@@ -4,7 +4,7 @@ package firrtl.constraint
object IsMax {
def apply(left: Constraint, right: Constraint): Constraint = (left, right) match {
- case (l: IsKnown, r: IsKnown) => l max r
+ case (l: IsKnown, r: IsKnown) => l.max(r)
case _ => apply(Seq(left, right))
}
def apply(children: Seq[Constraint]): Constraint = {
@@ -15,33 +15,32 @@ object IsMax {
}
}
-case class IsMax private[constraint](known: Option[IsKnown],
- mins: Vector[IsMin],
- others: Vector[Constraint]
- ) extends MultiAry {
+case class IsMax private[constraint] (known: Option[IsKnown], mins: Vector[IsMin], others: Vector[Constraint])
+ extends MultiAry {
- def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 max b2
+ def op(b1: IsKnown, b2: IsKnown): IsKnown = b1.max(b2)
override def serialize: String = "max(" + children.map(_.serialize).mkString(", ") + ")"
- override def map(f: Constraint=>Constraint): Constraint = IsMax(children.map(f))
+ override def map(f: Constraint => Constraint): Constraint = IsMax(children.map(f))
lazy val children: Vector[Constraint] = {
- if(known.nonEmpty) known.get +: (mins ++ others) else mins ++ others
+ if (known.nonEmpty) known.get +: (mins ++ others) else mins ++ others
}
def reduce(): Constraint = {
- if(children.size == 1) children.head else {
+ if (children.size == 1) children.head
+ else {
(known, mins, others) match {
case (Some(IsKnown(a)), _, _) =>
// Eliminate minimums who have a known minimum value which is smaller than known maximum value
val filteredMins = mins.filter {
case IsMin(Some(IsKnown(i)), _, _) if i <= a => false
- case other => true
+ case other => true
}
// If a successful filter, rerun reduce
val newMax = new IsMax(known, filteredMins, others)
- if(filteredMins.size != mins.size) {
+ if (filteredMins.size != mins.size) {
newMax.reduce()
} else newMax
case _ => this
@@ -50,10 +49,9 @@ case class IsMax private[constraint](known: Option[IsKnown],
}
def addChild(x: Constraint): IsMax = x match {
- case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others)
- case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others)
- case min: IsMin => new IsMax(known, mins :+ min, others)
- case other => new IsMax(known, mins, others :+ other)
+ case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others)
+ case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others)
+ case min: IsMin => new IsMax(known, mins :+ min, others)
+ case other => new IsMax(known, mins, others :+ other)
}
}
-
diff --git a/src/main/scala/firrtl/constraint/IsMin.scala b/src/main/scala/firrtl/constraint/IsMin.scala
index ee97e298..2c5db14d 100644
--- a/src/main/scala/firrtl/constraint/IsMin.scala
+++ b/src/main/scala/firrtl/constraint/IsMin.scala
@@ -4,43 +4,44 @@ package firrtl.constraint
object IsMin {
def apply(left: Constraint, right: Constraint): Constraint = (left, right) match {
- case (l: IsKnown, r: IsKnown) => l min r
+ case (l: IsKnown, r: IsKnown) => l.min(r)
case _ => apply(Seq(left, right))
}
def apply(children: Seq[Constraint]): Constraint = {
- children.foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) =>
- add.addChild(c)
- }.reduce()
+ children
+ .foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) =>
+ add.addChild(c)
+ }
+ .reduce()
}
}
-case class IsMin private[constraint](known: Option[IsKnown],
- maxs: Vector[IsMax],
- others: Vector[Constraint]
- ) extends MultiAry {
+case class IsMin private[constraint] (known: Option[IsKnown], maxs: Vector[IsMax], others: Vector[Constraint])
+ extends MultiAry {
- def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 min b2
+ def op(b1: IsKnown, b2: IsKnown): IsKnown = b1.min(b2)
override def serialize: String = "min(" + children.map(_.serialize).mkString(", ") + ")"
- override def map(f: Constraint=>Constraint): Constraint = IsMin(children.map(f))
+ override def map(f: Constraint => Constraint): Constraint = IsMin(children.map(f))
lazy val children: Vector[Constraint] = {
- if(known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others
+ if (known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others
}
def reduce(): Constraint = {
- if(children.size == 1) children.head else {
+ if (children.size == 1) children.head
+ else {
(known, maxs, others) match {
case (Some(IsKnown(i)), _, _) =>
// Eliminate maximums who have a known maximum value which is larger than known minimum value
val filteredMaxs = maxs.filter {
case IsMax(Some(IsKnown(a)), _, _) if a >= i => false
- case other => true
+ case other => true
}
// If a successful filter, rerun reduce
val newMin = new IsMin(known, filteredMaxs, others)
- if(filteredMaxs.size != maxs.size) {
+ if (filteredMaxs.size != maxs.size) {
newMin.reduce()
} else newMin
case _ => this
@@ -49,9 +50,9 @@ case class IsMin private[constraint](known: Option[IsKnown],
}
def addChild(x: Constraint): IsMin = x match {
- case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others)
- case max: IsMax => new IsMin(known, maxs :+ max, others)
- case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others)
- case other => new IsMin(known, maxs, others :+ other)
+ case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others)
+ case max: IsMax => new IsMin(known, maxs :+ max, others)
+ case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others)
+ case other => new IsMin(known, maxs, others :+ other)
}
}
diff --git a/src/main/scala/firrtl/constraint/IsMul.scala b/src/main/scala/firrtl/constraint/IsMul.scala
index 3f637d75..a4acd74c 100644
--- a/src/main/scala/firrtl/constraint/IsMul.scala
+++ b/src/main/scala/firrtl/constraint/IsMul.scala
@@ -10,9 +10,11 @@ object IsMul {
case _ => apply(Seq(left, right))
}
def apply(children: Seq[Constraint]): Constraint = {
- children.foldLeft(new IsMul(None, Vector())) { (add, c) =>
- add.addChild(c)
- }.reduce()
+ children
+ .foldLeft(new IsMul(None, Vector())) { (add, c) =>
+ add.addChild(c)
+ }
+ .reduce()
}
}
@@ -20,19 +22,20 @@ case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) ex
def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 * b2
- lazy val children: Vector[Constraint] = if(known.nonEmpty) known.get +: others else others
+ lazy val children: Vector[Constraint] = if (known.nonEmpty) known.get +: others else others
def addChild(x: Constraint): IsMul = x match {
- case k: IsKnown => new IsMul(known = merge(Some(k), known), others)
- case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others)
- case other => new IsMul(known, others :+ other)
+ case k: IsKnown => new IsMul(known = merge(Some(k), known), others)
+ case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others)
+ case other => new IsMul(known, others :+ other)
}
override def reduce(): Constraint = {
- if(children.size == 1) children.head else {
+ if (children.size == 1) children.head
+ else {
(known, others) match {
- case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce()
- case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0)
+ case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce()
+ case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0)
case (Some(Closed(x)), Vector(m: IsMax)) if x > 0 =>
IsMax(m.children.map { c => IsMul(Closed(x), c) })
case (Some(Closed(x)), Vector(m: IsMax)) if x < 0 =>
@@ -46,7 +49,7 @@ case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) ex
}
}
- override def map(f: Constraint=>Constraint): Constraint = IsMul(children.map(f))
+ override def map(f: Constraint => Constraint): Constraint = IsMul(children.map(f))
override def serialize: String = "(" + children.map(_.serialize).mkString(" * ") + ")"
}
diff --git a/src/main/scala/firrtl/constraint/IsNeg.scala b/src/main/scala/firrtl/constraint/IsNeg.scala
index 46f739c6..574cfd47 100644
--- a/src/main/scala/firrtl/constraint/IsNeg.scala
+++ b/src/main/scala/firrtl/constraint/IsNeg.scala
@@ -11,10 +11,10 @@ object IsNeg {
case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint {
override def reduce(): Constraint = child match {
case k: IsKnown => k.neg
- case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) })
- case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail)
- case x: IsNeg => x.child
- case x: IsPow => this
+ case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) })
+ case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail)
+ case x: IsNeg => x.child
+ case x: IsPow => this
// -[max(a, b)] -> min[-a, -b]
case x: IsMax => IsMin(x.children.map { b => IsNeg(b) })
case x: IsMin => IsMax(x.children.map { b => IsNeg(b) })
@@ -24,9 +24,7 @@ case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint {
lazy val children = Vector(child)
- override def map(f: Constraint=>Constraint): Constraint = IsNeg(f(child))
+ override def map(f: Constraint => Constraint): Constraint = IsNeg(f(child))
override def serialize: String = "(-" + child.serialize + ")"
}
-
-
diff --git a/src/main/scala/firrtl/constraint/IsPow.scala b/src/main/scala/firrtl/constraint/IsPow.scala
index 54a06bf8..2a1fb14a 100644
--- a/src/main/scala/firrtl/constraint/IsPow.scala
+++ b/src/main/scala/firrtl/constraint/IsPow.scala
@@ -12,22 +12,20 @@ case class IsPow private (child: Constraint, dummyArg: Int) extends Constraint {
override def reduce(): Constraint = child match {
case k: IsKnown => k.pow
// 2^(a + b) -> 2^a * 2^b
- case x: IsAdd => IsMul(x.children.map { b => IsPow(b)})
+ case x: IsAdd => IsMul(x.children.map { b => IsPow(b) })
case x: IsMul => this
case x: IsNeg => this
case x: IsPow => this
// 2^(max(a, b)) -> max(2^a, 2^b) since two is always positive, so a, b control magnitude
- case x: IsMax => IsMax(x.children.map {b => IsPow(b)})
- case x: IsMin => IsMin(x.children.map {b => IsPow(b)})
+ case x: IsMax => IsMax(x.children.map { b => IsPow(b) })
+ case x: IsMin => IsMin(x.children.map { b => IsPow(b) })
case x: IsVar => this
case _ => this
}
val children = Vector(child)
- override def map(f: Constraint=>Constraint): Constraint = IsPow(f(child))
+ override def map(f: Constraint => Constraint): Constraint = IsPow(f(child))
override def serialize: String = "(2^" + child.serialize + ")"
}
-
-
diff --git a/src/main/scala/firrtl/constraint/IsVar.scala b/src/main/scala/firrtl/constraint/IsVar.scala
index 98396fa0..18fb53b2 100644
--- a/src/main/scala/firrtl/constraint/IsVar.scala
+++ b/src/main/scala/firrtl/constraint/IsVar.scala
@@ -16,7 +16,7 @@ trait IsVar extends Constraint {
override def serialize: String = name
- override def map(f: Constraint=>Constraint): Constraint = this
+ override def map(f: Constraint => Constraint): Constraint = this
override def reduce() = this
@@ -24,4 +24,3 @@ trait IsVar extends Constraint {
}
case class VarCon(name: String) extends IsVar
-
diff --git a/src/main/scala/firrtl/features/LetterCaseTransform.scala b/src/main/scala/firrtl/features/LetterCaseTransform.scala
index a6cd270a..8610d7b1 100644
--- a/src/main/scala/firrtl/features/LetterCaseTransform.scala
+++ b/src/main/scala/firrtl/features/LetterCaseTransform.scala
@@ -8,14 +8,15 @@ import firrtl.transforms.ManipulateNames
import scala.reflect.ClassTag
/** Parent of transforms that do change the letter case of names in a FIRRTL circuit */
-abstract class LetterCaseTransform[A <: ManipulateNames[_] : ClassTag] extends ManipulateNames[A] {
+abstract class LetterCaseTransform[A <: ManipulateNames[_]: ClassTag] extends ManipulateNames[A] {
protected def newName: String => String
- final def manipulate = (a: String, ns: Namespace) => newName(a) match {
- case `a` => None
- case b => Some(ns.newName(b))
- }
+ final def manipulate = (a: String, ns: Namespace) =>
+ newName(a) match {
+ case `a` => None
+ case b => Some(ns.newName(b))
+ }
}
/** Convert all FIRRTL names to lowercase */
diff --git a/src/main/scala/firrtl/graph/DiGraph.scala b/src/main/scala/firrtl/graph/DiGraph.scala
index 32bcac5f..7720028c 100644
--- a/src/main/scala/firrtl/graph/DiGraph.scala
+++ b/src/main/scala/firrtl/graph/DiGraph.scala
@@ -2,7 +2,7 @@
package firrtl.graph
-import scala.collection.{Map, Set, mutable}
+import scala.collection.{mutable, Map, Set}
import scala.collection.mutable.{LinkedHashMap, LinkedHashSet}
/** An exception that is raised when an assumed DAG has a cycle */
@@ -13,6 +13,7 @@ class PathNotFoundException extends Exception("Unreachable node")
/** A companion to create DiGraphs from mutable data */
object DiGraph {
+
/** Create a DiGraph from a MutableDigraph, representing the same graph */
def apply[T](mdg: MutableDiGraph[T]): DiGraph[T] = mdg
@@ -33,7 +34,8 @@ object DiGraph {
}
/** Represents common behavior of all directed graphs */
-class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {
+class DiGraph[T](private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]]) {
+
/** Check whether the graph contains vertex v */
def contains(v: T): Boolean = edges.contains(v)
@@ -74,8 +76,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
try {
foundPath = path(vertex, node, blacklist = Set.empty)
true
- }
- catch {
+ } catch {
case _: PathNotFoundException =>
foundPath = Seq.empty[T]
false
@@ -138,7 +139,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
*/
- def BFS(root: T): Map[T,T] = BFS(root, Set.empty[T])
+ def BFS(root: T): Map[T, T] = BFS(root, Set.empty[T])
/** Performs breadth-first search on the directed graph, with a blacklist of nodes
*
@@ -147,8 +148,8 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
* @return a Map[T,T] from each visited node to its predecessor in the
* traversal
*/
- def BFS(root: T, blacklist: Set[T]): Map[T,T] = {
- val prev = new mutable.LinkedHashMap[T,T]
+ def BFS(root: T, blacklist: Set[T]): Map[T, T] = {
+ val prev = new mutable.LinkedHashMap[T, T]
val queue = new mutable.Queue[T]
queue.enqueue(root)
while (queue.nonEmpty) {
@@ -181,7 +182,9 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
* @param blacklist list of nodes to stop searching, if encountered
* @return a Set[T] of nodes reachable from `root`
*/
- def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({ case (k, v) => k })
+ def reachableFrom(root: T, blacklist: Set[T]): LinkedHashSet[T] = new LinkedHashSet[T] ++ BFS(root, blacklist).map({
+ case (k, v) => k
+ })
/** Finds a path (if one exists) from one node to another
*
@@ -238,7 +241,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
val callStack = new mutable.Stack[StrongConnectFrame[T]]
for (node <- getVertices) {
- callStack.push(new StrongConnectFrame(node,getEdges(node).iterator))
+ callStack.push(new StrongConnectFrame(node, getEdges(node).iterator))
while (!callStack.isEmpty) {
val frame = callStack.top
val v = frame.v
@@ -257,7 +260,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
val w = frame.edgeIter.next
if (!indices.contains(w)) {
frame.childCall = Some(w)
- callStack.push(new StrongConnectFrame(w,getEdges(w).iterator))
+ callStack.push(new StrongConnectFrame(w, getEdges(w).iterator))
} else if (onstack.contains(w)) {
lowlinks(v) = lowlinks(v).min(indices(w))
}
@@ -269,8 +272,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
val w = stack.pop
onstack -= w
scc += w
- }
- while (scc.last != v);
+ } while (scc.last != v);
sccs.append(scc.toSeq)
}
callStack.pop
@@ -291,7 +293,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
* @param start the node to start at
* @return a Map[T,Seq[Seq[T]]] where the value associated with v is the Seq of all paths from start to v
*/
- def pathsInDAG(start: T): LinkedHashMap[T,Seq[Seq[T]]] = {
+ def pathsInDAG(start: T): LinkedHashMap[T, Seq[Seq[T]]] = {
// paths(v) holds the set of paths from start to v
val paths = new LinkedHashMap[T, mutable.Set[Seq[T]]]
val queue = new mutable.Queue[T]
@@ -299,7 +301,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
def addBinding(n: T, p: Seq[T]): Unit = {
paths.getOrElseUpdate(n, new LinkedHashSet[Seq[T]]) += p
}
- addBinding(start,Seq(start))
+ addBinding(start, Seq(start))
queue += start
queue ++= linearize.filter(reachable.contains(_))
while (!queue.isEmpty) {
@@ -310,22 +312,25 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
}
}
}
- paths.map({ case (k,v) => (k,v.toSeq) })
+ paths.map({ case (k, v) => (k, v.toSeq) })
}
/** Returns a graph with all edges reversed */
def reverse: DiGraph[T] = {
val mdg = new MutableDiGraph[T]
edges.foreach({ case (u, edges) => mdg.addVertex(u) })
- edges.foreach({ case (u, edges) =>
- edges.foreach(v => mdg.addEdge(v,u))
+ edges.foreach({
+ case (u, edges) =>
+ edges.foreach(v => mdg.addEdge(v, u))
})
DiGraph(mdg)
}
private def filterEdges(vprime: Set[T]): LinkedHashMap[T, LinkedHashSet[T]] = {
- def filterNodeSet(s: LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) })
- def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({ case (k, v) => (k, filterNodeSet(v)) })
+ def filterNodeSet(s: LinkedHashSet[T]): LinkedHashSet[T] = s.filter({ case (k) => vprime.contains(k) })
+ def filterAdjacencyLists(m: LinkedHashMap[T, LinkedHashSet[T]]): LinkedHashMap[T, LinkedHashSet[T]] = m.map({
+ case (k, v) => (k, filterNodeSet(v))
+ })
val eprime: LinkedHashMap[T, LinkedHashSet[T]] = edges.filter({ case (k, v) => vprime.contains(k) })
filterAdjacencyLists(eprime)
}
@@ -354,7 +359,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
*/
def simplify(vprime: Set[T]): DiGraph[T] = {
require(vprime.subsetOf(edges.keySet))
- val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime-v)) )
+ val pathEdges = vprime.map(v => (v, reachableFrom(v) & (vprime - v)))
new DiGraph(new LinkedHashMap[T, LinkedHashSet[T]] ++ pathEdges)
}
@@ -384,6 +389,7 @@ class DiGraph[T] (private[graph] val edges: LinkedHashMap[T, LinkedHashSet[T]])
}
class MutableDiGraph[T] extends DiGraph[T](new LinkedHashMap[T, LinkedHashSet[T]]) {
+
/** Add vertex v to the graph
* @return v, the added vertex
*/
diff --git a/src/main/scala/firrtl/graph/EdgeData.scala b/src/main/scala/firrtl/graph/EdgeData.scala
index 16990de0..6a63c3b9 100644
--- a/src/main/scala/firrtl/graph/EdgeData.scala
+++ b/src/main/scala/firrtl/graph/EdgeData.scala
@@ -6,11 +6,10 @@ import scala.collection.mutable
/**
* An exception that indicates that an edge cannot be found in a graph with edge data.
- *
+ *
* @note the vertex type is not captured as a type parameter, as it would be erased.
*/
-class EdgeNotFoundException(u: Any, v: Any)
- extends IllegalArgumentException(s"Edge (${u}, ${v}) does not exist!")
+class EdgeNotFoundException(u: Any, v: Any) extends IllegalArgumentException(s"Edge (${u}, ${v}) does not exist!")
/**
* Mixing this trait into a DiGraph indicates that each edge may be associated with an optional
diff --git a/src/main/scala/firrtl/graph/EulerTour.scala b/src/main/scala/firrtl/graph/EulerTour.scala
index 2d8a17e2..5e075ae2 100644
--- a/src/main/scala/firrtl/graph/EulerTour.scala
+++ b/src/main/scala/firrtl/graph/EulerTour.scala
@@ -6,6 +6,7 @@ import scala.collection.mutable
/** Euler Tour companion object */
object EulerTour {
+
/** Create an Euler Tour of a `DiGraph[T]` */
def apply[T](diGraph: DiGraph[T], start: T): EulerTour[Seq[T]] = {
val r = mutable.Map[Seq[T], Int]()
@@ -66,8 +67,8 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
* the index of that minimum in each block, b.
*/
private lazy val blocks = (h ++ (1 to (m - n % m))).grouped(m).toArray
- private lazy val a = blocks map (_.min)
- private lazy val b = blocks map (b => b.indexOf(b.min))
+ private lazy val a = blocks.map(_.min)
+ private lazy val b = blocks.map(b => b.indexOf(b.min))
/** Construct a Sparse Table (ST) representation for the minimum index
* of a sequence of integers. Data in the returned array is indexed
@@ -75,7 +76,10 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
*/
private def constructSparseTable(x: Seq[Int]): Array[Array[Int]] = {
val tmp = Array.ofDim[Int](x.size + 1, math.ceil(lg(x.size)).toInt)
- for (i <- 0 to x.size - 1; j <- 0 to math.ceil(lg(x.size)).toInt - 1) {
+ for {
+ i <- 0 to x.size - 1
+ j <- 0 to math.ceil(lg(x.size)).toInt - 1
+ } {
tmp(i)(j) = -1
}
@@ -86,11 +90,11 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
} else {
val (a, b, c) = (base, base + (1 << (size - 1)), size - 1)
- val l = if (tmp(a)(c) != -1) { tmp(a)(c) }
- else { tableRecursive(a, c) }
+ val l = if (tmp(a)(c) != -1) { tmp(a)(c) }
+ else { tableRecursive(a, c) }
- val r = if (tmp(b)(c) != -1) { tmp(b)(c) }
- else { tableRecursive(b, c) }
+ val r = if (tmp(b)(c) != -1) { tmp(b)(c) }
+ else { tableRecursive(b, c) }
val min = if (x(l) < x(r)) l else r
tmp(base)(size) = min
@@ -99,9 +103,11 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
}
}
- for (i <- (0 to x.size - 1);
- j <- (0 to math.ceil(lg(x.size)).toInt - 1);
- if i + (1 << j) - 1 < x.size) {
+ for {
+ i <- (0 to x.size - 1)
+ j <- (0 to math.ceil(lg(x.size)).toInt - 1)
+ if i + (1 << j) - 1 < x.size
+ } {
tableRecursive(i, j)
}
tmp
@@ -117,16 +123,26 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
}
val size = m - 1
- val out = Seq.fill(size)(Seq(-1, 1))
- .flatten.combinations(m - 1).flatMap(_.permutations).toList
+ val out = Seq
+ .fill(size)(Seq(-1, 1))
+ .flatten
+ .combinations(m - 1)
+ .flatMap(_.permutations)
+ .toList
.sortWith(sortSeqSeq)
.map(_.foldLeft(Seq(0))((h, pm) => (h.head + pm) +: h).reverse)
- .map{ a =>
+ .map { a =>
val tmp = Array.ofDim[Int](m, m)
- for (i <- 0 to size; j <- i to size) yield {
+ for {
+ i <- 0 to size
+ j <- i to size
+ } yield {
val window = a.slice(i, j + 1)
- tmp(i)(j) = window.indexOf(window.min) + i }
- tmp }.toArray
+ tmp(i)(j) = window.indexOf(window.min) + i
+ }
+ tmp
+ }
+ .toArray
out
}
private lazy val tables = constructTableLookups(m)
@@ -167,7 +183,7 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
// Compute block and word indices
val (block_i, block_j) = (i / m, j / m)
- val (word_i, word_j) = (i % m, j % m)
+ val (word_i, word_j) = (i % m, j % m)
/** Up to four possible minimum indices are then computed based on the
* following conditions:
@@ -187,12 +203,12 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(word_j)
Seq(min_i)
case (bi, bj) if (block_i == block_j - 1) =>
- val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1)
- val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j)
+ val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(m - 1)
+ val min_j = block_j * m + tables(tableIdx(block_j))(0)(word_j)
Seq(min_i, min_j)
case _ =>
- val min_i = block_i * m + tables(tableIdx(block_i))(word_i)( m - 1)
- val min_j = block_j * m + tables(tableIdx(block_j))( 0)(word_j)
+ val min_i = block_i * m + tables(tableIdx(block_i))(word_i)(m - 1)
+ val min_j = block_j * m + tables(tableIdx(block_j))(0)(word_j)
val (min_between_l, min_between_r) = {
val range = math.floor(lg(block_j - block_i - 1)).toInt
val base_0 = block_i + 1
@@ -200,7 +216,8 @@ class EulerTour[T](r: Map[T, Int], e: Seq[T], h: Seq[Int]) {
val (idx_0, idx_1) = (st(base_0)(range), st(base_1)(range))
val (min_0, min_1) = (b(idx_0) + idx_0 * m, b(idx_1) + idx_1 * m)
- (min_0, min_1) }
+ (min_0, min_1)
+ }
Seq(min_i, min_between_l, min_between_r, min_j)
}
diff --git a/src/main/scala/firrtl/graph/RenderDiGraph.scala b/src/main/scala/firrtl/graph/RenderDiGraph.scala
index b3c1373c..45be3a8f 100644
--- a/src/main/scala/firrtl/graph/RenderDiGraph.scala
+++ b/src/main/scala/firrtl/graph/RenderDiGraph.scala
@@ -16,7 +16,6 @@ import scala.collection.mutable
*/
class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankDir: String = "LR") {
-
/**
* override this to change the default way a node is displayed. Default is toString surrounded by double quotes
* This example changes the double quotes to brackets
@@ -38,8 +37,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
try {
diGraph.linearize
- }
- catch {
+ } catch {
case cyclicException: CyclicException =>
val node = cyclicException.node.asInstanceOf[T]
path = diGraph.findLoopAtNode(node)
@@ -61,31 +59,29 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
val loop = findOneLoop
- if(loop.nonEmpty) {
+ if (loop.nonEmpty) {
// Find all the children of the nodes in the loop
val childrenFound = diGraph.getEdgeMap.flatMap {
case (node, children) if loop.contains(node) => children
- case _ => Seq.empty
+ case _ => Seq.empty
}.toSet
// Create a new DiGraph containing only loop and direct children or parents
val edgeData = diGraph.getEdgeMap
- val newEdgeData = edgeData.flatMap { case (node, children) =>
- if(loop.contains(node)) {
- Some(node -> children)
- }
- else if(childrenFound.contains(node)) {
- Some(node -> children.intersect(loop))
- }
- else {
- val newChildren = children.intersect(loop)
- if(newChildren.nonEmpty) {
- Some(node -> newChildren)
- }
- else {
- None
- }
+ val newEdgeData = edgeData.flatMap {
+ case (node, children) =>
+ if (loop.contains(node)) {
+ Some(node -> children)
+ } else if (childrenFound.contains(node)) {
+ Some(node -> children.intersect(loop))
+ } else {
+ val newChildren = children.intersect(loop)
+ if (newChildren.nonEmpty) {
+ Some(node -> newChildren)
+ } else {
+ None
+ }
}
}
@@ -96,8 +92,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
}
}
newRenderer.toDotWithLoops(loop, getRankedNodes)
- }
- else {
+ } else {
""
}
}
@@ -114,10 +109,11 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
val edges = diGraph.getEdgeMap
- edges.foreach { case (parent, children) =>
- children.foreach { child =>
- s.append(s""" ${renderNode(parent)} -> ${renderNode(child)};""" + "\n")
- }
+ edges.foreach {
+ case (parent, children) =>
+ children.foreach { child =>
+ s.append(s""" ${renderNode(parent)} -> ${renderNode(child)};""" + "\n")
+ }
}
s.append("}\n")
s.toString
@@ -137,24 +133,25 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
val edges = diGraph.getEdgeMap
- edges.foreach { case (parent, children) =>
- allNodes += parent
- allNodes ++= children
+ edges.foreach {
+ case (parent, children) =>
+ allNodes += parent
+ allNodes ++= children
- children.foreach { child =>
- val highlight = if(loopedNodes.contains(parent) && loopedNodes.contains(child)) {
- "[color=red,penwidth=3.0]"
- }
- else {
- ""
+ children.foreach { child =>
+ val highlight = if (loopedNodes.contains(parent) && loopedNodes.contains(child)) {
+ "[color=red,penwidth=3.0]"
+ } else {
+ ""
+ }
+ s.append(s""" ${renderNode(parent)} -> ${renderNode(child)}$highlight;""" + "\n")
}
- s.append(s""" ${renderNode(parent)} -> ${renderNode(child)}$highlight;""" + "\n")
- }
}
val paredRankedNodes = rankedNodes.flatMap { nodes =>
val newNodes = nodes.filter(allNodes.contains)
- if(newNodes.nonEmpty) { Some(newNodes) } else { None }
+ if (newNodes.nonEmpty) { Some(newNodes) }
+ else { None }
}
paredRankedNodes.foreach { nodesAtRank =>
@@ -183,7 +180,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
diGraph.getEdges(node)
}.filterNot(alreadyVisited.contains).distinct
- if(nextNodes.nonEmpty) {
+ if (nextNodes.nonEmpty) {
walkByRank(nextNodes, rankNumber + 1)
}
}
@@ -191,6 +188,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
walkByRank(diGraph.findSources.toSeq)
rankNodes
}
+
/**
* Convert this graph into input for the graphviz dot program.
* It tries to align nodes in columns based
@@ -216,7 +214,7 @@ class RenderDiGraph[T <: Any](diGraph: DiGraph[T], graphName: String = "", rankD
children
}.filterNot(alreadyVisited.contains).distinct
- if(nextNodes.nonEmpty) {
+ if (nextNodes.nonEmpty) {
walkByRank(nextNodes, rankNumber + 1)
}
}
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index 5263d9c0..2536a77e 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -39,41 +39,50 @@ case class FileInfo(escaped: String) extends Info {
object FileInfo {
@deprecated("Use FileInfo.fromUnEscaped instead. FileInfo.apply will be removed in FIRRTL 1.5.", "FIRRTL 1.4")
- def apply(info: StringLit): FileInfo = new FileInfo(escape(info.string))
- def fromEscaped(s: String): FileInfo = new FileInfo(s)
- def fromUnescaped(s: String): FileInfo = new FileInfo(escape(s))
+ def apply(info: StringLit): FileInfo = new FileInfo(escape(info.string))
+ def fromEscaped(s: String): FileInfo = new FileInfo(s)
+ def fromUnescaped(s: String): FileInfo = new FileInfo(escape(s))
+
/** prepends a `\` to: `\`, `\n`, `\t` and `]` */
def escape(s: String): String = EscapeFirrtl.translate(s)
+
/** removes the `\` in front of `\`, `\n`, `\t` and `]` */
def unescape(s: String): String = UnescapeFirrtl.translate(s)
+
/** take an already escaped String and do the additional escaping needed for Verilog comment */
def escapedToVerilog(s: String) = EscapedToVerilog.translate(s)
// custom `CharSequenceTranslator` for FIRRTL Info String escaping
type CharMap = (CharSequence, CharSequence)
- private val EscapeFirrtl = new LookupTranslator(Seq[CharMap](
- "\\" -> "\\\\",
- "\n" -> "\\n",
- "\t" -> "\\t",
- "]" -> "\\]"
- ).toMap.asJava)
- private val UnescapeFirrtl = new LookupTranslator(Seq[CharMap](
- "\\\\" -> "\\",
- "\\n" -> "\n",
- "\\t" -> "\t",
- "\\]" -> "]"
- ).toMap.asJava)
+ private val EscapeFirrtl = new LookupTranslator(
+ Seq[CharMap](
+ "\\" -> "\\\\",
+ "\n" -> "\\n",
+ "\t" -> "\\t",
+ "]" -> "\\]"
+ ).toMap.asJava
+ )
+ private val UnescapeFirrtl = new LookupTranslator(
+ Seq[CharMap](
+ "\\\\" -> "\\",
+ "\\n" -> "\n",
+ "\\t" -> "\t",
+ "\\]" -> "]"
+ ).toMap.asJava
+ )
// EscapeFirrtl + EscapedToVerilog essentially does the same thing as running StringEscapeUtils.unescapeJava
private val EscapedToVerilog = new AggregateTranslator(
- new LookupTranslator(Seq[CharMap](
- // ] is the one character that firrtl needs to be escaped that does not need to be escaped in
- "\\]" -> "]",
- "\"" -> "\\\"",
- // \n and \t are already escaped
- "\b" -> "\\b",
- "\f" -> "\\f",
- "\r" -> "\\r"
- ).toMap.asJava),
+ new LookupTranslator(
+ Seq[CharMap](
+ // ] is the one character that firrtl needs to be escaped that does not need to be escaped in
+ "\\]" -> "]",
+ "\"" -> "\\\"",
+ // \n and \t are already escaped
+ "\b" -> "\\b",
+ "\f" -> "\\f",
+ "\r" -> "\\r"
+ ).toMap.asJava
+ ),
JavaUnicodeEscaper.outsideOf(32, 0x7f)
)
@@ -81,9 +90,9 @@ object FileInfo {
case class MultiInfo(infos: Seq[Info]) extends Info {
private def collectStrings(info: Info): Seq[String] = info match {
- case f : FileInfo => Seq(f.escaped)
- case MultiInfo(seq) => seq flatMap collectStrings
- case NoInfo => Seq.empty
+ case f: FileInfo => Seq(f.escaped)
+ case MultiInfo(seq) => seq.flatMap(collectStrings)
+ case NoInfo => Seq.empty
}
override def toString: String = {
val parts = collectStrings(this)
@@ -107,12 +116,12 @@ object MultiInfo {
// TODO should this be made into an API?
private[firrtl] def demux(info: Info): (Info, Info, Info) = info match {
case MultiInfo(infos) if infos.lengthCompare(3) == 0 => (infos(0), infos(1), infos(2))
- case other => (other, NoInfo, NoInfo) // if not exactly 3, we don't know what to do
+ case other => (other, NoInfo, NoInfo) // if not exactly 3, we don't know what to do
}
-
+
private def flattenInfo(infos: Seq[Info]): Seq[FileInfo] = infos.flatMap {
case NoInfo => Seq()
- case f : FileInfo => Seq(f)
+ case f: FileInfo => Seq(f)
case MultiInfo(infos) => flattenInfo(infos)
}
}
@@ -127,6 +136,7 @@ trait IsDeclaration extends HasName with HasInfo
case class StringLit(string: String) extends FirrtlNode {
import org.apache.commons.text.StringEscapeUtils
+
/** Returns an escaped and quoted String */
def escape: String = {
"\"" + serialize + "\""
@@ -137,26 +147,28 @@ case class StringLit(string: String) extends FirrtlNode {
def verilogFormat: StringLit = {
StringLit(string.replaceAll("%x", "%h"))
}
+
/** Returns an escaped and quoted String */
def verilogEscape: String = {
// normalize to turn things like ö into o
import java.text.Normalizer
val normalized = Normalizer.normalize(string, Normalizer.Form.NFD)
- val ascii = normalized flatMap StringLit.toASCII
+ val ascii = normalized.flatMap(StringLit.toASCII)
ascii.mkString("\"", "", "\"")
}
}
object StringLit {
import org.apache.commons.text.StringEscapeUtils
+
/** Maps characters to ASCII for Verilog emission */
private def toASCII(char: Char): List[Char] = char match {
case nonASCII if !nonASCII.isValidByte => List('?')
- case '"' => List('\\', '"')
- case '\\' => List('\\', '\\')
- case c if c >= ' ' && c <= '~' => List(c)
- case '\n' => List('\\', 'n')
- case '\t' => List('\\', 't')
- case _ => List('?')
+ case '"' => List('\\', '"')
+ case '\\' => List('\\', '\\')
+ case c if c >= ' ' && c <= '~' => List(c)
+ case '\n' => List('\\', 'n')
+ case '\t' => List('\\', 't')
+ case _ => List('?')
}
/** Create a StringLit from a raw parsed String */
@@ -175,8 +187,8 @@ abstract class PrimOp extends FirrtlNode {
def apply(args: Any*): DoPrim = {
val groups = args.groupBy {
case x: Expression => "exp"
- case x: BigInt => "int"
- case x: Int => "int"
+ case x: BigInt => "int"
+ case x: Int => "int"
case other => "other"
}
val exprs = groups.getOrElse("exp", Nil).collect {
@@ -185,11 +197,11 @@ abstract class PrimOp extends FirrtlNode {
val consts = groups.getOrElse("int", Nil).map {
_ match {
case i: BigInt => i
- case i: Int => BigInt(i)
+ case i: Int => BigInt(i)
}
}
groups.get("other") match {
- case None =>
+ case None =>
case Some(x) => sys.error(s"Shouldn't be here: $x")
}
DoPrim(this, exprs, consts, UnknownType)
@@ -198,12 +210,12 @@ abstract class PrimOp extends FirrtlNode {
abstract class Expression extends FirrtlNode {
def tpe: Type
- def mapExpr(f: Expression => Expression): Expression
- def mapType(f: Type => Type): Expression
- def mapWidth(f: Width => Width): Expression
- def foreachExpr(f: Expression => Unit): Unit
- def foreachType(f: Type => Unit): Unit
- def foreachWidth(f: Width => Unit): Unit
+ def mapExpr(f: Expression => Expression): Expression
+ def mapType(f: Type => Type): Expression
+ def mapWidth(f: Width => Width): Expression
+ def foreachExpr(f: Expression => Unit): Unit
+ def foreachType(f: Type => Unit): Unit
+ def foreachWidth(f: Width => Unit): Unit
}
/** Represents reference-like expression nodes: SubField, SubIndex, SubAccess and Reference
@@ -215,75 +227,92 @@ abstract class Expression extends FirrtlNode {
sealed trait RefLikeExpression extends Expression { def flow: Flow }
object Reference {
+
/** Creates a Reference from a Wire */
def apply(wire: DefWire): Reference = Reference(wire.name, wire.tpe, WireKind, UnknownFlow)
+
/** Creates a Reference from a Register */
def apply(reg: DefRegister): Reference = Reference(reg.name, reg.tpe, RegKind, UnknownFlow)
+
/** Creates a Reference from a Node */
def apply(node: DefNode): Reference = Reference(node.name, node.value.tpe, NodeKind, SourceFlow)
+
/** Creates a Reference from a Port */
def apply(port: Port): Reference = Reference(port.name, port.tpe, PortKind, UnknownFlow)
+
/** Creates a Reference from a DefInstance */
def apply(i: DefInstance): Reference = Reference(i.name, i.tpe, InstanceKind, UnknownFlow)
+
/** Creates a Reference from a DefMemory */
def apply(mem: DefMemory): Reference = Reference(mem.name, passes.MemPortUtils.memType(mem), MemKind, UnknownFlow)
}
case class Reference(name: String, tpe: Type = UnknownType, kind: Kind = UnknownKind, flow: Flow = UnknownFlow)
- extends Expression with HasName with UseSerializer with RefLikeExpression {
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachWidth(f: Width => Unit): Unit = ()
+ extends Expression
+ with HasName
+ with UseSerializer
+ with RefLikeExpression {
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case class SubField(expr: Expression, name: String, tpe: Type = UnknownType, flow: Flow = UnknownFlow)
- extends Expression with HasName with UseSerializer with RefLikeExpression {
- def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = f(expr)
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachWidth(f: Width => Unit): Unit = ()
+ extends Expression
+ with HasName
+ with UseSerializer
+ with RefLikeExpression {
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case class SubIndex(expr: Expression, value: Int, tpe: Type, flow: Flow = UnknownFlow)
- extends Expression with UseSerializer with RefLikeExpression {
- def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = f(expr)
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachWidth(f: Width => Unit): Unit = ()
+ extends Expression
+ with UseSerializer
+ with RefLikeExpression {
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case class SubAccess(expr: Expression, index: Expression, tpe: Type, flow: Flow = UnknownFlow)
- extends Expression with UseSerializer with RefLikeExpression {
- def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index))
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
+ extends Expression
+ with UseSerializer
+ with RefLikeExpression {
+ def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
def foreachExpr(f: Expression => Unit): Unit = { f(expr); f(index) }
- def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = ()
}
case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type = UnknownType)
- extends Expression with UseSerializer {
- def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe)
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
+ extends Expression
+ with UseSerializer {
+ def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe)
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(tval); f(fval) }
- def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = ()
}
case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression with UseSerializer {
- def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe)
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
+ def mapExpr(f: Expression => Expression): Expression = ValidIf(f(cond), f(value), tpe)
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
def foreachExpr(f: Expression => Unit): Unit = { f(cond); f(value) }
- def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = ()
}
abstract class Literal extends Expression {
@@ -292,16 +321,16 @@ abstract class Literal extends Expression {
}
case class UIntLiteral(value: BigInt, width: Width) extends Literal with UseSerializer {
def tpe = UIntType(width)
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width))
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = f(width)
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width))
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
object UIntLiteral {
def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1))
- def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value))
+ def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value))
/** Utility to construct UIntLiterals masked by the width
*
@@ -314,78 +343,82 @@ object UIntLiteral {
}
case class SIntLiteral(value: BigInt, width: Width) extends Literal with UseSerializer {
def tpe = SIntType(width)
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width))
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = f(width)
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width))
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
object SIntLiteral {
def minWidth(value: BigInt): Width = IntWidth(value.bitLength + 1)
- def apply(value: BigInt): SIntLiteral = new SIntLiteral(value, minWidth(value))
+ def apply(value: BigInt): SIntLiteral = new SIntLiteral(value, minWidth(value))
}
case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Literal with UseSerializer {
def tpe = FixedType(width, point)
- def mapExpr(f: Expression => Expression): Expression = this
- def mapType(f: Type => Type): Expression = this
- def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point))
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Expression = this
+ def mapType(f: Type => Type): Expression = this
+ def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point))
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) }
}
case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type)
- extends Expression with UseSerializer {
- def mapExpr(f: Expression => Expression): Expression = this.copy(args = args map f)
- def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
- def mapWidth(f: Width => Width): Expression = this
- def foreachExpr(f: Expression => Unit): Unit = args.foreach(f)
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachWidth(f: Width => Unit): Unit = ()
+ extends Expression
+ with UseSerializer {
+ def mapExpr(f: Expression => Expression): Expression = this.copy(args = args.map(f))
+ def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
+ def mapWidth(f: Width => Width): Expression = this
+ def foreachExpr(f: Expression => Unit): Unit = args.foreach(f)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachWidth(f: Width => Unit): Unit = ()
}
abstract class Statement extends FirrtlNode {
- def mapStmt(f: Statement => Statement): Statement
- def mapExpr(f: Expression => Expression): Statement
- def mapType(f: Type => Type): Statement
- def mapString(f: String => String): Statement
- def mapInfo(f: Info => Info): Statement
- def foreachStmt(f: Statement => Unit): Unit
- def foreachExpr(f: Expression => Unit): Unit
- def foreachType(f: Type => Unit): Unit
- def foreachString(f: String => Unit): Unit
- def foreachInfo(f: Info => Unit): Unit
+ def mapStmt(f: Statement => Statement): Statement
+ def mapExpr(f: Expression => Expression): Statement
+ def mapType(f: Type => Type): Statement
+ def mapString(f: String => String): Statement
+ def mapInfo(f: Info => Info): Statement
+ def foreachStmt(f: Statement => Unit): Unit
+ def foreachExpr(f: Expression => Unit): Unit
+ def foreachType(f: Type => Unit): Unit
+ def foreachString(f: String => Unit): Unit
+ def foreachInfo(f: Info => Unit): Unit
}
case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe))
- def mapString(f: String => String): Statement = DefWire(info, f(name), tpe)
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = DefWire(info, name, f(tpe))
+ def mapString(f: String => String): Statement = DefWire(info, f(name), tpe)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class DefRegister(
- info: Info,
- name: String,
- tpe: Type,
- clock: Expression,
- reset: Expression,
- init: Expression) extends Statement with IsDeclaration with UseSerializer {
+ info: Info,
+ name: String,
+ tpe: Type,
+ clock: Expression,
+ reset: Expression,
+ init: Expression)
+ extends Statement
+ with IsDeclaration
+ with UseSerializer {
def mapStmt(f: Statement => Statement): Statement = this
def mapExpr(f: Expression => Expression): Statement =
DefRegister(info, name, tpe, f(clock), f(reset), f(init))
- def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { f(clock); f(reset); f(init) }
- def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
object DefInstance {
@@ -393,17 +426,19 @@ object DefInstance {
}
case class DefInstance(info: Info, name: String, module: String, tpe: Type = UnknownType)
- extends Statement with IsDeclaration with UseSerializer {
- def mapExpr(f: Expression => Expression): Statement = this
- def mapStmt(f: Statement => Statement): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(tpe)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ extends Statement
+ with IsDeclaration
+ with UseSerializer {
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
object ReadUnderWrite extends Enumeration {
@@ -413,56 +448,64 @@ object ReadUnderWrite extends Enumeration {
}
case class DefMemory(
- info: Info,
- name: String,
- dataType: Type,
- depth: BigInt,
- writeLatency: Int,
- readLatency: Int,
- readers: Seq[String],
- writers: Seq[String],
- readwriters: Seq[String],
- // TODO: handle read-under-write
- readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined)
- extends Statement with IsDeclaration with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(dataType)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
-}
-case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = DefNode(info, f(name), value)
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = f(value)
- def foreachType(f: Type => Unit): Unit = ()
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ info: Info,
+ name: String,
+ dataType: Type,
+ depth: BigInt,
+ writeLatency: Int,
+ readLatency: Int,
+ readers: Seq[String],
+ writers: Seq[String],
+ readwriters: Seq[String],
+ // TODO: handle read-under-write
+ readUnderWrite: ReadUnderWrite.Value = ReadUnderWrite.Undefined)
+ extends Statement
+ with IsDeclaration
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
+}
+case class DefNode(info: Info, name: String, value: Expression)
+ extends Statement
+ with IsDeclaration
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = DefNode(info, name, f(value))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = DefNode(info, f(name), value)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = f(value)
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Conditionally(
- info: Info,
- pred: Expression,
- conseq: Statement,
- alt: Statement) extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt))
- def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt)
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ info: Info,
+ pred: Expression,
+ conseq: Statement,
+ alt: Statement)
+ extends Statement
+ with HasInfo
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = Conditionally(info, pred, f(conseq), f(alt))
+ def mapExpr(f: Expression => Expression): Statement = Conditionally(info, f(pred), conseq, alt)
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
def foreachStmt(f: Statement => Unit): Unit = { f(conseq); f(alt) }
- def foreachExpr(f: Expression => Unit): Unit = f(pred)
- def foreachType(f: Type => Unit): Unit = ()
- def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def foreachExpr(f: Expression => Unit): Unit = f(pred)
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachString(f: String => Unit): Unit = ()
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
object Block {
@@ -489,94 +532,101 @@ case class Block(stmts: Seq[Statement]) extends Statement with UseSerializer {
}
Block(res.toSeq)
}
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this
- def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f)
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = ()
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this
+ def foreachStmt(f: Statement => Unit): Unit = stmts.foreach(f)
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachString(f: String => Unit): Unit = ()
+ def foreachInfo(f: Info => Unit): Unit = ()
}
case class PartialConnect(info: Info, loc: Expression, expr: Expression)
- extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
+ extends Statement
+ with HasInfo
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = PartialConnect(info, f(loc), f(expr))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) }
- def foreachType(f: Type => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
-}
-case class Connect(info: Info, loc: Expression, expr: Expression)
- extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachInfo(f: Info => Unit): Unit = f(info)
+}
+case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Connect(info, f(loc), f(expr))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { f(loc); f(expr) }
- def foreachType(f: Type => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = f(expr)
- def foreachType(f: Type => Unit): Unit = ()
- def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = IsInvalid(info, f(expr))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = f(expr)
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachString(f: String => Unit): Unit = ()
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Attach(info: Info, exprs: Seq[Expression]) extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = Attach(info, exprs map f)
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f)
- def foreachType(f: Type => Unit): Unit = ()
- def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
-}
-case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Attach(info, exprs.map(f))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = exprs.foreach(f)
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachString(f: String => Unit): Unit = ()
+ def foreachInfo(f: Info => Unit): Unit = f(info)
+}
+case class Stop(info: Info, ret: Int, clk: Expression, en: Expression)
+ extends Statement
+ with HasInfo
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Stop(info, ret, f(clk), f(en))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(en) }
- def foreachType(f: Type => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Print(
- info: Info,
- string: StringLit,
- args: Seq[Expression],
- clk: Expression,
- en: Expression) extends Statement with HasInfo with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = Print(info, string, args map f, f(clk), f(en))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
+ info: Info,
+ string: StringLit,
+ args: Seq[Expression],
+ clk: Expression,
+ en: Expression)
+ extends Statement
+ with HasInfo
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = Print(info, string, args.map(f), f(clk), f(en))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { args.foreach(f); f(clk); f(en) }
- def foreachType(f: Type => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
// formal
@@ -587,38 +637,40 @@ object Formal extends Enumeration {
}
case class Verification(
- op: Formal.Value,
+ op: Formal.Value,
info: Info,
- clk: Expression,
+ clk: Expression,
pred: Expression,
- en: Expression,
- msg: StringLit
-) extends Statement with HasInfo with UseSerializer {
+ en: Expression,
+ msg: StringLit)
+ extends Statement
+ with HasInfo
+ with UseSerializer {
def mapStmt(f: Statement => Statement): Statement = this
def mapExpr(f: Expression => Expression): Statement =
copy(clk = f(clk), pred = f(pred), en = f(en))
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = copy(info = f(info))
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = copy(info = f(info))
def foreachStmt(f: Statement => Unit): Unit = ()
def foreachExpr(f: Expression => Unit): Unit = { f(clk); f(pred); f(en); }
- def foreachType(f: Type => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
// end formal
case object EmptyStmt extends Statement with UseSerializer {
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = this
- def mapString(f: String => String): Statement = this
- def mapInfo(f: Info => Info): Statement = this
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = ()
- def foreachString(f: String => Unit): Unit = ()
- def foreachInfo(f: Info => Unit): Unit = ()
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this
+ def mapString(f: String => String): Statement = this
+ def mapInfo(f: Info => Info): Statement = this
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachString(f: String => Unit): Unit = ()
+ def foreachInfo(f: Info => Unit): Unit = ()
}
abstract class Width extends FirrtlNode {
@@ -631,14 +683,15 @@ abstract class Width extends FirrtlNode {
case _ => UnknownWidth
}
def max(x: Width): Width = (this, x) match {
- case (a: IntWidth, b: IntWidth) => IntWidth(a.width max b.width)
+ case (a: IntWidth, b: IntWidth) => IntWidth(a.width.max(b.width))
case _ => UnknownWidth
}
def min(x: Width): Width = (this, x) match {
- case (a: IntWidth, b: IntWidth) => IntWidth(a.width min b.width)
+ case (a: IntWidth, b: IntWidth) => IntWidth(a.width.min(b.width))
case _ => UnknownWidth
}
}
+
/** Positive Integer Bit Width of a [[GroundType]] */
object IntWidth {
private val maxCached = 1024
@@ -665,7 +718,7 @@ class IntWidth(val width: BigInt) extends Width with Product with UseSerializer
override def hashCode = width.toInt
override def productPrefix = "IntWidth"
override def toString = s"$productPrefix($width)"
- def copy(width: BigInt = width) = IntWidth(width)
+ def copy(width: BigInt = width) = IntWidth(width)
def canEqual(that: Any) = that.isInstanceOf[Width]
def productArity = 1
def productElement(int: Int) = int match {
@@ -693,19 +746,18 @@ case object Flip extends Orientation {
/** Field of [[BundleType]] */
case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode with HasName with UseSerializer
-
/** Bounds of [[IntervalType]] */
trait Bound extends Constraint
case object UnknownBound extends Bound {
def serialize: String = Serializer.serialize(this)
- def map(f: Constraint=>Constraint): Constraint = this
+ def map(f: Constraint => Constraint): Constraint = this
override def reduce(): Constraint = this
val children = Vector()
}
case class CalcBound(arg: Constraint) extends Bound {
def serialize: String = Serializer.serialize(this)
- def map(f: Constraint=>Constraint): Constraint = f(arg)
+ def map(f: Constraint => Constraint): Constraint = f(arg)
override def reduce(): Constraint = arg
val children = Vector(arg)
}
@@ -727,58 +779,60 @@ case class Open(value: BigDecimal) extends IsKnown with Bound {
def +(that: IsKnown): IsKnown = Open(value + that.value)
def *(that: IsKnown): IsKnown = that match {
case Closed(x) if x == 0 => Closed(x)
- case _ => Open(value * that.value)
+ case _ => Open(value * that.value)
}
- def min(that: IsKnown): IsKnown = if(value < that.value) this else that
- def max(that: IsKnown): IsKnown = if(value > that.value) this else that
- def neg: IsKnown = Open(-value)
- def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
- def pow: IsKnown = if(value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
+ def min(that: IsKnown): IsKnown = if (value < that.value) this else that
+ def max(that: IsKnown): IsKnown = if (value > that.value) this else that
+ def neg: IsKnown = Open(-value)
+ def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
+ def pow: IsKnown =
+ if (value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
}
case class Closed(value: BigDecimal) extends IsKnown with Bound {
def serialize: String = Serializer.serialize(this)
def +(that: IsKnown): IsKnown = that match {
- case Open(x) => Open(value + x)
+ case Open(x) => Open(value + x)
case Closed(x) => Closed(value + x)
}
def *(that: IsKnown): IsKnown = that match {
case IsKnown(x) if value == BigInt(0) => Closed(0)
- case Open(x) => Open(value * x)
- case Closed(x) => Closed(value * x)
+ case Open(x) => Open(value * x)
+ case Closed(x) => Closed(value * x)
}
- def min(that: IsKnown): IsKnown = if(value <= that.value) this else that
- def max(that: IsKnown): IsKnown = if(value >= that.value) this else that
- def neg: IsKnown = Closed(-value)
+ def min(that: IsKnown): IsKnown = if (value <= that.value) this else that
+ def max(that: IsKnown): IsKnown = if (value >= that.value) this else that
+ def neg: IsKnown = Closed(-value)
def floor: IsKnown = Closed(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
- def pow: IsKnown = if(value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
+ def pow: IsKnown =
+ if (value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
}
/** Types of [[FirrtlNode]] */
abstract class Type extends FirrtlNode {
- def mapType(f: Type => Type): Type
- def mapWidth(f: Width => Width): Type
- def foreachType(f: Type => Unit): Unit
- def foreachWidth(f: Width => Unit): Unit
+ def mapType(f: Type => Type): Type
+ def mapWidth(f: Width => Width): Type
+ def foreachType(f: Type => Unit): Unit
+ def foreachWidth(f: Width => Unit): Unit
}
abstract class GroundType extends Type {
val width: Width
- def mapType(f: Type => Type): Type = this
+ def mapType(f: Type => Type): Type = this
def foreachType(f: Type => Unit): Unit = ()
}
object GroundType {
def unapply(ground: GroundType): Option[Width] = Some(ground.width)
}
abstract class AggregateType extends Type {
- def mapWidth(f: Width => Width): Type = this
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case class UIntType(width: Width) extends GroundType with UseSerializer {
- def mapWidth(f: Width => Width): Type = UIntType(f(width))
- def foreachWidth(f: Width => Unit): Unit = f(width)
+ def mapWidth(f: Width => Width): Type = UIntType(f(width))
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
case class SIntType(width: Width) extends GroundType with UseSerializer {
- def mapWidth(f: Width => Width): Type = SIntType(f(width))
- def foreachWidth(f: Width => Unit): Unit = f(width)
+ def mapWidth(f: Width => Width): Type = SIntType(f(width))
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
case class FixedType(width: Width, point: Width) extends GroundType with UseSerializer {
def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point))
@@ -790,21 +844,21 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground
case Open(l) => s"(${dec2string(l)}, "
case Closed(l) => s"[${dec2string(l)}, "
case UnknownBound => s"[?, "
- case _ => s"[?, "
+ case _ => s"[?, "
}
val upperString = upper match {
case Open(u) => s"${dec2string(u)})"
case Closed(u) => s"${dec2string(u)}]"
case UnknownBound => s"?]"
- case _ => s"?]"
+ case _ => s"?]"
}
val bounds = (lower, upper) match {
case (k1: IsKnown, k2: IsKnown) => lowerString + upperString
case _ => ""
}
val pointString = point match {
- case IntWidth(i) => "." + i.toString
- case _ => ""
+ case IntWidth(i) => "." + i.toString
+ case _ => ""
}
"Interval" + bounds + pointString
}
@@ -813,35 +867,43 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground
private def precision: Option[BigDecimal] = point match {
case IntWidth(width) =>
val bp = width.toInt
- if(bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp))
+ if (bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp))
case other => None
}
def min: Option[BigDecimal] = (lower, precision) match {
- case (Open(a), Some(prec)) => a / prec match {
- case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0
- case x => Some(x.setScale(0, CEILING) * prec) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l
- }
+ case (Open(a), Some(prec)) =>
+ a / prec match {
+ case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0
+ case x =>
+ Some(
+ x.setScale(0, CEILING) * prec
+ ) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l
+ }
case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, CEILING) * prec)
- case other => None
+ case other => None
}
def max: Option[BigDecimal] = (upper, precision) match {
- case (Open(a), Some(prec)) => a / prec match {
- case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound
- case x => Some(x.setScale(0, FLOOR) * prec)
- }
+ case (Open(a), Some(prec)) =>
+ a / prec match {
+ case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound
+ case x => Some(x.setScale(0, FLOOR) * prec)
+ }
case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec)
}
def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match {
case x if trim(x).isWhole | x.doubleValue == 0.0 => x.toBigInt
- case x => sys.error(s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}.")
+ case x =>
+ sys.error(
+ s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}."
+ )
})
def maxAdjusted: Option[BigInt] = max.map(_ * BigDecimal(BigInt(1) << bp) match {
case x if trim(x).isWhole => x.toBigInt
- case x => sys.error(s"MaxAdjusted should be a whole number: $x")
+ case x => sys.error(s"MaxAdjusted should be a whole number: $x")
})
/** If bounds are known, calculates the width, otherwise returns UnknownWidth */
@@ -854,48 +916,48 @@ case class IntervalType(lower: Bound, upper: Bound, point: Width) extends Ground
/** If bounds are known, returns a sequence of all possible values inside this interval */
lazy val range: Option[Seq[BigDecimal]] = (lower, upper, point) match {
case (l: IsKnown, u: IsKnown, p: IntWidth) =>
- if(min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get))
+ if (min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get))
case _ => None
}
- override def mapWidth(f: Width => Width): Type = this.copy(point = f(point))
- override def foreachWidth(f: Width => Unit): Unit = f(point)
+ override def mapWidth(f: Width => Width): Type = this.copy(point = f(point))
+ override def foreachWidth(f: Width => Unit): Unit = f(point)
}
case class BundleType(fields: Seq[Field]) extends AggregateType with UseSerializer {
def mapType(f: Type => Type): Type =
- BundleType(fields map (x => x.copy(tpe = f(x.tpe))))
- def foreachType(f: Type => Unit): Unit = fields.foreach{ x => f(x.tpe) }
+ BundleType(fields.map(x => x.copy(tpe = f(x.tpe))))
+ def foreachType(f: Type => Unit): Unit = fields.foreach { x => f(x.tpe) }
}
case class VectorType(tpe: Type, size: Int) extends AggregateType with UseSerializer {
- def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe))
+ def mapType(f: Type => Type): Type = this.copy(tpe = f(tpe))
def foreachType(f: Type => Unit): Unit = f(tpe)
}
case object ClockType extends GroundType with UseSerializer {
val width = IntWidth(1)
- def mapWidth(f: Width => Width): Type = this
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = ()
}
/* Abstract reset, will be inferred to UInt<1> or AsyncReset */
case object ResetType extends GroundType with UseSerializer {
val width = IntWidth(1)
- def mapWidth(f: Width => Width): Type = this
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case object AsyncResetType extends GroundType with UseSerializer {
val width = IntWidth(1)
- def mapWidth(f: Width => Width): Type = this
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapWidth(f: Width => Width): Type = this
+ def foreachWidth(f: Width => Unit): Unit = ()
}
case class AnalogType(width: Width) extends GroundType with UseSerializer {
- def mapWidth(f: Width => Width): Type = AnalogType(f(width))
- def foreachWidth(f: Width => Unit): Unit = f(width)
+ def mapWidth(f: Width => Width): Type = AnalogType(f(width))
+ def foreachWidth(f: Width => Unit): Unit = f(width)
}
case object UnknownType extends Type with UseSerializer {
- def mapType(f: Type => Type): Type = this
- def mapWidth(f: Width => Width): Type = this
- def foreachType(f: Type => Unit): Unit = ()
- def foreachWidth(f: Width => Unit): Unit = ()
+ def mapType(f: Type => Type): Type = this
+ def mapWidth(f: Width => Width): Type = this
+ def foreachType(f: Type => Unit): Unit = ()
+ def foreachWidth(f: Width => Unit): Unit = ()
}
/** [[Port]] Direction */
@@ -909,11 +971,14 @@ case object Output extends Direction {
/** [[DefModule]] Port */
case class Port(
- info: Info,
- name: String,
- direction: Direction,
- tpe: Type) extends FirrtlNode with IsDeclaration with UseSerializer {
- def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe))
+ info: Info,
+ name: String,
+ direction: Direction,
+ tpe: Type)
+ extends FirrtlNode
+ with IsDeclaration
+ with UseSerializer {
+ def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe))
def mapString(f: String => String): Port = Port(info, f(name), direction, tpe)
}
@@ -921,12 +986,16 @@ case class Port(
sealed abstract class Param extends FirrtlNode {
def name: String
}
+
/** Integer (of any width) Parameter */
case class IntParam(name: String, value: BigInt) extends Param with UseSerializer
+
/** IEEE Double Precision Parameter (for Verilog real) */
case class DoubleParam(name: String, value: Double) extends Param with UseSerializer
+
/** String Parameter */
case class StringParam(name: String, value: StringLit) extends Param with UseSerializer
+
/** Raw String Parameter
* Useful for Verilog type parameters
* @note Firrtl doesn't guarantee anything about this String being legal in any backend
@@ -935,59 +1004,65 @@ case class RawStringParam(name: String, value: String) extends Param with UseSer
/** Base class for modules */
abstract class DefModule extends FirrtlNode with IsDeclaration {
- val info : Info
- val name : String
- val ports : Seq[Port]
- def mapStmt(f: Statement => Statement): DefModule
- def mapPort(f: Port => Port): DefModule
- def mapString(f: String => String): DefModule
- def mapInfo(f: Info => Info): DefModule
- def foreachStmt(f: Statement => Unit): Unit
- def foreachPort(f: Port => Unit): Unit
- def foreachString(f: String => Unit): Unit
- def foreachInfo(f: Info => Unit): Unit
+ val info: Info
+ val name: String
+ val ports: Seq[Port]
+ def mapStmt(f: Statement => Statement): DefModule
+ def mapPort(f: Port => Port): DefModule
+ def mapString(f: String => String): DefModule
+ def mapInfo(f: Info => Info): DefModule
+ def foreachStmt(f: Statement => Unit): Unit
+ def foreachPort(f: Port => Unit): Unit
+ def foreachString(f: String => Unit): Unit
+ def foreachInfo(f: Info => Unit): Unit
}
+
/** Internal Module
*
* An instantiable hardware block
*/
case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) extends DefModule with UseSerializer {
- def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body))
- def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
- def mapString(f: String => String): DefModule = this.copy(name = f(name))
- def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
- def foreachStmt(f: Statement => Unit): Unit = f(body)
- def foreachPort(f: Port => Unit): Unit = ports.foreach(f)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def mapStmt(f: Statement => Statement): DefModule = this.copy(body = f(body))
+ def mapPort(f: Port => Port): DefModule = this.copy(ports = ports.map(f))
+ def mapString(f: String => String): DefModule = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = f(body)
+ def foreachPort(f: Port => Unit): Unit = ports.foreach(f)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
+
/** External Module
*
* Generally used for Verilog black boxes
* @param defname Defined name of the external module (ie. the name Firrtl will emit)
*/
case class ExtModule(
- info: Info,
- name: String,
- ports: Seq[Port],
- defname: String,
- params: Seq[Param]) extends DefModule with UseSerializer {
- def mapStmt(f: Statement => Statement): DefModule = this
- def mapPort(f: Port => Port): DefModule = this.copy(ports = ports map f)
- def mapString(f: String => String): DefModule = this.copy(name = f(name))
- def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachPort(f: Port => Unit): Unit = ports.foreach(f)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ info: Info,
+ name: String,
+ ports: Seq[Port],
+ defname: String,
+ params: Seq[Param])
+ extends DefModule
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): DefModule = this
+ def mapPort(f: Port => Port): DefModule = this.copy(ports = ports.map(f))
+ def mapString(f: String => String): DefModule = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): DefModule = this.copy(f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachPort(f: Port => Unit): Unit = ports.foreach(f)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
case class Circuit(info: Info, modules: Seq[DefModule], main: String)
- extends FirrtlNode with HasInfo with UseSerializer {
- def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules map f)
- def mapString(f: String => String): Circuit = this.copy(main = f(main))
- def mapInfo(f: Info => Info): Circuit = this.copy(f(info))
- def foreachModule(f: DefModule => Unit): Unit = modules foreach f
- def foreachString(f: String => Unit): Unit = f(main)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ extends FirrtlNode
+ with HasInfo
+ with UseSerializer {
+ def mapModule(f: DefModule => DefModule): Circuit = this.copy(modules = modules.map(f))
+ def mapString(f: String => String): Circuit = this.copy(main = f(main))
+ def mapInfo(f: Info => Info): Circuit = this.copy(f(info))
+ def foreachModule(f: DefModule => Unit): Unit = modules.foreach(f)
+ def foreachString(f: String => Unit): Unit = f(main)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala
index ea304cf3..bf9a57c1 100644
--- a/src/main/scala/firrtl/ir/Serializer.scala
+++ b/src/main/scala/firrtl/ir/Serializer.scala
@@ -13,19 +13,19 @@ object Serializer {
val builder = new StringBuilder()
val indent = 0
node match {
- case n : Info => s(n)(builder, indent)
- case n : StringLit => s(n)(builder, indent)
- case n : Expression => s(n)(builder, indent)
- case n : Statement => s(n)(builder, indent)
- case n : Width => s(n)(builder, indent)
- case n : Orientation => s(n)(builder, indent)
- case n : Field => s(n)(builder, indent)
- case n : Type => s(n)(builder, indent)
- case n : Direction => s(n)(builder, indent)
- case n : Port => s(n)(builder, indent)
- case n : Param => s(n)(builder, indent)
- case n : DefModule => s(n)(builder, indent)
- case n : Circuit => s(n)(builder, indent)
+ case n: Info => s(n)(builder, indent)
+ case n: StringLit => s(n)(builder, indent)
+ case n: Expression => s(n)(builder, indent)
+ case n: Statement => s(n)(builder, indent)
+ case n: Width => s(n)(builder, indent)
+ case n: Orientation => s(n)(builder, indent)
+ case n: Field => s(n)(builder, indent)
+ case n: Type => s(n)(builder, indent)
+ case n: Direction => s(n)(builder, indent)
+ case n: Port => s(n)(builder, indent)
+ case n: Param => s(n)(builder, indent)
+ case n: DefModule => s(n)(builder, indent)
+ case n: Circuit => s(n)(builder, indent)
}
builder.toString()
}
@@ -39,16 +39,16 @@ object Serializer {
private def flattenInfo(infos: Seq[Info]): Seq[FileInfo] = infos.flatMap {
case NoInfo => Seq()
- case f : FileInfo => Seq(f)
+ case f: FileInfo => Seq(f)
case MultiInfo(infos) => flattenInfo(infos)
}
private def s(node: Info)(implicit b: StringBuilder, indent: Int): Unit = node match {
- case f : FileInfo => b ++= " @[" ; b ++= f.escaped ; b ++= "]"
+ case f: FileInfo => b ++= " @["; b ++= f.escaped; b ++= "]"
case NoInfo => // empty string
- case m : MultiInfo =>
+ case m: MultiInfo =>
val infos = m.flatten
- if(infos.nonEmpty) {
+ if (infos.nonEmpty) {
val lastId = infos.length - 1
b ++= " @["
infos.zipWithIndex.foreach { case (f, i) => b ++= f.escaped; if (i < lastId) b += ' ' }
@@ -61,103 +61,113 @@ object Serializer {
private def s(node: Expression)(implicit b: StringBuilder, indent: Int): Unit = node match {
case Reference(name, _, _, _) => b ++= name
case DoPrim(op, args, consts, _) =>
- b ++= op.toString ; b += '(' ; s(args, ", ", consts.isEmpty) ; s(consts, ", ") ; b += ')'
+ b ++= op.toString; b += '('; s(args, ", ", consts.isEmpty); s(consts, ", "); b += ')'
case UIntLiteral(value, width) =>
- b ++= "UInt" ; s(width) ; b ++= "(\"h" ; b ++= value.toString(16) ; b ++= "\")"
- case SubField(expr, name, _, _) => s(expr) ; b += '.' ; b ++= name
- case SubIndex(expr, value, _, _) => s(expr) ; b += '[' ; b ++= value.toString ; b += ']'
- case SubAccess(expr, index, _, _) => s(expr) ; b += '[' ; s(index) ; b += ']'
+ b ++= "UInt"; s(width); b ++= "(\"h"; b ++= value.toString(16); b ++= "\")"
+ case SubField(expr, name, _, _) => s(expr); b += '.'; b ++= name
+ case SubIndex(expr, value, _, _) => s(expr); b += '['; b ++= value.toString; b += ']'
+ case SubAccess(expr, index, _, _) => s(expr); b += '['; s(index); b += ']'
case Mux(cond, tval, fval, _) =>
- b ++= "mux(" ; s(cond) ; b ++= ", " ; s(tval) ; b ++= ", " ; s(fval) ; b += ')'
- case ValidIf(cond, value, _) => b ++= "validif(" ; s(cond) ; b ++= ", " ; s(value) ; b += ')'
+ b ++= "mux("; s(cond); b ++= ", "; s(tval); b ++= ", "; s(fval); b += ')'
+ case ValidIf(cond, value, _) => b ++= "validif("; s(cond); b ++= ", "; s(value); b += ')'
case SIntLiteral(value, width) =>
- b ++= "SInt" ; s(width) ; b ++= "(\"h" ; b ++= value.toString(16) ; b ++= "\")"
+ b ++= "SInt"; s(width); b ++= "(\"h"; b ++= value.toString(16); b ++= "\")"
case FixedLiteral(value, width, point) =>
- b ++= "Fixed" ; s(width) ; sPoint(point)
- b ++= "(\"h" ; b ++= value.toString(16) ; b ++= "\")"
+ b ++= "Fixed"; s(width); sPoint(point)
+ b ++= "(\"h"; b ++= value.toString(16); b ++= "\")"
// WIR
- case firrtl.WVoid => b ++= "VOID"
- case firrtl.WInvalid => b ++= "INVALID"
+ case firrtl.WVoid => b ++= "VOID"
+ case firrtl.WInvalid => b ++= "INVALID"
case firrtl.EmptyExpression => b ++= "EMPTY"
}
private def s(node: Statement)(implicit b: StringBuilder, indent: Int): Unit = node match {
- case DefNode(info, name, value) => b ++= "node " ; b ++= name ; b ++= " = " ; s(value) ; s(info)
- case Connect(info, loc, expr) => s(loc) ; b ++= " <= " ; s(expr) ; s(info)
+ case DefNode(info, name, value) => b ++= "node "; b ++= name; b ++= " = "; s(value); s(info)
+ case Connect(info, loc, expr) => s(loc); b ++= " <= "; s(expr); s(info)
case Conditionally(info, pred, conseq, alt) =>
- b ++= "when " ; s(pred) ; b ++= " :" ; s(info)
- newLineAndIndent(1) ; s(conseq)(b, indent + 1)
- if(alt != EmptyStmt) {
- newLineAndIndent() ; b ++= "else :"
- newLineAndIndent(1) ; s(alt)(b, indent + 1)
+ b ++= "when "; s(pred); b ++= " :"; s(info)
+ newLineAndIndent(1); s(conseq)(b, indent + 1)
+ if (alt != EmptyStmt) {
+ newLineAndIndent(); b ++= "else :"
+ newLineAndIndent(1); s(alt)(b, indent + 1)
}
- case EmptyStmt => b ++= "skip"
+ case EmptyStmt => b ++= "skip"
case Block(Seq()) => b ++= "skip"
case Block(stmts) =>
val it = stmts.iterator
- while(it.hasNext) {
+ while (it.hasNext) {
s(it.next)
- if(it.hasNext) newLineAndIndent()
+ if (it.hasNext) newLineAndIndent()
}
case Stop(info, ret, clk, en) =>
- b ++= "stop(" ; s(clk) ; b ++= ", " ; s(en) ; b ++= ", " ; b ++= ret.toString ; b += ')' ; s(info)
+ b ++= "stop("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= ret.toString; b += ')'; s(info)
case Print(info, string, args, clk, en) =>
- b ++= "printf(" ; s(clk) ; b ++= ", " ; s(en) ; b ++= ", " ; b ++= string.escape
- if(args.nonEmpty) b ++= ", " ; s(args, ", ") ; b += ')' ; s(info)
- case IsInvalid(info, expr) => s(expr) ; b ++= " is invalid" ; s(info)
- case DefWire(info, name, tpe) => b ++= "wire " ; b ++= name ; b ++= " : " ; s(tpe) ; s(info)
+ b ++= "printf("; s(clk); b ++= ", "; s(en); b ++= ", "; b ++= string.escape
+ if (args.nonEmpty) b ++= ", "; s(args, ", "); b += ')'; s(info)
+ case IsInvalid(info, expr) => s(expr); b ++= " is invalid"; s(info)
+ case DefWire(info, name, tpe) => b ++= "wire "; b ++= name; b ++= " : "; s(tpe); s(info)
case DefRegister(info, name, tpe, clock, reset, init) =>
- b ++= "reg " ; b ++= name ; b ++= " : " ; s(tpe) ; b ++= ", " ; s(clock) ; b ++= " with :" ; newLineAndIndent(1)
- b ++= "reset => (" ; s(reset) ; b ++= ", " ; s(init) ; b += ')' ; s(info)
- case DefInstance(info, name, module, _) => b ++= "inst " ; b ++= name ; b ++= " of " ; b ++= module ; s(info)
- case DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers,
- readwriters, readUnderWrite) =>
- b ++= "mem " ; b ++= name ; b ++= " :" ; s(info) ; newLineAndIndent(1)
- b ++= "data-type => " ; s(dataType) ; newLineAndIndent(1)
- b ++= "depth => " ; b ++= depth.toString() ; newLineAndIndent(1)
- b ++= "read-latency => " ; b ++= readLatency.toString ; newLineAndIndent(1)
- b ++= "write-latency => " ; b ++= writeLatency.toString ; newLineAndIndent(1)
- readers.foreach{ r => b ++= "reader => " ; b ++= r ; newLineAndIndent(1) }
- writers.foreach{ w => b ++= "writer => " ; b ++= w ; newLineAndIndent(1) }
- readwriters.foreach{ r => b ++= "readwriter => " ; b ++= r ; newLineAndIndent(1) }
- b ++= "read-under-write => " ; b ++= readUnderWrite.toString
- case PartialConnect(info, loc, expr) => s(loc) ; b ++= " <- " ; s(expr) ; s(info)
- case Attach(info, exprs) =>
+ b ++= "reg "; b ++= name; b ++= " : "; s(tpe); b ++= ", "; s(clock); b ++= " with :"; newLineAndIndent(1)
+ b ++= "reset => ("; s(reset); b ++= ", "; s(init); b += ')'; s(info)
+ case DefInstance(info, name, module, _) => b ++= "inst "; b ++= name; b ++= " of "; b ++= module; s(info)
+ case DefMemory(
+ info,
+ name,
+ dataType,
+ depth,
+ writeLatency,
+ readLatency,
+ readers,
+ writers,
+ readwriters,
+ readUnderWrite
+ ) =>
+ b ++= "mem "; b ++= name; b ++= " :"; s(info); newLineAndIndent(1)
+ b ++= "data-type => "; s(dataType); newLineAndIndent(1)
+ b ++= "depth => "; b ++= depth.toString(); newLineAndIndent(1)
+ b ++= "read-latency => "; b ++= readLatency.toString; newLineAndIndent(1)
+ b ++= "write-latency => "; b ++= writeLatency.toString; newLineAndIndent(1)
+ readers.foreach { r => b ++= "reader => "; b ++= r; newLineAndIndent(1) }
+ writers.foreach { w => b ++= "writer => "; b ++= w; newLineAndIndent(1) }
+ readwriters.foreach { r => b ++= "readwriter => "; b ++= r; newLineAndIndent(1) }
+ b ++= "read-under-write => "; b ++= readUnderWrite.toString
+ case PartialConnect(info, loc, expr) => s(loc); b ++= " <- "; s(expr); s(info)
+ case Attach(info, exprs) =>
// exprs should never be empty since the attach statement takes *at least* two signals according to the spec
- b ++= "attach (" ; s(exprs, ", ") ; b += ')' ; s(info)
+ b ++= "attach ("; s(exprs, ", "); b += ')'; s(info)
case Verification(op, info, clk, pred, en, msg) =>
- b ++= op.toString ; b += '(' ; s(List(clk, pred, en), ", ", false) ; b ++= msg.escape
- b += ')' ; s(info)
+ b ++= op.toString; b += '('; s(List(clk, pred, en), ", ", false); b ++= msg.escape
+ b += ')'; s(info)
// WIR
case firrtl.CDefMemory(info, name, tpe, size, seq, readUnderWrite) =>
- if(seq) b ++= "smem " else b ++= "cmem "
- b ++= name ; b ++= " : " ; s(tpe) ; b ++= " [" ; b ++= size.toString() ; b += ']' ; s(info)
+ if (seq) b ++= "smem " else b ++= "cmem "
+ b ++= name; b ++= " : "; s(tpe); b ++= " ["; b ++= size.toString(); b += ']'; s(info)
case firrtl.CDefMPort(info, name, _, mem, exps, direction) =>
- b ++= direction.serialize ; b ++= " mport " ; b ++= name ; b ++= " = " ; b ++= mem
- b += '[' ; s(exps.head) ; b ++= "], " ; s(exps(1)) ; s(info)
+ b ++= direction.serialize; b ++= " mport "; b ++= name; b ++= " = "; b ++= mem
+ b += '['; s(exps.head); b ++= "], "; s(exps(1)); s(info)
case firrtl.WDefInstanceConnector(info, name, module, tpe, portCons) =>
- b ++= "inst " ; b ++= name ; b ++= " of " ; b ++= module ; b ++= " with " ; s(tpe) ; b ++= " connected to ("
- s(portCons.map(_._2), ", ") ; b += ')' ; s(info)
+ b ++= "inst "; b ++= name; b ++= " of "; b ++= module; b ++= " with "; s(tpe); b ++= " connected to ("
+ s(portCons.map(_._2), ", "); b += ')'; s(info)
}
private def s(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match {
case IntWidth(width) => b += '<'; b ++= width.toString(); b += '>'
- case UnknownWidth => // empty string
- case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')'
- case VarWidth(name) => b += '<'; b ++= name; b += '>'
+ case UnknownWidth => // empty string
+ case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')'
+ case VarWidth(name) => b += '<'; b ++= name; b += '>'
}
private def sPoint(node: Width)(implicit b: StringBuilder, indent: Int): Unit = node match {
case IntWidth(width) => b ++= "<<"; b ++= width.toString(); b ++= ">>"
- case UnknownWidth => // empty string
- case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')'
- case VarWidth(name) => b ++= "<<"; b ++= name; b ++= ">>"
+ case UnknownWidth => // empty string
+ case CalcWidth(arg) => b ++= "calcw("; s(arg); b += ')'
+ case VarWidth(name) => b ++= "<<"; b ++= name; b ++= ">>"
}
private def s(node: Orientation)(implicit b: StringBuilder, indent: Int): Unit = node match {
case Default => // empty string
- case Flip => b ++= "flip "
+ case Flip => b ++= "flip "
}
private def s(node: Field)(implicit b: StringBuilder, indent: Int): Unit = node match {
@@ -169,19 +179,19 @@ object Serializer {
case UIntType(width: Width) => b ++= "UInt"; s(width)
case SIntType(width: Width) => b ++= "SInt"; s(width)
case FixedType(width, point) => b ++= "Fixed"; s(width); sPoint(point)
- case BundleType(fields) => b ++= "{ "; sField(fields, ", "); b += '}'
- case VectorType(tpe, size) => s(tpe); b += '['; b ++= size.toString; b += ']'
- case ClockType => b ++= "Clock"
- case ResetType => b ++= "Reset"
- case AsyncResetType => b ++= "AsyncReset"
- case AnalogType(width) => b ++= "Analog"; s(width)
- case UnknownType => b += '?'
+ case BundleType(fields) => b ++= "{ "; sField(fields, ", "); b += '}'
+ case VectorType(tpe, size) => s(tpe); b += '['; b ++= size.toString; b += ']'
+ case ClockType => b ++= "Clock"
+ case ResetType => b ++= "Reset"
+ case AsyncResetType => b ++= "AsyncReset"
+ case AnalogType(width) => b ++= "Analog"; s(width)
+ case UnknownType => b += '?'
// the IntervalType has a complicated custom serialization method which does not recurse
case i: IntervalType => b ++= i.serialize
}
private def s(node: Direction)(implicit b: StringBuilder, indent: Int): Unit = node match {
- case Input => b ++= "input"
+ case Input => b ++= "input"
case Output => b ++= "output"
}
@@ -191,50 +201,50 @@ object Serializer {
}
private def s(node: Param)(implicit b: StringBuilder, indent: Int): Unit = node match {
- case IntParam(name, value) => b ++= "parameter " ; b ++= name ; b ++= " = " ; b ++= value.toString
- case DoubleParam(name, value) => b ++= "parameter " ; b ++= name ; b ++= " = " ; b ++= value.toString
- case StringParam(name, value) => b ++= "parameter " ; b ++= name ; b ++= " = " ; b ++= value.escape
+ case IntParam(name, value) => b ++= "parameter "; b ++= name; b ++= " = "; b ++= value.toString
+ case DoubleParam(name, value) => b ++= "parameter "; b ++= name; b ++= " = "; b ++= value.toString
+ case StringParam(name, value) => b ++= "parameter "; b ++= name; b ++= " = "; b ++= value.escape
case RawStringParam(name, value) =>
- b ++= "parameter " ; b ++= name ; b ++= " = "
- b += '\'' ; b ++= value.replace("'", "\\'") ; b += '\''
+ b ++= "parameter "; b ++= name; b ++= " = "
+ b += '\''; b ++= value.replace("'", "\\'"); b += '\''
}
private def s(node: DefModule)(implicit b: StringBuilder, indent: Int): Unit = node match {
case Module(info, name, ports, body) =>
- b ++= "module " ; b ++= name ; b ++= " :" ; s(info)
- ports.foreach{ p => newLineAndIndent(1) ; s(p) }
+ b ++= "module "; b ++= name; b ++= " :"; s(info)
+ ports.foreach { p => newLineAndIndent(1); s(p) }
newLineNoIndent() // add a new line between port declaration and body
- newLineAndIndent(1) ; s(body)(b, indent + 1)
+ newLineAndIndent(1); s(body)(b, indent + 1)
case ExtModule(info, name, ports, defname, params) =>
- b ++= "extmodule " ; b ++= name ; b ++= " :" ; s(info)
- ports.foreach{ p => newLineAndIndent(1) ; s(p) }
- newLineAndIndent(1) ; b ++= "defname = " ; b ++= defname
- params.foreach{ p => newLineAndIndent(1) ; s(p) }
+ b ++= "extmodule "; b ++= name; b ++= " :"; s(info)
+ ports.foreach { p => newLineAndIndent(1); s(p) }
+ newLineAndIndent(1); b ++= "defname = "; b ++= defname
+ params.foreach { p => newLineAndIndent(1); s(p) }
}
private def s(node: Circuit)(implicit b: StringBuilder, indent: Int): Unit = node match {
case Circuit(info, modules, main) =>
- b ++= "circuit " ; b ++= main ; b ++= " :" ; s(info)
- if(modules.nonEmpty) {
- newLineAndIndent(1) ; s(modules.head)(b, indent + 1)
- modules.drop(1).foreach{m => newLineNoIndent(); newLineAndIndent(1) ; s(m)(b, indent + 1) }
+ b ++= "circuit "; b ++= main; b ++= " :"; s(info)
+ if (modules.nonEmpty) {
+ newLineAndIndent(1); s(modules.head)(b, indent + 1)
+ modules.drop(1).foreach { m => newLineNoIndent(); newLineAndIndent(1); s(m)(b, indent + 1) }
}
}
// serialize constraints
private def s(const: Constraint)(implicit b: StringBuilder): Unit = const match {
// Bounds
- case UnknownBound => b += '?'
- case CalcBound(arg) => b ++= "calcb(" ; s(arg) ; b += ')'
+ case UnknownBound => b += '?'
+ case CalcBound(arg) => b ++= "calcb("; s(arg); b += ')'
case VarBound(name) => b ++= name
- case Open(value) => b ++ "o(" ; b ++= value.toString ; b += ')'
- case Closed(value) => b ++ "c(" ; b ++= value.toString ; b += ')'
- case other => other.serialize
+ case Open(value) => b ++ "o("; b ++= value.toString; b += ')'
+ case Closed(value) => b ++ "c("; b ++= value.toString; b += ')'
+ case other => other.serialize
}
/** create a new line with the appropriate indent */
private def newLineAndIndent(inc: Int = 0)(implicit b: StringBuilder, indent: Int): Unit = {
- b += NewLine ; doIndent(inc)
+ b += NewLine; doIndent(inc)
}
private def newLineNoIndent()(implicit b: StringBuilder): Unit = b += NewLine
@@ -245,32 +255,37 @@ object Serializer {
}
/** serialize firrtl Expression nodes with a custom separator and the option to include the separator at the end */
- private def s(nodes: Iterable[Expression], sep: String, noFinalSep: Boolean = true)
- (implicit b: StringBuilder, indent: Int): Unit = {
+ private def s(
+ nodes: Iterable[Expression],
+ sep: String,
+ noFinalSep: Boolean = true
+ )(
+ implicit b: StringBuilder,
+ indent: Int
+ ): Unit = {
val it = nodes.iterator
- while(it.hasNext) {
+ while (it.hasNext) {
s(it.next())
- if(!noFinalSep || it.hasNext) b ++= sep
+ if (!noFinalSep || it.hasNext) b ++= sep
}
}
/** serialize firrtl Field nodes with a custom separator and the option to include the separator at the end */
@inline
- private def sField(nodes: Iterable[Field], sep: String)
- (implicit b: StringBuilder, indent: Int): Unit = {
+ private def sField(nodes: Iterable[Field], sep: String)(implicit b: StringBuilder, indent: Int): Unit = {
val it = nodes.iterator
- while(it.hasNext) {
+ while (it.hasNext) {
s(it.next())
- if(it.hasNext) b ++= sep
+ if (it.hasNext) b ++= sep
}
}
/** serialize BigInts with a custom separator */
private def s(consts: Iterable[BigInt], sep: String)(implicit b: StringBuilder): Unit = {
val it = consts.iterator
- while(it.hasNext) {
+ while (it.hasNext) {
b ++= it.next().toString()
- if(it.hasNext) b ++= sep
+ if (it.hasNext) b ++= sep
}
}
}
diff --git a/src/main/scala/firrtl/ir/StructuralHash.scala b/src/main/scala/firrtl/ir/StructuralHash.scala
index 1b38dec1..f1ed91f3 100644
--- a/src/main/scala/firrtl/ir/StructuralHash.scala
+++ b/src/main/scala/firrtl/ir/StructuralHash.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
* of the same circuit and thus all modules referred to in DefInstance are the same.
*
* @author Kevin Laeufer <laeufer@cs.berkeley.edu>
- * */
+ */
object StructuralHash {
def sha256(node: DefModule, moduleRename: String => String = identity): HashCode = {
val m = MessageDigest.getInstance(SHA256)
@@ -59,19 +59,19 @@ object StructuralHash {
private val SHA256 = "SHA-256"
private def hash(node: FirrtlNode, h: Hasher, rename: String => String): Unit = node match {
- case n : Expression => new StructuralHash(h, rename).hash(n)
- case n : Statement => new StructuralHash(h, rename).hash(n)
- case n : Type => new StructuralHash(h, rename).hash(n)
- case n : Width => new StructuralHash(h, rename).hash(n)
- case n : Orientation => new StructuralHash(h, rename).hash(n)
- case n : Field => new StructuralHash(h, rename).hash(n)
- case n : Direction => new StructuralHash(h, rename).hash(n)
- case n : Port => new StructuralHash(h, rename).hash(n)
- case n : Param => new StructuralHash(h, rename).hash(n)
- case _ : Info => throw new RuntimeException("The structural hash of Info is meaningless.")
- case n : DefModule => new StructuralHash(h, rename).hash(n)
- case n : Circuit => hashCircuit(n, h, rename)
- case n : StringLit => h.update(n.toString)
+ case n: Expression => new StructuralHash(h, rename).hash(n)
+ case n: Statement => new StructuralHash(h, rename).hash(n)
+ case n: Type => new StructuralHash(h, rename).hash(n)
+ case n: Width => new StructuralHash(h, rename).hash(n)
+ case n: Orientation => new StructuralHash(h, rename).hash(n)
+ case n: Field => new StructuralHash(h, rename).hash(n)
+ case n: Direction => new StructuralHash(h, rename).hash(n)
+ case n: Port => new StructuralHash(h, rename).hash(n)
+ case n: Param => new StructuralHash(h, rename).hash(n)
+ case _: Info => throw new RuntimeException("The structural hash of Info is meaningless.")
+ case n: DefModule => new StructuralHash(h, rename).hash(n)
+ case n: Circuit => hashCircuit(n, h, rename)
+ case n: StringLit => h.update(n.toString)
}
private def hashModuleAndPortNames(m: DefModule, h: Hasher, rename: String => String): Unit = {
@@ -85,9 +85,9 @@ object StructuralHash {
}
private def hashPortTypeName(tpe: Type, h: String => Unit): Unit = tpe match {
- case BundleType(fields) => fields.foreach{ f => h(f.name) ; hashPortTypeName(f.tpe, h) }
- case VectorType(vt, _) => hashPortTypeName(vt, h)
- case _ => // ignore ground types since they do not have field names nor sub-types
+ case BundleType(fields) => fields.foreach { f => h(f.name); hashPortTypeName(f.tpe, h) }
+ case VectorType(vt, _) => hashPortTypeName(vt, h)
+ case _ => // ignore ground types since they do not have field names nor sub-types
}
private def hashCircuit(c: Circuit, h: Hasher, rename: String => String): Unit = {
@@ -101,8 +101,8 @@ object StructuralHash {
}
}
- private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map{ case (op, i) => op -> (-i -1).toByte }.toMap
- assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50")
+ private val primOpToId = PrimOps.builtinPrimOps.zipWithIndex.map { case (op, i) => op -> (-i - 1).toByte }.toMap
+ assert(primOpToId.values.max == -1, "PrimOp nodes use ids -1 ... -50")
assert(primOpToId.values.min >= -50, "PrimOp nodes use ids -1 ... -50")
private def primOp(p: PrimOp): Byte = primOpToId(p)
@@ -110,7 +110,7 @@ object StructuralHash {
private def verificationOp(op: Formal.Value): Byte = op match {
case Formal.Assert => 0
case Formal.Assume => 1
- case Formal.Cover => 2
+ case Formal.Cover => 2
}
}
@@ -129,14 +129,14 @@ private class MDHashCode(code: Array[Byte]) extends HashCode {
/** Generic hashing interface which allows us to use different backends to trade of speed and collision resistance */
private trait Hasher {
- def update(b: Byte): Unit
- def update(i: Int): Unit
- def update(l: Long): Unit
- def update(s: String): Unit
+ def update(b: Byte): Unit
+ def update(i: Int): Unit
+ def update(l: Long): Unit
+ def update(s: String): Unit
def update(b: Array[Byte]): Unit
def update(d: Double): Unit = update(java.lang.Double.doubleToRawLongBits(d))
- def update(i: BigInt): Unit = update(i.toByteArray)
- def update(b: Boolean): Unit = if(b) update(1.toByte) else update(0.toByte)
+ def update(i: BigInt): Unit = update(i.toByteArray)
+ def update(b: Boolean): Unit = if (b) update(1.toByte) else update(0.toByte)
def update(i: BigDecimal): Unit = {
// this might be broken, tried to borrow some code from BigDecimal.computeHashCode
val temp = i.bigDecimal.stripTrailingZeros()
@@ -149,14 +149,14 @@ private trait Hasher {
private class MessageDigestHasher(m: MessageDigest) extends Hasher {
override def update(b: Byte): Unit = m.update(b)
override def update(i: Int): Unit = {
- m.update(((i >> 0) & 0xff).toByte)
- m.update(((i >> 8) & 0xff).toByte)
+ m.update(((i >> 0) & 0xff).toByte)
+ m.update(((i >> 8) & 0xff).toByte)
m.update(((i >> 16) & 0xff).toByte)
m.update(((i >> 24) & 0xff).toByte)
}
override def update(l: Long): Unit = {
- m.update(((l >> 0) & 0xff).toByte)
- m.update(((l >> 8) & 0xff).toByte)
+ m.update(((l >> 0) & 0xff).toByte)
+ m.update(((l >> 8) & 0xff).toByte)
m.update(((l >> 16) & 0xff).toByte)
m.update(((l >> 24) & 0xff).toByte)
m.update(((l >> 32) & 0xff).toByte)
@@ -165,42 +165,47 @@ private class MessageDigestHasher(m: MessageDigest) extends Hasher {
m.update(((l >> 56) & 0xff).toByte)
}
// the encoding of the bytes should not matter as long as we are on the same platform
- override def update(s: String): Unit = m.update(s.getBytes())
+ override def update(s: String): Unit = m.update(s.getBytes())
override def update(b: Array[Byte]): Unit = m.update(b)
}
-class StructuralHash private(h: Hasher, renameModule: String => String) {
+class StructuralHash private (h: Hasher, renameModule: String => String) {
// replace identifiers with incrementing integers
private val nameToInt = mutable.HashMap[String, Int]()
private var nameCounter: Int = 0
- @inline private def n(name: String): Unit = hash(nameToInt.getOrElseUpdate(name, {
- val ii = nameCounter
- nameCounter = nameCounter + 1
- ii
- }))
+ @inline private def n(name: String): Unit = hash(
+ nameToInt.getOrElseUpdate(
+ name, {
+ val ii = nameCounter
+ nameCounter = nameCounter + 1
+ ii
+ }
+ )
+ )
// internal convenience methods
- @inline private def id(b: Byte): Unit = h.update(b)
- @inline private def hash(i: Int): Unit = h.update(i)
- @inline private def hash(b: Boolean): Unit = h.update(b)
- @inline private def hash(d: Double): Unit = h.update(d)
- @inline private def hash(i: BigInt): Unit = h.update(i)
+ @inline private def id(b: Byte): Unit = h.update(b)
+ @inline private def hash(i: Int): Unit = h.update(i)
+ @inline private def hash(b: Boolean): Unit = h.update(b)
+ @inline private def hash(d: Double): Unit = h.update(d)
+ @inline private def hash(i: BigInt): Unit = h.update(i)
@inline private def hash(i: BigDecimal): Unit = h.update(i)
- @inline private def hash(s: String): Unit = h.update(s)
+ @inline private def hash(s: String): Unit = h.update(s)
private def hash(node: Expression): Unit = node match {
- case Reference(name, _, _, _) => id(0) ; n(name)
+ case Reference(name, _, _, _) => id(0); n(name)
case DoPrim(op, args, consts, _) =>
// no need to hash the number of arguments or constants since that is implied by the op
- id(1) ; h.update(StructuralHash.primOp(op)) ; args.foreach(hash) ; consts.foreach(hash)
- case UIntLiteral(value, width) => id(2) ; hash(value) ; hash(width)
+ id(1); h.update(StructuralHash.primOp(op)); args.foreach(hash); consts.foreach(hash)
+ case UIntLiteral(value, width) => id(2); hash(value); hash(width)
// We hash bundles as if fields are accessed by their index.
// Thus we need to also hash field accesses that way.
// This has the side-effect that `x.y` might hash to the same value as `z.r`, for example if the
// types are `x: {y: UInt<1>, ...}` and `z: {r: UInt<1>, ...}` respectively.
// They do not hash to the same value if the type of `z` is e.g., `z: {..., r: UInt<1>, ...}`
// as that would have the `r` field at a different index.
- case SubField(expr, name, _, _) => id(3) ; hash(expr)
+ case SubField(expr, name, _, _) =>
+ id(3); hash(expr)
// find field index and hash that instead of the field name
val fields = expr.tpe match {
case b: BundleType => b.fields
@@ -209,93 +214,115 @@ class StructuralHash private(h: Hasher, renameModule: String => String) {
}
val index = fields.zipWithIndex.find(_._1.name == name).map(_._2).get
hash(index)
- case SubIndex(expr, value, _, _) => id(4) ; hash(expr) ; hash(value)
- case SubAccess(expr, index, _, _) => id(5) ; hash(expr) ; hash(index)
- case Mux(cond, tval, fval, _) => id(6) ; hash(cond) ; hash(tval) ; hash(fval)
- case ValidIf(cond, value, _) => id(7) ; hash(cond) ; hash(value)
- case SIntLiteral(value, width) => id(8) ; hash(value) ; hash(width)
- case FixedLiteral(value, width, point) => id(9) ; hash(value) ; hash(width) ; hash(point)
+ case SubIndex(expr, value, _, _) => id(4); hash(expr); hash(value)
+ case SubAccess(expr, index, _, _) => id(5); hash(expr); hash(index)
+ case Mux(cond, tval, fval, _) => id(6); hash(cond); hash(tval); hash(fval)
+ case ValidIf(cond, value, _) => id(7); hash(cond); hash(value)
+ case SIntLiteral(value, width) => id(8); hash(value); hash(width)
+ case FixedLiteral(value, width, point) => id(9); hash(value); hash(width); hash(point)
// WIR
- case firrtl.WVoid => id(10)
- case firrtl.WInvalid => id(11)
+ case firrtl.WVoid => id(10)
+ case firrtl.WInvalid => id(11)
case firrtl.EmptyExpression => id(12)
// VRandom is used in the Emitter
- case firrtl.VRandom(width) => id(13) ; hash(width)
+ case firrtl.VRandom(width) => id(13); hash(width)
// ids 14 ... 19 are reserved for future Expression nodes
}
private def hash(node: Statement): Unit = node match {
// all info fields are ignore
- case DefNode(_, name, value) => id(20) ; n(name) ; hash(value)
- case Connect(_, loc, expr) => id(21) ; hash(loc) ; hash(expr)
+ case DefNode(_, name, value) => id(20); n(name); hash(value)
+ case Connect(_, loc, expr) => id(21); hash(loc); hash(expr)
// we place the unique id 23 between conseq and alt to distinguish between them in case conseq is empty
// we place the unique id 24 after alt to distinguish between alt and the next statement in case alt is empty
- case Conditionally(_, pred, conseq, alt) => id(22) ; hash(pred) ; hash(conseq) ; id(23) ; hash(alt) ; id(24)
- case EmptyStmt => // empty statements are ignored
- case Block(stmts) => stmts.foreach(hash) // block structure is ignored
- case Stop(_, ret, clk, en) => id(25) ; hash(ret) ; hash(clk) ; hash(en)
- case Print(_, string, args, clk, en) =>
+ case Conditionally(_, pred, conseq, alt) => id(22); hash(pred); hash(conseq); id(23); hash(alt); id(24)
+ case EmptyStmt => // empty statements are ignored
+ case Block(stmts) => stmts.foreach(hash) // block structure is ignored
+ case Stop(_, ret, clk, en) => id(25); hash(ret); hash(clk); hash(en)
+ case Print(_, string, args, clk, en) =>
// the string is part of the side effect and thus part of the circuit behavior
- id(26) ; hash(string.string) ; hash(args.length) ; args.foreach(hash) ; hash(clk) ; hash(en)
- case IsInvalid(_, expr) => id(27) ; hash(expr)
- case DefWire(_, name, tpe) => id(28) ; n(name) ; hash(tpe)
+ id(26); hash(string.string); hash(args.length); args.foreach(hash); hash(clk); hash(en)
+ case IsInvalid(_, expr) => id(27); hash(expr)
+ case DefWire(_, name, tpe) => id(28); n(name); hash(tpe)
case DefRegister(_, name, tpe, clock, reset, init) =>
- id(29) ; n(name) ; hash(tpe) ; hash(clock) ; hash(reset) ; hash(init)
+ id(29); n(name); hash(tpe); hash(clock); hash(reset); hash(init)
case DefInstance(_, name, module, _) =>
// Module is in the global namespace which is why we cannot replace it with a numeric id.
// However, it might have been renamed as part of the dedup consolidation.
- id(30) ; n(name) ; hash(renameModule(module))
+ id(30); n(name); hash(renameModule(module))
// descriptions on statements are ignores
case firrtl.DescribedStmt(_, stmt) => hash(stmt)
- case DefMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers,
- readwriters, readUnderWrite) =>
- id(30) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency)
- hash(readers.length) ; readers.foreach(hash)
- hash(writers.length) ; writers.foreach(hash)
- hash(readwriters.length) ; readwriters.foreach(hash)
+ case DefMemory(
+ _,
+ name,
+ dataType,
+ depth,
+ writeLatency,
+ readLatency,
+ readers,
+ writers,
+ readwriters,
+ readUnderWrite
+ ) =>
+ id(30); n(name); hash(dataType); hash(depth); hash(writeLatency); hash(readLatency)
+ hash(readers.length); readers.foreach(hash)
+ hash(writers.length); writers.foreach(hash)
+ hash(readwriters.length); readwriters.foreach(hash)
hash(readUnderWrite)
- case PartialConnect(_, loc, expr) => id(31) ; hash(loc) ; hash(expr)
- case Attach(_, exprs) => id(32) ; hash(exprs.length) ; exprs.foreach(hash)
+ case PartialConnect(_, loc, expr) => id(31); hash(loc); hash(expr)
+ case Attach(_, exprs) => id(32); hash(exprs.length); exprs.foreach(hash)
// WIR
case firrtl.CDefMemory(_, name, tpe, size, seq, readUnderWrite) =>
- id(33) ; n(name) ; hash(tpe); hash(size) ; hash(seq) ; hash(readUnderWrite)
+ id(33); n(name); hash(tpe); hash(size); hash(seq); hash(readUnderWrite)
case firrtl.CDefMPort(_, name, _, mem, exps, direction) =>
// the type of the MPort depends only on the memory (in well types firrtl) and can thus be ignored
- id(34) ; n(name) ; n(mem) ; hash(exps.length) ; exps.foreach(hash) ; hash(direction)
+ id(34); n(name); n(mem); hash(exps.length); exps.foreach(hash); hash(direction)
// DefAnnotatedMemory from MemIR.scala
- case firrtl.passes.memlib.DefAnnotatedMemory(_, name, dataType, depth, writeLatency, readLatency, readers, writers,
- readwriters, readUnderWrite, maskGran, memRef) =>
- id(35) ; n(name) ; hash(dataType) ; hash(depth) ; hash(writeLatency) ; hash(readLatency)
- hash(readers.length) ; readers.foreach(hash)
- hash(writers.length) ; writers.foreach(hash)
- hash(readwriters.length) ; readwriters.foreach(hash)
+ case firrtl.passes.memlib.DefAnnotatedMemory(
+ _,
+ name,
+ dataType,
+ depth,
+ writeLatency,
+ readLatency,
+ readers,
+ writers,
+ readwriters,
+ readUnderWrite,
+ maskGran,
+ memRef
+ ) =>
+ id(35); n(name); hash(dataType); hash(depth); hash(writeLatency); hash(readLatency)
+ hash(readers.length); readers.foreach(hash)
+ hash(writers.length); writers.foreach(hash)
+ hash(readwriters.length); readwriters.foreach(hash)
hash(readUnderWrite.toString)
- hash(maskGran.size) ; maskGran.foreach(hash)
- hash(memRef.size) ; memRef.foreach{ case (a, b) => hash(a) ; hash(b) }
+ hash(maskGran.size); maskGran.foreach(hash)
+ hash(memRef.size); memRef.foreach { case (a, b) => hash(a); hash(b) }
case Verification(op, _, clk, pred, en, msg) =>
- id(36) ; hash(StructuralHash.verificationOp(op)) ; hash(clk) ; hash(pred) ; hash(en) ; hash(msg.string)
+ id(36); hash(StructuralHash.verificationOp(op)); hash(clk); hash(pred); hash(en); hash(msg.string)
// ids 37 ... 39 are reserved for future Statement nodes
}
// ReadUnderWrite is never used in place of a FirrtlNode and thus we can start a new id namespace
private def hash(ruw: ReadUnderWrite.Value): Unit = ruw match {
- case ReadUnderWrite.New => id(0)
- case ReadUnderWrite.Old => id(1)
+ case ReadUnderWrite.New => id(0)
+ case ReadUnderWrite.Old => id(1)
case ReadUnderWrite.Undefined => id(2)
}
private def hash(node: Width): Unit = node match {
- case IntWidth(width) => id(40) ; hash(width)
- case UnknownWidth => id(41)
- case CalcWidth(arg) => id(42) ; hash(arg)
+ case IntWidth(width) => id(40); hash(width)
+ case UnknownWidth => id(41)
+ case CalcWidth(arg) => id(42); hash(arg)
// we are hashing the name of the `VarWidth` instead of using `n` since these Vars exist in a different namespace
- case VarWidth(name) => id(43) ; hash(name)
+ case VarWidth(name) => id(43); hash(name)
// ids 44 + 45 are reserved for future Width nodes
}
private def hash(node: Orientation): Unit = node match {
case Default => id(46)
- case Flip => id(47)
+ case Flip => id(47)
}
private def hash(node: Field): Unit = {
@@ -306,81 +333,81 @@ class StructuralHash private(h: Hasher, renameModule: String => String) {
// has been used in the Dedup pass for a long time.
// This position-based notion of equality requires us to replace field names with field indexes when hashing
// SubField accesses.
- id(48) ; hash(node.flip) ; hash(node.tpe)
+ id(48); hash(node.flip); hash(node.tpe)
}
private def hash(node: Type): Unit = node match {
// Types
- case UIntType(width: Width) => id(50) ; hash(width)
- case SIntType(width: Width) => id(51) ; hash(width)
- case FixedType(width, point) => id(52) ; hash(width) ; hash(point)
- case BundleType(fields) => id(53) ; hash(fields.length) ; fields.foreach(hash)
- case VectorType(tpe, size) => id(54) ; hash(tpe) ; hash(size)
- case ClockType => id(55)
- case ResetType => id(56)
- case AsyncResetType => id(57)
- case AnalogType(width) => id(58) ; hash(width)
- case UnknownType => id(59)
- case IntervalType(lower, upper, point) => id(60) ; hash(lower) ; hash(upper) ; hash(point)
+ case UIntType(width: Width) => id(50); hash(width)
+ case SIntType(width: Width) => id(51); hash(width)
+ case FixedType(width, point) => id(52); hash(width); hash(point)
+ case BundleType(fields) => id(53); hash(fields.length); fields.foreach(hash)
+ case VectorType(tpe, size) => id(54); hash(tpe); hash(size)
+ case ClockType => id(55)
+ case ResetType => id(56)
+ case AsyncResetType => id(57)
+ case AnalogType(width) => id(58); hash(width)
+ case UnknownType => id(59)
+ case IntervalType(lower, upper, point) => id(60); hash(lower); hash(upper); hash(point)
// ids 61 ... 65 are reserved for future Type nodes
}
private def hash(node: Direction): Unit = node match {
- case Input => id(66)
+ case Input => id(66)
case Output => id(67)
}
private def hash(node: Port): Unit = {
- id(68) ; n(node.name) ; hash(node.direction) ; hash(node.tpe)
+ id(68); n(node.name); hash(node.direction); hash(node.tpe)
}
private def hash(node: Param): Unit = node match {
- case IntParam(name, value) => id(70) ; n(name) ; hash(value)
- case DoubleParam(name, value) => id(71) ; n(name) ; hash(value)
- case StringParam(name, value) => id(72) ; n(name) ; hash(value.string)
- case RawStringParam(name, value) => id(73) ; n(name) ; hash(value)
+ case IntParam(name, value) => id(70); n(name); hash(value)
+ case DoubleParam(name, value) => id(71); n(name); hash(value)
+ case StringParam(name, value) => id(72); n(name); hash(value.string)
+ case RawStringParam(name, value) => id(73); n(name); hash(value)
// id 74 is reserved for future use
}
private def hash(node: DefModule): Unit = node match {
// the module name is ignored since it does not affect module functionality
case Module(_, _name, ports, body) =>
- id(75) ; hash(ports.length) ; ports.foreach(hash) ; hash(body)
+ id(75); hash(ports.length); ports.foreach(hash); hash(body)
// the module name is ignored since it does not affect module functionality
case ExtModule(_, name, ports, defname, params) =>
- id(76) ; hash(ports.length) ; ports.foreach(hash) ; hash(defname)
- hash(params.length) ; params.foreach(hash)
+ id(76); hash(ports.length); ports.foreach(hash); hash(defname)
+ hash(params.length); params.foreach(hash)
}
// id 127 is reserved for Circuit nodes
private def hash(d: firrtl.MPortDir): Unit = d match {
- case firrtl.MInfer => id(-70)
- case firrtl.MRead => id(-71)
- case firrtl.MWrite => id(-72)
+ case firrtl.MInfer => id(-70)
+ case firrtl.MRead => id(-71)
+ case firrtl.MWrite => id(-72)
case firrtl.MReadWrite => id(-73)
}
private def hash(c: firrtl.constraint.Constraint): Unit = c match {
case b: Bound => hash(b) /* uses ids -80 ... -84 */
case firrtl.constraint.IsAdd(known, maxs, mins, others) =>
- id(-85) ; hash(known.nonEmpty) ; known.foreach(hash)
- hash(maxs.length) ; maxs.foreach(hash)
- hash(mins.length) ; mins.foreach(hash)
- hash(others.length) ; others.foreach(hash)
- case firrtl.constraint.IsFloor(child, dummyArg) => id(-86) ; hash(child) ; hash(dummyArg)
- case firrtl.constraint.IsKnown(decimal) => id(-87) ; hash(decimal)
- case firrtl.constraint.IsNeg(child, dummyArg) => id(-88) ; hash(child) ; hash(dummyArg)
- case firrtl.constraint.IsPow(child, dummyArg) => id(-89) ; hash(child) ; hash(dummyArg)
- case firrtl.constraint.IsVar(str) => id(-90) ; n(str)
+ id(-85); hash(known.nonEmpty); known.foreach(hash)
+ hash(maxs.length); maxs.foreach(hash)
+ hash(mins.length); mins.foreach(hash)
+ hash(others.length); others.foreach(hash)
+ case firrtl.constraint.IsFloor(child, dummyArg) => id(-86); hash(child); hash(dummyArg)
+ case firrtl.constraint.IsKnown(decimal) => id(-87); hash(decimal)
+ case firrtl.constraint.IsNeg(child, dummyArg) => id(-88); hash(child); hash(dummyArg)
+ case firrtl.constraint.IsPow(child, dummyArg) => id(-89); hash(child); hash(dummyArg)
+ case firrtl.constraint.IsVar(str) => id(-90); n(str)
}
private def hash(b: Bound): Unit = b match {
case UnknownBound => id(-80)
- case CalcBound(arg) => id(-81) ; hash(arg)
+ case CalcBound(arg) => id(-81); hash(arg)
// we are hashing the name of the `VarBound` instead of using `n` since these Vars exist in a different namespace
- case VarBound(name) => id(-82) ; hash(name)
- case Open(value) => id(-83) ; hash(value)
- case Closed(value) => id(-84) ; hash(value)
+ case VarBound(name) => id(-82); hash(name)
+ case Open(value) => id(-83); hash(value)
+ case Closed(value) => id(-84); hash(value)
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/options/DependencyManager.scala b/src/main/scala/firrtl/options/DependencyManager.scala
index ee6a7404..561e32ab 100644
--- a/src/main/scala/firrtl/options/DependencyManager.scala
+++ b/src/main/scala/firrtl/options/DependencyManager.scala
@@ -3,7 +3,7 @@
package firrtl.options
import firrtl.AnnotationSeq
-import firrtl.graph.{DiGraph, CyclicException}
+import firrtl.graph.{CyclicException, DiGraph}
import scala.collection.Set
import scala.collection.immutable.{Set => ISet}
@@ -22,7 +22,6 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
override def prerequisites = currentState
-
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf = Seq.empty
@@ -34,13 +33,13 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
*/
def targets: Seq[Dependency[B]]
private lazy val _targets: LinkedHashSet[Dependency[B]] = targets
- .foldLeft(new LinkedHashSet[Dependency[B]]()){ case (a, b) => a += b }
+ .foldLeft(new LinkedHashSet[Dependency[B]]()) { case (a, b) => a += b }
/** A sequence of [[firrtl.Transform]]s that have been run. Internally, this will be converted to an ordered set.
*/
def currentState: Seq[Dependency[B]]
private lazy val _currentState: LinkedHashSet[Dependency[B]] = currentState
- .foldLeft(new LinkedHashSet[Dependency[B]]()){ case (a, b) => a += b }
+ .foldLeft(new LinkedHashSet[Dependency[B]]()) { case (a, b) => a += b }
/** Existing transform objects that have already been constructed */
def knownObjects: Set[B]
@@ -64,9 +63,10 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
* requirements. This is used to solve sub-problems arising from invalidations.
*/
protected def copy(
- targets: Seq[Dependency[B]],
+ targets: Seq[Dependency[B]],
currentState: Seq[Dependency[B]],
- knownObjects: ISet[B] = dependencyToObject.values.toSet): B
+ knownObjects: ISet[B] = dependencyToObject.values.toSet
+ ): B
/** Implicit conversion from Dependency to B */
private implicit def dToO(d: Dependency[B]): B = dependencyToObject.getOrElseUpdate(d, d.getObject())
@@ -77,14 +77,16 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
/** Modified breadth-first search that supports multiple starting nodes and a custom extractor that can be used to
* generate/filter the edges to explore. Additionally, this will include edges to previously discovered nodes.
*/
- private def bfs( start: LinkedHashSet[Dependency[B]],
- blacklist: LinkedHashSet[Dependency[B]],
- extractor: B => Set[Dependency[B]] ): LinkedHashMap[B, LinkedHashSet[B]] = {
+ private def bfs(
+ start: LinkedHashSet[Dependency[B]],
+ blacklist: LinkedHashSet[Dependency[B]],
+ extractor: B => Set[Dependency[B]]
+ ): LinkedHashMap[B, LinkedHashSet[B]] = {
val (queue, edges) = {
- val a: Queue[Dependency[B]] = Queue(start.toSeq:_*)
- val b: LinkedHashMap[B, LinkedHashSet[B]] = LinkedHashMap[B, LinkedHashSet[B]](
- start.map((dToO(_) -> LinkedHashSet[B]())).toSeq:_*)
+ val a: Queue[Dependency[B]] = Queue(start.toSeq: _*)
+ val b: LinkedHashMap[B, LinkedHashSet[B]] =
+ LinkedHashMap[B, LinkedHashSet[B]](start.map((dToO(_) -> LinkedHashSet[B]())).toSeq: _*)
(a, b)
}
@@ -117,7 +119,8 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
val edges = bfs(
start = _targets &~ _currentState,
blacklist = _currentState,
- extractor = (p: B) => p._prerequisites &~ _currentState)
+ extractor = (p: B) => p._prerequisites &~ _currentState
+ )
DiGraph(edges)
}
@@ -144,11 +147,14 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
val edges = {
val x = new LinkedHashMap ++ _targets
.map(dependencyToObject)
- .map{ a => a -> prerequisiteGraph.getVertices.filter(a._optionalPrerequisiteOf(_)) }
- x
- .values
+ .map { a => a -> prerequisiteGraph.getVertices.filter(a._optionalPrerequisiteOf(_)) }
+ x.values
.reduce(_ ++ _)
- .foldLeft(x){ case (xx, y) => if (xx.contains(y)) { xx } else { xx ++ Map(y -> Set.empty[B]) } }
+ .foldLeft(x) {
+ case (xx, y) =>
+ if (xx.contains(y)) { xx }
+ else { xx ++ Map(y -> Set.empty[B]) }
+ }
}
DiGraph(edges).reverse
}
@@ -165,23 +171,26 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
bfs(
start = v.map(oToD(_)),
blacklist = _currentState,
-
/* Explore all invalidated transforms **EXCEPT** the current transform! */
extractor = (p: B) => {
val filtered = new LinkedHashSet[Dependency[B]]
filtered ++= v.filter(p.invalidates).map(oToD(_))
filtered -= oToD(p)
filtered
- })
+ }
+ )
).reverse
}
/** Wrap a possible [[CyclicException]] thrown by a thunk in a [[DependencyManagerException]] */
- private def cyclePossible[A](a: String, diGraph: DiGraph[_])(thunk: => A): A = try { thunk } catch {
+ private def cyclePossible[A](a: String, diGraph: DiGraph[_])(thunk: => A): A = try { thunk }
+ catch {
case e: CyclicException =>
throw new DependencyManagerException(
s"""|No transform ordering possible due to cyclic dependency in $a with cycles:
- |${diGraph.findSCCs.filter(_.size > 1).mkString(" - ", "\n - ", "")}""".stripMargin, e)
+ |${diGraph.findSCCs.filter(_.size > 1).mkString(" - ", "\n - ", "")}""".stripMargin,
+ e
+ )
}
/** An ordering of [[firrtl.options.TransformLike TransformLike]]s that causes the requested [[DependencyManager.targets
@@ -198,38 +207,39 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
*/
val sorted = {
val edges = {
- val v = cyclePossible("invalidates", invalidateGraph){ invalidateGraph.linearize }.reverse
+ val v = cyclePossible("invalidates", invalidateGraph) { invalidateGraph.linearize }.reverse
/* A comparison function that will sort vertices based on the topological sort of the invalidation graph */
val cmp =
- (l: B, r: B) => v.foldLeft((Map.empty[B, Dependency[B] => Boolean], Set.empty[Dependency[B]])){
- case ((m, s), r) => (m + (r -> ((a: Dependency[B]) => !s(a))), s + r) }._1(l)(r)
+ (l: B, r: B) =>
+ v.foldLeft((Map.empty[B, Dependency[B] => Boolean], Set.empty[Dependency[B]])) {
+ case ((m, s), r) => (m + (r -> ((a: Dependency[B]) => !s(a))), s + r)
+ }._1(l)(r)
new LinkedHashMap() ++
v.map(vv => vv -> (new LinkedHashSet() ++ (dependencyGraph.getEdges(vv).toSeq.sortWith(cmp))))
}
cyclePossible("prerequisites", dependencyGraph) {
- DiGraph(edges)
- .linearize
- .reverse
+ DiGraph(edges).linearize.reverse
.dropWhile(b => _currentState.contains(b))
}
}
/* [todo] Seq is inefficient here, but Array has ClassTag problems. Use something else? */
- val (s, l) = sorted.foldLeft((_currentState, Seq[B]())){ case ((state, out), in) =>
- val prereqs = in._prerequisites ++
- dependencyGraph.getEdges(in).toSeq.map(oToD) ++
- otherPrerequisites.getEdges(in).toSeq.map(oToD)
- val preprocessing: Option[B] = {
- if ((prereqs -- state).nonEmpty) { Some(this.copy(prereqs.toSeq, state.toSeq)) }
- else { None }
- }
- /* "in" is added *after* invalidation because a transform my not invalidate itself! */
- ((state ++ prereqs).map(dToO).filterNot(in.invalidates).map(oToD) + in, out ++ preprocessing :+ in)
+ val (s, l) = sorted.foldLeft((_currentState, Seq[B]())) {
+ case ((state, out), in) =>
+ val prereqs = in._prerequisites ++
+ dependencyGraph.getEdges(in).toSeq.map(oToD) ++
+ otherPrerequisites.getEdges(in).toSeq.map(oToD)
+ val preprocessing: Option[B] = {
+ if ((prereqs -- state).nonEmpty) { Some(this.copy(prereqs.toSeq, state.toSeq)) }
+ else { None }
+ }
+ /* "in" is added *after* invalidation because a transform my not invalidate itself! */
+ ((state ++ prereqs).map(dToO).filterNot(in.invalidates).map(oToD) + in, out ++ preprocessing :+ in)
}
val postprocessing: Option[B] = {
if ((_targets -- s).nonEmpty) { Some(this.copy(_targets.toSeq, s.toSeq)) }
- else { None }
+ else { None }
}
l ++ postprocessing
}
@@ -252,20 +262,21 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
* applied while tracking the state of the underlying A. If the state ever disagrees with a prerequisite, then this
* throws an exception.
*/
- flattenedTransformOrder
- .map{ t =>
- val w = wrappers.foldLeft(t){ case (tx, wrapper) => wrapper(tx) }
- wrapperToClass += (w -> t)
- w
- }.foldLeft((annotations, _currentState)){ case ((a, state), t) =>
- if (!t.prerequisites.toSet.subsetOf(state)) {
- throw new DependencyManagerException(
- s"""|Tried to execute '$t' for which run-time prerequisites were not satisfied:
- | state: ${state.mkString("\n -", "\n -", "")}
- | prerequisites: ${prerequisites.mkString("\n -", "\n -", "")}""".stripMargin)
- }
- (t.transform(a), ((state + wrapperToClass(t)).map(dToO).filterNot(t.invalidates).map(oToD)))
- }._1
+ flattenedTransformOrder.map { t =>
+ val w = wrappers.foldLeft(t) { case (tx, wrapper) => wrapper(tx) }
+ wrapperToClass += (w -> t)
+ w
+ }.foldLeft((annotations, _currentState)) {
+ case ((a, state), t) =>
+ if (!t.prerequisites.toSet.subsetOf(state)) {
+ throw new DependencyManagerException(
+ s"""|Tried to execute '$t' for which run-time prerequisites were not satisfied:
+ | state: ${state.mkString("\n -", "\n -", "")}
+ | prerequisites: ${prerequisites.mkString("\n -", "\n -", "")}""".stripMargin
+ )
+ }
+ (t.transform(a), ((state + wrapperToClass(t)).map(dToO).filterNot(t.invalidates).map(oToD)))
+ }._1
}
/** This colormap uses Colorbrewer's 4-class OrRd color scheme */
@@ -282,13 +293,13 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
def toGraphviz(digraph: DiGraph[B], attributes: String = "", tab: String = " "): Option[String] = {
val edges =
- digraph
- .getEdgeMap
- .collect{ case (v, edges) if edges.nonEmpty => (v -> edges) }
- .map{ case (v, edges) =>
- s"""${transformName(v)} -> ${edges.map(e => transformName(e)).mkString("{ ", " ", " }")}""" }
+ digraph.getEdgeMap.collect { case (v, edges) if edges.nonEmpty => (v -> edges) }.map {
+ case (v, edges) =>
+ s"""${transformName(v)} -> ${edges.map(e => transformName(e)).mkString("{ ", " ", " }")}"""
+ }
- if (edges.isEmpty) { None } else {
+ if (edges.isEmpty) { None }
+ else {
Some(
s"""| { $attributes
|${edges.mkString(tab, "\n" + tab, "")}
@@ -298,16 +309,16 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
}
val connections =
- Seq( (prerequisiteGraph, "edge []"),
- (optionalPrerequisiteOfGraph, """edge [style=bold color="#4292c6"]"""),
- (invalidateGraph, """edge [minlen=2 style=dashed constraint=false color="#fb6a4a"]"""),
- (optionalPrerequisitesGraph, """edge [style=dotted color="#a1d99b"]""") )
- .flatMap{ case (a, b) => toGraphviz(a, b) }
+ Seq(
+ (prerequisiteGraph, "edge []"),
+ (optionalPrerequisiteOfGraph, """edge [style=bold color="#4292c6"]"""),
+ (invalidateGraph, """edge [minlen=2 style=dashed constraint=false color="#fb6a4a"]"""),
+ (optionalPrerequisitesGraph, """edge [style=dotted color="#a1d99b"]""")
+ ).flatMap { case (a, b) => toGraphviz(a, b) }
.mkString("\n")
val nodes =
- (prerequisiteGraph + optionalPrerequisiteOfGraph + invalidateGraph + otherPrerequisites)
- .getVertices
+ (prerequisiteGraph + optionalPrerequisiteOfGraph + invalidateGraph + otherPrerequisites).getVertices
.map(v => s"""${transformName(v)} [label="${v.name}"]""")
s"""|digraph DependencyManager {
@@ -322,9 +333,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
def transformOrderToGraphviz(colormap: Seq[String] = colormap): String = {
def rotate[A](a: Seq[A]): Seq[A] = a match {
- case Nil => Nil
+ case Nil => Nil
case car :: cdr => cdr :+ car
- case car => car
+ case car => car
}
val sorted = ArrayBuffer.empty[String]
@@ -340,7 +351,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
|$tab labeljust=l
|$tab node [fillcolor="${cm.head}"]""".stripMargin
- val body = pm.transformOrder.map{
+ val body = pm.transformOrder.map {
case a: DependencyManager[A, B] =>
val (str, d) = rec(a, rotate(cm), tab + " ", offset + 1)
offset = d
@@ -369,9 +380,10 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
* @param size the number of nodes at the current level of the tree
*/
def customPrintHandling(
- tab: String,
+ tab: String,
charSet: CharSet,
- size: Int): Option[PartialFunction[(B, Int), Seq[String]]] = None
+ size: Int
+ ): Option[PartialFunction[(B, Int), Seq[String]]] = None
/** Helper utility when recursing during pretty printing
* @param tab an indentation string to use for every line of output
@@ -386,9 +398,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
val defaultHandling: PartialFunction[(B, Int), Seq[String]] = {
case (a: DependencyManager[_, _], `last`) =>
Seq(s"$tab$l ${a.name}") ++ a.prettyPrintRec(s"""$tab${" " * c.size} """, charSet)
- case (a: DependencyManager[_, _], _) => Seq(s"$tab$n ${a.name}") ++ a.prettyPrintRec(s"$tab$c ", charSet)
- case (a, `last`) => Seq(s"$tab$l ${a.name}")
- case (a, _) => Seq(s"$tab$n ${a.name}")
+ case (a: DependencyManager[_, _], _) => Seq(s"$tab$n ${a.name}") ++ a.prettyPrintRec(s"$tab$c ", charSet)
+ case (a, `last`) => Seq(s"$tab$l ${a.name}")
+ case (a, _) => Seq(s"$tab$n ${a.name}")
}
val handling = customPrintHandling(tab, charSet, transformOrder.size) match {
@@ -396,8 +408,7 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
case None => defaultHandling
}
- transformOrder
- .zipWithIndex
+ transformOrder.zipWithIndex
.flatMap(handling)
}
@@ -406,8 +417,9 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
* @param charSet a collection of characters to use when printing
*/
def prettyPrint(
- tab: String = "",
- charSet: DependencyManagerUtils.CharSet = DependencyManagerUtils.PrettyCharSet): String = {
+ tab: String = "",
+ charSet: DependencyManagerUtils.CharSet = DependencyManagerUtils.PrettyCharSet
+ ): String = {
(Seq(s"$tab$name") ++ prettyPrintRec(tab, charSet)).mkString("\n")
@@ -422,9 +434,11 @@ trait DependencyManager[A, B <: TransformLike[A] with DependencyAPI[B]] extends
* @param targets the [[Phase]]s you want to run
*/
class PhaseManager(
- val targets: Seq[PhaseManager.PhaseDependency],
+ val targets: Seq[PhaseManager.PhaseDependency],
val currentState: Seq[PhaseManager.PhaseDependency] = Seq.empty,
- val knownObjects: Set[Phase] = Set.empty) extends DependencyManager[AnnotationSeq, Phase] with Phase {
+ val knownObjects: Set[Phase] = Set.empty)
+ extends DependencyManager[AnnotationSeq, Phase]
+ with Phase {
import PhaseManager.PhaseDependency
protected def copy(a: Seq[PhaseDependency], b: Seq[PhaseDependency], c: ISet[Phase]) = new PhaseManager(a, b, c)
@@ -444,6 +458,7 @@ object DependencyManagerUtils {
* @see [[ASCIICharSet]]
*/
trait CharSet {
+
/** Used when printing the last node */
val lastNode: String
@@ -456,15 +471,15 @@ object DependencyManagerUtils {
/** Uses prettier characters, but possibly not supported by all fonts */
object PrettyCharSet extends CharSet {
- val lastNode = "└──"
- val notLastNode = "├──"
+ val lastNode = "└──"
+ val notLastNode = "├──"
val continuation = "│ "
}
/** Basic ASCII output */
object ASCIICharSet extends CharSet {
- val lastNode = "\\--"
- val notLastNode = "|--"
+ val lastNode = "\\--"
+ val notLastNode = "|--"
val continuation = "| "
}
diff --git a/src/main/scala/firrtl/options/ExitCodes.scala b/src/main/scala/firrtl/options/ExitCodes.scala
index 0e91fdec..94e525de 100644
--- a/src/main/scala/firrtl/options/ExitCodes.scala
+++ b/src/main/scala/firrtl/options/ExitCodes.scala
@@ -6,7 +6,7 @@ package firrtl.options
sealed trait ExitCode { val number: Int }
/** [[ExitCode]] indicating success */
-object ExitSuccess extends ExitCode{ val number = 0 }
+object ExitSuccess extends ExitCode { val number = 0 }
/** An [[ExitCode]] indicative of failure. This must be non-zero and should not conflict with a reserved exit code. */
sealed trait ExitFailure extends ExitCode
diff --git a/src/main/scala/firrtl/options/OptionParser.scala b/src/main/scala/firrtl/options/OptionParser.scala
index 9360a961..e7ea68bf 100644
--- a/src/main/scala/firrtl/options/OptionParser.scala
+++ b/src/main/scala/firrtl/options/OptionParser.scala
@@ -9,7 +9,8 @@ import scopt.OptionParser
case object OptionsHelpException extends Exception("Usage help invoked")
/** OptionParser mixin that causes the OptionParser to not call exit (call `sys.exit`) if the `--help` option is
- * passed */
+ * passed
+ */
trait DoNotTerminateOnExit { this: OptionParser[_] =>
override def terminate(exitState: Either[String, Unit]): Unit = ()
}
@@ -33,16 +34,18 @@ trait DuplicateHandling extends OptionParser[AnnotationSeq] {
/** Message for found duplicate options */
def msg(x: String, y: String) = s"""Duplicate $x "$y" (did your custom Transform or OptionsManager add this?)"""
- val longDups = options.map(_.name).groupBy(identity).collect{ case (k, v) if v.size > 1 && k != "" => k }
- val shortDups = options.map(_.shortOpt).flatten.groupBy(identity).collect{ case (k, v) if v.size > 1 => k }
-
+ val longDups = options.map(_.name).groupBy(identity).collect { case (k, v) if v.size > 1 && k != "" => k }
+ val shortDups = options.map(_.shortOpt).flatten.groupBy(identity).collect { case (k, v) if v.size > 1 => k }
- if (longDups.nonEmpty) {
+ if (longDups.nonEmpty) {
throw new OptionsException(msg("long option", longDups.map("--" + _).mkString(",")), new IllegalArgumentException)
}
if (shortDups.nonEmpty) {
- throw new OptionsException(msg("short option", shortDups.map("-" + _).mkString(",")), new IllegalArgumentException)
+ throw new OptionsException(
+ msg("short option", shortDups.map("-" + _).mkString(",")),
+ new IllegalArgumentException
+ )
}
super.parse(args, init)
diff --git a/src/main/scala/firrtl/options/Phase.scala b/src/main/scala/firrtl/options/Phase.scala
index 2a68251d..6a3f4a8c 100644
--- a/src/main/scala/firrtl/options/Phase.scala
+++ b/src/main/scala/firrtl/options/Phase.scala
@@ -12,7 +12,7 @@ import scala.reflect
import scala.reflect.ClassTag
object Dependency {
- def apply[A <: DependencyAPI[_] : ClassTag]: Dependency[A] = {
+ def apply[A <: DependencyAPI[_]: ClassTag]: Dependency[A] = {
val clazz = reflect.classTag[A].runtimeClass
Dependency(Left(clazz.asInstanceOf[Class[A]]))
}
@@ -40,26 +40,30 @@ object Dependency {
case class Dependency[+A <: DependencyAPI[_]](id: Either[Class[_ <: A], A with Singleton]) {
def getObject(): A = id match {
- case Left(c) => safeConstruct(c)
+ case Left(c) => safeConstruct(c)
case Right(o) => o
}
def getSimpleName: String = id match {
- case Left(c) => c.getSimpleName
+ case Left(c) => c.getSimpleName
case Right(o) => o.getClass.getSimpleName
}
def getName: String = id match {
- case Left(c) => c.getName
+ case Left(c) => c.getName
case Right(o) => o.getClass.getName
}
/** Wrap an [[IllegalAccessException]] due to attempted object construction in a [[DependencyManagerException]] */
- private def safeConstruct[A](a: Class[_ <: A]): A = try { a.newInstance } catch {
- case e: IllegalAccessException => throw new DependencyManagerException(
- s"Failed to construct '$a'! (Did you try to construct an object?)", e)
- case e: InstantiationException => throw new DependencyManagerException(
- s"Failed to construct '$a'! (Did you try to construct an inner class or a class with parameters?)", e)
+ private def safeConstruct[A](a: Class[_ <: A]): A = try { a.newInstance }
+ catch {
+ case e: IllegalAccessException =>
+ throw new DependencyManagerException(s"Failed to construct '$a'! (Did you try to construct an object?)", e)
+ case e: InstantiationException =>
+ throw new DependencyManagerException(
+ s"Failed to construct '$a'! (Did you try to construct an inner class or a class with parameters?)",
+ e
+ )
}
}
@@ -124,7 +128,7 @@ trait DependencyAPI[A <: DependencyAPI[A]] { this: TransformLike[_] =>
/** All transform that must run before this transform
* $seqNote
*/
- def prerequisites: Seq[Dependency[A]] = Seq.empty
+ def prerequisites: Seq[Dependency[A]] = Seq.empty
private[options] lazy val _prerequisites: LinkedHashSet[Dependency[A]] = new LinkedHashSet() ++ prerequisites
/** All transforms that, if a prerequisite of *another* transform, will run before this transform.
@@ -184,8 +188,10 @@ trait DependencyAPI[A <: DependencyAPI[A]] { this: TransformLike[_] =>
/** A trait indicating that no invalidations occur, i.e., all previous transforms are preserved
* @tparam A some [[TransformLike]]
*/
-@deprecated("Use an explicit `override def invalidates` returning false. This will be removed in FIRRTL 1.5.",
- "FIRRTL 1.4")
+@deprecated(
+ "Use an explicit `override def invalidates` returning false. This will be removed in FIRRTL 1.5.",
+ "FIRRTL 1.4"
+)
trait PreservesAll[A <: DependencyAPI[A]] { this: DependencyAPI[A] =>
override final def invalidates(a: A): Boolean = false
diff --git a/src/main/scala/firrtl/options/Registration.scala b/src/main/scala/firrtl/options/Registration.scala
index c832ec7c..55772c79 100644
--- a/src/main/scala/firrtl/options/Registration.scala
+++ b/src/main/scala/firrtl/options/Registration.scala
@@ -14,26 +14,26 @@ import scopt.{OptionDef, OptionParser, Read}
* @param shortOption an optional single-dash option
* @param helpValueName a string to show as a placeholder argument in help text
*/
-final class ShellOption[A: Read] (
- val longOption: String,
+final class ShellOption[A: Read](
+ val longOption: String,
val toAnnotationSeq: A => AnnotationSeq,
- val helpText: String,
- val shortOption: Option[String] = None,
- val helpValueName: Option[String] = None
-) {
+ val helpText: String,
+ val shortOption: Option[String] = None,
+ val helpValueName: Option[String] = None) {
/** Add this specific shell (command line) option to an option parser
* @param p an option parser
*/
final def addOption(p: OptionParser[AnnotationSeq]): Unit = {
val f = Seq(
- (p: OptionDef[A, AnnotationSeq]) => p.action( (x, c) => toAnnotationSeq(x).reverse ++ c ),
+ (p: OptionDef[A, AnnotationSeq]) => p.action((x, c) => toAnnotationSeq(x).reverse ++ c),
(p: OptionDef[A, AnnotationSeq]) => p.text(helpText),
- (p: OptionDef[A, AnnotationSeq]) => p.unbounded()) ++
- shortOption.map( a => (p: OptionDef[A, AnnotationSeq]) => p.abbr(a) ) ++
- helpValueName.map( a => (p: OptionDef[A, AnnotationSeq]) => p.valueName(a) )
+ (p: OptionDef[A, AnnotationSeq]) => p.unbounded()
+ ) ++
+ shortOption.map(a => (p: OptionDef[A, AnnotationSeq]) => p.abbr(a)) ++
+ helpValueName.map(a => (p: OptionDef[A, AnnotationSeq]) => p.valueName(a))
- f.foldLeft(p.opt[A](longOption))( (a, b) => b(a) )
+ f.foldLeft(p.opt[A](longOption))((a, b) => b(a))
}
}
@@ -55,13 +55,15 @@ trait HasShellOptions {
/** A [[Transform]] that includes an option that should be exposed at the top level.
*
* @note To complete registration, include an entry in
- * src/main/resources/META-INF/services/firrtl.options.RegisteredTransform */
+ * src/main/resources/META-INF/services/firrtl.options.RegisteredTransform
+ */
trait RegisteredTransform extends HasShellOptions { this: Transform => }
/** A class that includes options that should be exposed as a group at the top level.
*
* @note To complete registration, include an entry in
- * src/main/resources/META-INF/services/firrtl.options.RegisteredLibrary */
+ * src/main/resources/META-INF/services/firrtl.options.RegisteredLibrary
+ */
trait RegisteredLibrary extends HasShellOptions {
/** The name of this library.
diff --git a/src/main/scala/firrtl/options/Shell.scala b/src/main/scala/firrtl/options/Shell.scala
index 88301d30..b0ead81f 100644
--- a/src/main/scala/firrtl/options/Shell.scala
+++ b/src/main/scala/firrtl/options/Shell.scala
@@ -4,7 +4,7 @@ package firrtl.options
import firrtl.AnnotationSeq
-import logger.{LogLevelAnnotation, ClassLogLevelAnnotation, LogFileAnnotation, LogClassNamesAnnotation}
+import logger.{ClassLogLevelAnnotation, LogClassNamesAnnotation, LogFileAnnotation, LogLevelAnnotation}
import scopt.OptionParser
@@ -62,28 +62,25 @@ class Shell(val applicationName: String) {
parser.note("Shell Options")
ProgramArgsAnnotation.addOptions(parser)
- Seq( TargetDirAnnotation,
- InputAnnotationFileAnnotation,
- OutputAnnotationFileAnnotation )
+ Seq(TargetDirAnnotation, InputAnnotationFileAnnotation, OutputAnnotationFileAnnotation)
.foreach(_.addOptions(parser))
- parser.opt[Unit]("show-registrations")
- .action{ (_, c) =>
+ parser
+ .opt[Unit]("show-registrations")
+ .action { (_, c) =>
val rtString = registeredTransforms.map(r => s"\n - ${r.getClass.getName}").mkString
val rlString = registeredLibraries.map(l => s"\n - ${l.getClass.getName}").mkString
println(s"""|The following FIRRTL transforms registered command line options:$rtString
|The following libraries registered command line options:$rlString""".stripMargin)
- c }
+ c
+ }
.unbounded()
.text("print discovered registered libraries and transforms")
parser.help("help").text("prints this usage text")
parser.note("Logging Options")
- Seq( LogLevelAnnotation,
- ClassLogLevelAnnotation,
- LogFileAnnotation,
- LogClassNamesAnnotation )
+ Seq(LogLevelAnnotation, ClassLogLevelAnnotation, LogFileAnnotation, LogClassNamesAnnotation)
.foreach(_.addOptions(parser))
}
diff --git a/src/main/scala/firrtl/options/Stage.scala b/src/main/scala/firrtl/options/Stage.scala
index aa4809dd..77c8133b 100644
--- a/src/main/scala/firrtl/options/Stage.scala
+++ b/src/main/scala/firrtl/options/Stage.scala
@@ -37,10 +37,12 @@ abstract class Stage extends Phase {
.foldLeft(annotations)((a, p) => p.transform(a))
Logger.makeScope(annotationsx) {
- Seq( new phases.AddDefaults,
- new phases.Checks,
- new Phase { def transform(a: AnnotationSeq) = run(a) },
- new phases.WriteOutputAnnotations )
+ Seq(
+ new phases.AddDefaults,
+ new phases.Checks,
+ new Phase { def transform(a: AnnotationSeq) = run(a) },
+ new phases.WriteOutputAnnotations
+ )
.map(phases.DeletedWrapper(_))
.foldLeft(annotationsx)((a, p) => p.transform(a))
}
@@ -61,6 +63,7 @@ abstract class Stage extends Phase {
* @param stage the stage to run
*/
class StageMain(val stage: Stage) {
+
/** The main function that serves as this stage's command line interface.
* @param args command line arguments
*/
diff --git a/src/main/scala/firrtl/options/StageAnnotations.scala b/src/main/scala/firrtl/options/StageAnnotations.scala
index 32f8ff59..84168975 100644
--- a/src/main/scala/firrtl/options/StageAnnotations.scala
+++ b/src/main/scala/firrtl/options/StageAnnotations.scala
@@ -89,7 +89,9 @@ object TargetDirAnnotation extends HasShellOptions {
toAnnotationSeq = (a: String) => Seq(TargetDirAnnotation(a)),
helpText = "Work directory (default: '.')",
shortOption = Some("td"),
- helpValueName = Some("<directory>") ) )
+ helpValueName = Some("<directory>")
+ )
+ )
}
@@ -101,10 +103,11 @@ case class ProgramArgsAnnotation(arg: String) extends NoTargetAnnotation with St
object ProgramArgsAnnotation {
- def addOptions(p: OptionParser[AnnotationSeq]): Unit = p.arg[String]("<arg>...")
+ def addOptions(p: OptionParser[AnnotationSeq]): Unit = p
+ .arg[String]("<arg>...")
.unbounded()
.optional()
- .action( (x, c) => ProgramArgsAnnotation(x) +: c )
+ .action((x, c) => ProgramArgsAnnotation(x) +: c)
.text("optional unbounded args")
}
@@ -123,7 +126,9 @@ object InputAnnotationFileAnnotation extends HasShellOptions {
toAnnotationSeq = (a: String) => Seq(InputAnnotationFileAnnotation(a)),
helpText = "An input annotation file",
shortOption = Some("faf"),
- helpValueName = Some("<file>") ) )
+ helpValueName = Some("<file>")
+ )
+ )
}
@@ -141,7 +146,9 @@ object OutputAnnotationFileAnnotation extends HasShellOptions {
toAnnotationSeq = (a: String) => Seq(OutputAnnotationFileAnnotation(a)),
helpText = "An output annotation file",
shortOption = Some("foaf"),
- helpValueName = Some("<file>") ) )
+ helpValueName = Some("<file>")
+ )
+ )
}
@@ -156,6 +163,8 @@ case object WriteDeletedAnnotation extends NoTargetAnnotation with StageOption w
new ShellOption[Unit](
longOption = "write-deleted",
toAnnotationSeq = (_: Unit) => Seq(WriteDeletedAnnotation),
- helpText = "Include deleted annotations in the output annotation file" ) )
+ helpText = "Include deleted annotations in the output annotation file"
+ )
+ )
}
diff --git a/src/main/scala/firrtl/options/StageOptions.scala b/src/main/scala/firrtl/options/StageOptions.scala
index f60a991c..6b9190a7 100644
--- a/src/main/scala/firrtl/options/StageOptions.scala
+++ b/src/main/scala/firrtl/options/StageOptions.scala
@@ -10,26 +10,28 @@ import java.io.File
* @param programArgs explicit program arguments
* @param outputAnnotationFileName an output annotation filename
*/
-class StageOptions private [firrtl] (
- val targetDir: String = TargetDirAnnotation().directory,
- val annotationFilesIn: Seq[String] = Seq.empty,
+class StageOptions private[firrtl] (
+ val targetDir: String = TargetDirAnnotation().directory,
+ val annotationFilesIn: Seq[String] = Seq.empty,
val annotationFileOut: Option[String] = None,
- val programArgs: Seq[String] = Seq.empty,
- val writeDeleted: Boolean = false ) {
+ val programArgs: Seq[String] = Seq.empty,
+ val writeDeleted: Boolean = false) {
- private [options] def copy(
- targetDir: String = targetDir,
- annotationFilesIn: Seq[String] = annotationFilesIn,
+ private[options] def copy(
+ targetDir: String = targetDir,
+ annotationFilesIn: Seq[String] = annotationFilesIn,
annotationFileOut: Option[String] = annotationFileOut,
- programArgs: Seq[String] = programArgs,
- writeDeleted: Boolean = writeDeleted ): StageOptions = {
+ programArgs: Seq[String] = programArgs,
+ writeDeleted: Boolean = writeDeleted
+ ): StageOptions = {
new StageOptions(
targetDir = targetDir,
annotationFilesIn = annotationFilesIn,
annotationFileOut = annotationFileOut,
programArgs = programArgs,
- writeDeleted = writeDeleted )
+ writeDeleted = writeDeleted
+ )
}
@@ -62,9 +64,9 @@ class StageOptions private [firrtl] (
}.toPath.normalize.toFile
file.getParentFile match {
- case null =>
+ case null =>
case parent if (!parent.exists) => parent.mkdirs()
- case _ =>
+ case _ =>
}
file.toString
diff --git a/src/main/scala/firrtl/options/StageUtils.scala b/src/main/scala/firrtl/options/StageUtils.scala
index 3983f653..2411da6e 100644
--- a/src/main/scala/firrtl/options/StageUtils.scala
+++ b/src/main/scala/firrtl/options/StageUtils.scala
@@ -2,16 +2,16 @@
package firrtl.options
-
/** Utilities related to working with a [[Stage]] */
object StageUtils {
+
/** Print a warning message (in yellow)
* @param message error message
*/
def dramaticWarning(message: String): Unit = {
- println(Console.YELLOW + "-"*78)
+ println(Console.YELLOW + "-" * 78)
println(s"Warning: $message")
- println("-"*78 + Console.RESET)
+ println("-" * 78 + Console.RESET)
}
/** Print an error message (in red)
@@ -19,9 +19,9 @@ object StageUtils {
* @note This does not stop the Driver.
*/
def dramaticError(message: String): Unit = {
- println(Console.RED + "-"*78)
+ println(Console.RED + "-" * 78)
println(s"Error: $message")
- println("-"*78 + Console.RESET)
+ println("-" * 78 + Console.RESET)
}
/** Generate a message suggesting that the user look at the usage text.
diff --git a/src/main/scala/firrtl/options/package.scala b/src/main/scala/firrtl/options/package.scala
index 8cf2875b..f87fb8a8 100644
--- a/src/main/scala/firrtl/options/package.scala
+++ b/src/main/scala/firrtl/options/package.scala
@@ -5,17 +5,16 @@ package firrtl
package object options {
implicit object StageOptionsView extends OptionsView[StageOptions] {
- def view(options: AnnotationSeq): StageOptions = options
- .collect { case a: StageOption => a }
+ def view(options: AnnotationSeq): StageOptions = options.collect { case a: StageOption => a }
.foldLeft(new StageOptions())((c, x) =>
x match {
case TargetDirAnnotation(a) => c.copy(targetDir = a)
/* Insert input files at the head of the Seq for speed and because order shouldn't matter */
- case InputAnnotationFileAnnotation(a) => c.copy(annotationFilesIn = a +: c.annotationFilesIn)
+ case InputAnnotationFileAnnotation(a) => c.copy(annotationFilesIn = a +: c.annotationFilesIn)
case OutputAnnotationFileAnnotation(a) => c.copy(annotationFileOut = Some(a))
/* Do NOT reorder program args. The order may matter. */
case ProgramArgsAnnotation(a) => c.copy(programArgs = c.programArgs :+ a)
- case WriteDeletedAnnotation => c.copy(writeDeleted = true)
+ case WriteDeletedAnnotation => c.copy(writeDeleted = true)
}
)
}
diff --git a/src/main/scala/firrtl/options/phases/AddDefaults.scala b/src/main/scala/firrtl/options/phases/AddDefaults.scala
index ab342b1e..0ef1832a 100644
--- a/src/main/scala/firrtl/options/phases/AddDefaults.scala
+++ b/src/main/scala/firrtl/options/phases/AddDefaults.scala
@@ -19,7 +19,7 @@ class AddDefaults extends Phase {
override def invalidates(a: Phase) = false
def transform(annotations: AnnotationSeq): AnnotationSeq = {
- val td = annotations.collectFirst{ case a: TargetDirAnnotation => a}.isEmpty
+ val td = annotations.collectFirst { case a: TargetDirAnnotation => a }.isEmpty
(if (td) Seq(TargetDirAnnotation()) else Seq()) ++
annotations
diff --git a/src/main/scala/firrtl/options/phases/Checks.scala b/src/main/scala/firrtl/options/phases/Checks.scala
index 9e671aa5..024c13a9 100644
--- a/src/main/scala/firrtl/options/phases/Checks.scala
+++ b/src/main/scala/firrtl/options/phases/Checks.scala
@@ -25,24 +25,27 @@ class Checks extends Phase {
val td, outA = collection.mutable.ListBuffer[Annotation]()
annotations.foreach {
- case a: TargetDirAnnotation => td += a
+ case a: TargetDirAnnotation => td += a
case a: OutputAnnotationFileAnnotation => outA += a
case _ =>
}
if (td.size != 1) {
- val d = td.map{ case TargetDirAnnotation(x) => x }
+ val d = td.map { case TargetDirAnnotation(x) => x }
throw new OptionsException(
s"""|Exactly one target directory must be specified, but found `${d.mkString(", ")}` specified via:
| - explicit target directory: -td, --target-dir, TargetDirAnnotation
- | - fallback default value""".stripMargin )}
+ | - fallback default value""".stripMargin
+ )
+ }
if (outA.size > 1) {
- val x = outA.map{ case OutputAnnotationFileAnnotation(x) => x }
+ val x = outA.map { case OutputAnnotationFileAnnotation(x) => x }
throw new OptionsException(
s"""|At most one output annotation file can be specified, but found '${x.mkString(", ")}' specified via:
- | - an option or annotation: -foaf, --output-annotation-file, OutputAnnotationFileAnnotation"""
- .stripMargin )}
+ | - an option or annotation: -foaf, --output-annotation-file, OutputAnnotationFileAnnotation""".stripMargin
+ )
+ }
annotations
}
diff --git a/src/main/scala/firrtl/options/phases/GetIncludes.scala b/src/main/scala/firrtl/options/phases/GetIncludes.scala
index b9320585..dd08e09b 100644
--- a/src/main/scala/firrtl/options/phases/GetIncludes.scala
+++ b/src/main/scala/firrtl/options/phases/GetIncludes.scala
@@ -10,7 +10,7 @@ import firrtl.FileUtils
import java.io.File
import scala.collection.mutable
-import scala.util.{Try, Failure}
+import scala.util.{Failure, Try}
/** Recursively expand all [[InputAnnotationFileAnnotation]]s in an [[AnnotationSeq]] */
class GetIncludes extends Phase {
@@ -37,8 +37,7 @@ class GetIncludes extends Phase {
* @param annos a sequence of annotations
* @return the original annotation sequence with any discovered annotations added
*/
- private def getIncludes(includeGuard: mutable.Set[String] = mutable.Set())
- (annos: AnnotationSeq): AnnotationSeq = {
+ private def getIncludes(includeGuard: mutable.Set[String] = mutable.Set())(annos: AnnotationSeq): AnnotationSeq = {
annos.flatMap {
case a @ InputAnnotationFileAnnotation(value) =>
if (includeGuard.contains(value)) {
diff --git a/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala b/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala
index 7ee385b1..53306c8a 100644
--- a/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala
+++ b/src/main/scala/firrtl/options/phases/WriteOutputAnnotations.scala
@@ -16,9 +16,7 @@ import scala.collection.mutable
class WriteOutputAnnotations extends Phase {
override def prerequisites =
- Seq( Dependency[GetIncludes],
- Dependency[AddDefaults],
- Dependency[Checks] )
+ Seq(Dependency[GetIncludes], Dependency[AddDefaults], Dependency[Checks])
override def optionalPrerequisiteOf = Seq.empty
@@ -29,8 +27,10 @@ class WriteOutputAnnotations extends Phase {
val sopts = Viewer[StageOptions].view(annotations)
val filesWritten = mutable.HashMap.empty[String, Annotation]
val serializable: AnnotationSeq = annotations.toSeq.flatMap {
- case _: Unserializable => None
- case a: DeletedAnnotation => if (sopts.writeDeleted) { Some(a) } else { None }
+ case _: Unserializable => None
+ case a: DeletedAnnotation =>
+ if (sopts.writeDeleted) { Some(a) }
+ else { None }
case a: CustomFileEmission =>
val filename = a.filename(annotations)
val canonical = filename.getCanonicalPath()
@@ -38,7 +38,7 @@ class WriteOutputAnnotations extends Phase {
filesWritten.get(canonical) match {
case None =>
val w = new BufferedWriter(new FileWriter(filename))
- a.getBytes.foreach( w.write(_) )
+ a.getBytes.foreach(w.write(_))
w.close()
filesWritten(canonical) = a
case Some(first) =>
diff --git a/src/main/scala/firrtl/passes/CInferMDir.scala b/src/main/scala/firrtl/passes/CInferMDir.scala
index b4819751..1fe8d57c 100644
--- a/src/main/scala/firrtl/passes/CInferMDir.scala
+++ b/src/main/scala/firrtl/passes/CInferMDir.scala
@@ -18,60 +18,61 @@ object CInferMDir extends Pass {
def infer_mdir_e(mports: MPortDirMap, dir: MPortDir)(e: Expression): Expression = e match {
case e: Reference =>
- mports get e.name match {
+ mports.get(e.name) match {
case None =>
- case Some(p) => mports(e.name) = (p, dir) match {
- case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MInfer, MWrite) => MWrite
- case (MInfer, MRead) => MRead
- case (MInfer, MReadWrite) => MReadWrite
- case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MWrite, MWrite) => MWrite
- case (MWrite, MRead) => MReadWrite
- case (MWrite, MReadWrite) => MReadWrite
- case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MRead, MWrite) => MReadWrite
- case (MRead, MRead) => MRead
- case (MRead, MReadWrite) => MReadWrite
- case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
- case (MReadWrite, MWrite) => MReadWrite
- case (MReadWrite, MRead) => MReadWrite
- case (MReadWrite, MReadWrite) => MReadWrite
- }
+ case Some(p) =>
+ mports(e.name) = (p, dir) match {
+ case (MInfer, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MInfer, MWrite) => MWrite
+ case (MInfer, MRead) => MRead
+ case (MInfer, MReadWrite) => MReadWrite
+ case (MWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MWrite, MWrite) => MWrite
+ case (MWrite, MRead) => MReadWrite
+ case (MWrite, MReadWrite) => MReadWrite
+ case (MRead, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MRead, MWrite) => MReadWrite
+ case (MRead, MRead) => MRead
+ case (MRead, MReadWrite) => MReadWrite
+ case (MReadWrite, MInfer) => throwInternalError(s"infer_mdir_e: shouldn't be here - $p, $dir")
+ case (MReadWrite, MWrite) => MReadWrite
+ case (MReadWrite, MRead) => MReadWrite
+ case (MReadWrite, MReadWrite) => MReadWrite
+ }
}
e
case e: SubAccess =>
infer_mdir_e(mports, dir)(e.expr)
infer_mdir_e(mports, MRead)(e.index) // index can't be a write port
e
- case e => e map infer_mdir_e(mports, dir)
+ case e => e.map(infer_mdir_e(mports, dir))
}
def infer_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
case sx: CDefMPort =>
- mports(sx.name) = sx.direction
- sx map infer_mdir_e(mports, MRead)
+ mports(sx.name) = sx.direction
+ sx.map(infer_mdir_e(mports, MRead))
case sx: Connect =>
- infer_mdir_e(mports, MRead)(sx.expr)
- infer_mdir_e(mports, MWrite)(sx.loc)
- sx
+ infer_mdir_e(mports, MRead)(sx.expr)
+ infer_mdir_e(mports, MWrite)(sx.loc)
+ sx
case sx: PartialConnect =>
- infer_mdir_e(mports, MRead)(sx.expr)
- infer_mdir_e(mports, MWrite)(sx.loc)
- sx
- case sx => sx map infer_mdir_s(mports) map infer_mdir_e(mports, MRead)
+ infer_mdir_e(mports, MRead)(sx.expr)
+ infer_mdir_e(mports, MWrite)(sx.loc)
+ sx
+ case sx => sx.map(infer_mdir_s(mports)).map(infer_mdir_e(mports, MRead))
}
def set_mdir_s(mports: MPortDirMap)(s: Statement): Statement = s match {
- case sx: CDefMPort => sx copy (direction = mports(sx.name))
- case sx => sx map set_mdir_s(mports)
+ case sx: CDefMPort => sx.copy(direction = mports(sx.name))
+ case sx => sx.map(set_mdir_s(mports))
}
def infer_mdir(m: DefModule): DefModule = {
val mports = new MPortDirMap
- m map infer_mdir_s(mports) map set_mdir_s(mports)
+ m.map(infer_mdir_s(mports)).map(set_mdir_s(mports))
}
def run(c: Circuit): Circuit =
- c copy (modules = c.modules map infer_mdir)
+ c.copy(modules = c.modules.map(infer_mdir))
}
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index 9903f445..97d614c1 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -12,9 +12,7 @@ object CheckChirrtl extends Pass with CheckHighFormLike {
override def prerequisites = Dependency[CheckScalaVersion] :: Nil
override val optionalPrerequisiteOf = firrtl.stage.Forms.ChirrtlForm ++
- Seq( Dependency(CInferTypes),
- Dependency(CInferMDir),
- Dependency(RemoveCHIRRTL) )
+ Seq(Dependency(CInferTypes), Dependency(CInferMDir), Dependency(RemoveCHIRRTL))
override def invalidates(a: Transform) = false
diff --git a/src/main/scala/firrtl/passes/CheckFlows.scala b/src/main/scala/firrtl/passes/CheckFlows.scala
index 3a9cc212..bc455a20 100644
--- a/src/main/scala/firrtl/passes/CheckFlows.scala
+++ b/src/main/scala/firrtl/passes/CheckFlows.scala
@@ -13,79 +13,87 @@ object CheckFlows extends Pass {
override def prerequisites = Dependency(passes.ResolveFlows) +: firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency[passes.InferBinaryPoints],
- Dependency[passes.TrimIntervals],
- Dependency[passes.InferWidths],
- Dependency[transforms.InferResets] )
+ Seq(
+ Dependency[passes.InferBinaryPoints],
+ Dependency[passes.TrimIntervals],
+ Dependency[passes.InferWidths],
+ Dependency[transforms.InferResets]
+ )
override def invalidates(a: Transform) = false
type FlowMap = collection.mutable.HashMap[String, Flow]
implicit def toStr(g: Flow): String = g match {
- case SourceFlow => "source"
- case SinkFlow => "sink"
+ case SourceFlow => "source"
+ case SinkFlow => "sink"
case UnknownFlow => "unknown"
- case DuplexFlow => "duplex"
+ case DuplexFlow => "duplex"
}
- class WrongFlow(info:Info, mname: String, expr: String, wrong: Flow, right: Flow) extends PassException(
- s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right.")
+ class WrongFlow(info: Info, mname: String, expr: String, wrong: Flow, right: Flow)
+ extends PassException(
+ s"$info: [module $mname] Expression $expr is used as a $wrong but can only be used as a $right."
+ )
- def run (c:Circuit): Circuit = {
+ def run(c: Circuit): Circuit = {
val errors = new Errors()
def get_flow(e: Expression, flows: FlowMap): Flow = e match {
- case (e: WRef) => flows(e.name)
+ case (e: WRef) => flows(e.name)
case (e: WSubIndex) => get_flow(e.expr, flows)
case (e: WSubAccess) => get_flow(e.expr, flows)
- case (e: WSubField) => e.expr.tpe match {case t: BundleType =>
- val f = (t.fields find (_.name == e.name)).get
- times(get_flow(e.expr, flows), f.flip)
- }
+ case (e: WSubField) =>
+ e.expr.tpe match {
+ case t: BundleType =>
+ val f = (t.fields.find(_.name == e.name)).get
+ times(get_flow(e.expr, flows), f.flip)
+ }
case _ => SourceFlow
}
def flip_q(t: Type): Boolean = {
def flip_rec(t: Type, f: Orientation): Boolean = t match {
- case tx:BundleType => tx.fields exists (
- field => flip_rec(field.tpe, times(f, field.flip))
- )
+ case tx: BundleType => tx.fields.exists(field => flip_rec(field.tpe, times(f, field.flip)))
case tx: VectorType => flip_rec(tx.tpe, f)
case tx => f == Flip
}
flip_rec(t, Default)
}
- def check_flow(info:Info, mname: String, flows: FlowMap, desired: Flow)(e:Expression): Unit = {
- val flow = get_flow(e,flows)
+ def check_flow(info: Info, mname: String, flows: FlowMap, desired: Flow)(e: Expression): Unit = {
+ val flow = get_flow(e, flows)
(flow, desired) match {
case (SourceFlow, SinkFlow) =>
errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
- case (SinkFlow, SourceFlow) => kind(e) match {
- case PortKind | InstanceKind if !flip_q(e.tpe) => // OK!
- case _ =>
- errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
- }
+ case (SinkFlow, SourceFlow) =>
+ kind(e) match {
+ case PortKind | InstanceKind if !flip_q(e.tpe) => // OK!
+ case _ =>
+ errors.append(new WrongFlow(info, mname, e.serialize, desired, flow))
+ }
case _ =>
}
- }
+ }
- def check_flows_e (info:Info, mname: String, flows: FlowMap)(e:Expression): Unit = {
+ def check_flows_e(info: Info, mname: String, flows: FlowMap)(e: Expression): Unit = {
e match {
- case e: Mux => e foreach check_flow(info, mname, flows, SourceFlow)
- case e: DoPrim => e.args foreach check_flow(info, mname, flows, SourceFlow)
+ case e: Mux => e.foreach(check_flow(info, mname, flows, SourceFlow))
+ case e: DoPrim => e.args.foreach(check_flow(info, mname, flows, SourceFlow))
case _ =>
}
- e foreach check_flows_e(info, mname, flows)
+ e.foreach(check_flows_e(info, mname, flows))
}
def check_flows_s(minfo: Info, mname: String, flows: FlowMap)(s: Statement): Unit = {
- val info = get_info(s) match { case NoInfo => minfo case x => x }
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
s match {
- case (s: DefWire) => flows(s.name) = DuplexFlow
+ case (s: DefWire) => flows(s.name) = DuplexFlow
case (s: DefRegister) => flows(s.name) = DuplexFlow
- case (s: DefMemory) => flows(s.name) = SourceFlow
+ case (s: DefMemory) => flows(s.name) = SourceFlow
case (s: WDefInstance) => flows(s.name) = SourceFlow
case (s: DefNode) =>
check_flow(info, mname, flows, SourceFlow)(s.value)
@@ -94,7 +102,7 @@ object CheckFlows extends Pass {
check_flow(info, mname, flows, SinkFlow)(s.loc)
check_flow(info, mname, flows, SourceFlow)(s.expr)
case (s: Print) =>
- s.args foreach check_flow(info, mname, flows, SourceFlow)
+ s.args.foreach(check_flow(info, mname, flows, SourceFlow))
check_flow(info, mname, flows, SourceFlow)(s.en)
check_flow(info, mname, flows, SourceFlow)(s.clk)
case (s: PartialConnect) =>
@@ -111,14 +119,14 @@ object CheckFlows extends Pass {
check_flow(info, mname, flows, SourceFlow)(s.en)
case _ =>
}
- s foreach check_flows_e(info, mname, flows)
- s foreach check_flows_s(minfo, mname, flows)
+ s.foreach(check_flows_e(info, mname, flows))
+ s.foreach(check_flows_s(minfo, mname, flows))
}
for (m <- c.modules) {
val flows = new FlowMap
- flows ++= (m.ports map (p => p.name -> to_flow(p.direction)))
- m foreach check_flows_s(m.info, m.name, flows)
+ flows ++= (m.ports.map(p => p.name -> to_flow(p.direction)))
+ m.foreach(check_flows_s(m.info, m.name, flows))
}
errors.trigger()
c
diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala
index 2f706d35..559c9060 100644
--- a/src/main/scala/firrtl/passes/CheckHighForm.scala
+++ b/src/main/scala/firrtl/passes/CheckHighForm.scala
@@ -27,66 +27,71 @@ trait CheckHighFormLike { this: Pass =>
scopes.find(_.contains(port.mem)).getOrElse(scopes.head) += port.name
}
def legalDecl(name: String): Boolean = !moduleNS.contains(name)
- def legalRef(name: String): Boolean = scopes.exists(_.contains(name))
+ def legalRef(name: String): Boolean = scopes.exists(_.contains(name))
def childScope(): ScopeView = new ScopeView(moduleNS, new NameSet +: scopes)
}
// Custom Exceptions
- class NotUniqueException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Reference $name does not have a unique name.")
- class InvalidLOCException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort.")
- class NegUIntException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] UIntLiteral cannot be negative.")
- class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Reference $name is not declared.")
- class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Poison $name cannot be a bundle type with flips.")
- class MemWithFlipException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Memory $name cannot be a bundle type with flips.")
- class IllegalMemLatencyException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency.")
- class RegWithFlipException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Register $name cannot be a bundle type with flips.")
- class InvalidAccessException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Invalid access to non-reference.")
- class ModuleNameNotUniqueException(info: Info, mname: String) extends PassException(
- s"$info: Repeat definition of module $mname")
- class DefnameConflictException(info: Info, mname: String, defname: String) extends PassException(
- s"$info: defname $defname of extmodule $mname conflicts with an existing module")
- class DefnameDifferentPortsException(info: Info, mname: String, defname: String) extends PassException(
- s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname""")
- class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException(
- s"$info: Module $name is not defined.")
- class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int) extends PassException(
- s"$info: [module $mname] Primop $op requires $n expression arguments.")
- class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int) extends PassException(
- s"$info: [module $mname] Primop $op requires $n integer arguments.")
- class NegWidthException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Width cannot be negative.")
- class NegVecSizeException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Vector type size cannot be negative.")
- class NegMemSizeException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Memory size cannot be negative or zero.")
- class BadPrintfException(info: Info, mname: String, x: Char) extends PassException(
- s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"")
- class BadPrintfTrailingException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Bad printf format: trailing " + "\"%\"")
- class BadPrintfIncorrectNumException(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Bad printf format: incorrect number of arguments")
- class InstanceLoop(info: Info, mname: String, loop: String) extends PassException(
- s"$info: [module $mname] Has instance loop $loop")
- class NoTopModuleException(info: Info, name: String) extends PassException(
- s"$info: A single module must be named $name.")
- class NegArgException(info: Info, mname: String, op: String, value: BigInt) extends PassException(
- s"$info: [module $mname] Primop $op argument $value < 0.")
- class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt) extends PassException(
- s"$info: [module $mname] Primop $op lsb $lsb > $msb.")
- class ResetInputException(info: Info, mname: String, expr: Expression) extends PassException(
- s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}")
- class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression) extends PassException(
- s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}")
-
+ class NotUniqueException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Reference $name does not have a unique name.")
+ class InvalidLOCException(info: Info, mname: String)
+ extends PassException(
+ s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort."
+ )
+ class NegUIntException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] UIntLiteral cannot be negative.")
+ class UndeclaredReferenceException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Reference $name is not declared.")
+ class PoisonWithFlipException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Poison $name cannot be a bundle type with flips.")
+ class MemWithFlipException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Memory $name cannot be a bundle type with flips.")
+ class IllegalMemLatencyException(info: Info, mname: String, name: String)
+ extends PassException(
+ s"$info: [module $mname] Memory $name must have non-negative read latency and positive write latency."
+ )
+ class RegWithFlipException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Register $name cannot be a bundle type with flips.")
+ class InvalidAccessException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Invalid access to non-reference.")
+ class ModuleNameNotUniqueException(info: Info, mname: String)
+ extends PassException(s"$info: Repeat definition of module $mname")
+ class DefnameConflictException(info: Info, mname: String, defname: String)
+ extends PassException(s"$info: defname $defname of extmodule $mname conflicts with an existing module")
+ class DefnameDifferentPortsException(info: Info, mname: String, defname: String)
+ extends PassException(
+ s"""$info: ports of extmodule $mname with defname $defname are different for an extmodule with the same defname"""
+ )
+ class ModuleNotDefinedException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: Module $name is not defined.")
+ class IncorrectNumArgsException(info: Info, mname: String, op: String, n: Int)
+ extends PassException(s"$info: [module $mname] Primop $op requires $n expression arguments.")
+ class IncorrectNumConstsException(info: Info, mname: String, op: String, n: Int)
+ extends PassException(s"$info: [module $mname] Primop $op requires $n integer arguments.")
+ class NegWidthException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Width cannot be negative.")
+ class NegVecSizeException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Vector type size cannot be negative.")
+ class NegMemSizeException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Memory size cannot be negative or zero.")
+ class BadPrintfException(info: Info, mname: String, x: Char)
+ extends PassException(s"$info: [module $mname] Bad printf format: " + "\"%" + x + "\"")
+ class BadPrintfTrailingException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Bad printf format: trailing " + "\"%\"")
+ class BadPrintfIncorrectNumException(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Bad printf format: incorrect number of arguments")
+ class InstanceLoop(info: Info, mname: String, loop: String)
+ extends PassException(s"$info: [module $mname] Has instance loop $loop")
+ class NoTopModuleException(info: Info, name: String)
+ extends PassException(s"$info: A single module must be named $name.")
+ class NegArgException(info: Info, mname: String, op: String, value: BigInt)
+ extends PassException(s"$info: [module $mname] Primop $op argument $value < 0.")
+ class LsbLargerThanMsbException(info: Info, mname: String, op: String, lsb: BigInt, msb: BigInt)
+ extends PassException(s"$info: [module $mname] Primop $op lsb $lsb > $msb.")
+ class ResetInputException(info: Info, mname: String, expr: Expression)
+ extends PassException(s"$info: [module $mname] Abstract Reset not allowed as top-level input: ${expr.serialize}")
+ class ResetExtModuleOutputException(info: Info, mname: String, expr: Expression)
+ extends PassException(s"$info: [module $mname] Abstract Reset not allowed as ExtModule output: ${expr.serialize}")
// Is Chirrtl allowed for this check? If not, return an error
def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException]
@@ -94,12 +99,12 @@ trait CheckHighFormLike { this: Pass =>
def run(c: Circuit): Circuit = {
val errors = new Errors()
val moduleGraph = new ModuleGraph
- val moduleNames = (c.modules map (_.name)).toSet
+ val moduleNames = (c.modules.map(_.name)).toSet
val intModuleNames = c.modules.view.collect({ case m: Module => m.name }).toSet
- c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach {
- m => errors.append(new ModuleNameNotUniqueException(m.info, m.name))
+ c.modules.groupBy(_.name).filter(_._2.length > 1).flatMap(_._2).foreach { m =>
+ errors.append(new ModuleNameNotUniqueException(m.info, m.name))
}
/** Strip all widths from types */
@@ -110,16 +115,18 @@ trait CheckHighFormLike { this: Pass =>
val extmoduleCollidingPorts = c.modules.collect {
case a: ExtModule => a
- }.groupBy(a => (a.defname, a.params.nonEmpty)).map {
- /* There are no parameters, so all ports must match exactly. */
- case (k@ (_, false), a) =>
- k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo))).toSet
- /* If there are parameters, then only port names must match because parameters could parameterize widths.
- * This means that this check cannot produce false positives, but can have false negatives.
- */
- case (k@ (_, true), a) =>
- k -> a.map(_.copy(info=NoInfo)).map(_.ports.map(_.copy(info=NoInfo).mapType(stripWidth))).toSet
- }.filter(_._2.size > 1)
+ }.groupBy(a => (a.defname, a.params.nonEmpty))
+ .map {
+ /* There are no parameters, so all ports must match exactly. */
+ case (k @ (_, false), a) =>
+ k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo))).toSet
+ /* If there are parameters, then only port names must match because parameters could parameterize widths.
+ * This means that this check cannot produce false positives, but can have false negatives.
+ */
+ case (k @ (_, true), a) =>
+ k -> a.map(_.copy(info = NoInfo)).map(_.ports.map(_.copy(info = NoInfo).mapType(stripWidth))).toSet
+ }
+ .filter(_._2.size > 1)
c.modules.collect {
case a: ExtModule =>
@@ -129,7 +136,8 @@ trait CheckHighFormLike { this: Pass =>
case _ =>
}
a match {
- case ExtModule(info, name, _, defname, params) if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) =>
+ case ExtModule(info, name, _, defname, params)
+ if extmoduleCollidingPorts.contains((defname, params.nonEmpty)) =>
errors.append(new DefnameDifferentPortsException(info, name, defname))
case _ =>
}
@@ -147,14 +155,14 @@ trait CheckHighFormLike { this: Pass =>
}
def nonNegativeConsts(): Unit = {
- e.consts.filter(_ < 0).foreach {
- negC => errors.append(new NegArgException(info, mname, e.op.toString, negC))
+ e.consts.filter(_ < 0).foreach { negC =>
+ errors.append(new NegArgException(info, mname, e.op.toString, negC))
}
}
e.op match {
- case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq |
- Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw | Clip | Wrap | Squeeze =>
+ case Add | Sub | Mul | Div | Rem | Lt | Leq | Gt | Geq | Eq | Neq | Dshl | Dshr | And | Or | Xor | Cat | Dshlw |
+ Clip | Wrap | Squeeze =>
correctNum(Option(2), 0)
case AsUInt | AsSInt | AsClock | AsAsyncReset | Cvt | Neq | Not =>
correctNum(Option(1), 0)
@@ -175,7 +183,7 @@ trait CheckHighFormLike { this: Pass =>
case AsInterval =>
correctNum(Option(1), 3)
case Andr | Orr | Xorr | Neg =>
- correctNum(None,0)
+ correctNum(None, 0)
}
}
@@ -208,12 +216,12 @@ trait CheckHighFormLike { this: Pass =>
}
def checkHighFormT(info: Info, mname: => String)(t: Type): Unit = {
- t foreach checkHighFormT(info, mname)
+ t.foreach(checkHighFormT(info, mname))
t match {
case tx: VectorType if tx.size < 0 =>
errors.append(new NegVecSizeException(info, mname))
case _: IntervalType =>
- case _ => t foreach checkHighFormW(info, mname)
+ case _ => t.foreach(checkHighFormW(info, mname))
}
}
@@ -235,12 +243,12 @@ trait CheckHighFormLike { this: Pass =>
errors.append(new NegUIntException(info, mname))
case ex: DoPrim => checkHighFormPrimop(info, mname, ex)
case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf =>
- case ex: SubAccess => validSubexp(info, mname)(ex.expr)
+ case ex: SubAccess => validSubexp(info, mname)(ex.expr)
case ex: WSubAccess => validSubexp(info, mname)(ex.expr)
- case ex => ex foreach validSubexp(info, mname)
+ case ex => ex.foreach(validSubexp(info, mname))
}
- e foreach checkHighFormW(info, mname + "/" + e.serialize)
- e foreach checkHighFormE(info, mname, names)
+ e.foreach(checkHighFormW(info, mname + "/" + e.serialize))
+ e.foreach(checkHighFormE(info, mname, names))
}
def checkName(info: Info, mname: String, names: ScopeView)(name: String): Unit = {
@@ -253,14 +261,17 @@ trait CheckHighFormLike { this: Pass =>
if (!moduleNames(child))
errors.append(new ModuleNotDefinedException(info, parent, child))
// Check to see if a recursive module instantiation has occured
- val childToParent = moduleGraph add (parent, child)
+ val childToParent = moduleGraph.add(parent, child)
if (childToParent.nonEmpty)
- errors.append(new InstanceLoop(info, parent, childToParent mkString "->"))
+ errors.append(new InstanceLoop(info, parent, childToParent.mkString("->")))
}
def checkHighFormS(minfo: Info, mname: String, names: ScopeView)(s: Statement): Unit = {
- val info = get_info(s) match {case NoInfo => minfo case x => x}
- s foreach checkName(info, mname, names)
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
+ s.foreach(checkName(info, mname, names))
s match {
case DefRegister(info, name, tpe, _, reset, init) =>
if (hasFlip(tpe))
@@ -272,24 +283,24 @@ trait CheckHighFormLike { this: Pass =>
errors.append(new MemWithFlipException(info, mname, sx.name))
if (sx.depth <= 0)
errors.append(new NegMemSizeException(info, mname))
- case sx: DefInstance => checkInstance(info, mname, sx.module)
- case sx: WDefInstance => checkInstance(info, mname, sx.module)
- case sx: Connect => checkValidLoc(info, mname, sx.loc)
- case sx: PartialConnect => checkValidLoc(info, mname, sx.loc)
- case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
- case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
+ case sx: DefInstance => checkInstance(info, mname, sx.module)
+ case sx: WDefInstance => checkInstance(info, mname, sx.module)
+ case sx: Connect => checkValidLoc(info, mname, sx.loc)
+ case sx: PartialConnect => checkValidLoc(info, mname, sx.loc)
+ case sx: Print => checkFstring(info, mname, sx.string, sx.args.length)
+ case _: CDefMemory => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
case mport: CDefMPort =>
errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) }
names.expandMPortVisibility(mport)
case sx => // Do Nothing
}
- s foreach checkHighFormT(info, mname)
- s foreach checkHighFormE(info, mname, names)
+ s.foreach(checkHighFormT(info, mname))
+ s.foreach(checkHighFormE(info, mname, names))
s match {
- case Conditionally(_,_, conseq, alt) =>
+ case Conditionally(_, _, conseq, alt) =>
checkHighFormS(minfo, mname, names.childScope())(conseq)
checkHighFormS(minfo, mname, names.childScope())(alt)
- case _ => s foreach checkHighFormS(minfo, mname, names)
+ case _ => s.foreach(checkHighFormS(minfo, mname, names))
}
}
@@ -313,10 +324,10 @@ trait CheckHighFormLike { this: Pass =>
def checkHighFormM(m: DefModule): Unit = {
val names = ScopeView()
- m foreach checkHighFormP(m.name, names)
- m foreach checkHighFormS(m.info, m.name, names)
+ m.foreach(checkHighFormP(m.name, names))
+ m.foreach(checkHighFormS(m.info, m.name, names))
m match {
- case _: Module =>
+ case _: Module =>
case ext: ExtModule =>
for ((port, expr) <- findBadResetTypePorts(ext, Output)) {
errors.append(new ResetExtModuleOutputException(port.info, ext.name, expr))
@@ -324,7 +335,7 @@ trait CheckHighFormLike { this: Pass =>
}
}
- c.modules foreach checkHighFormM
+ c.modules.foreach(checkHighFormM)
c.modules.filter(_.name == c.main) match {
case Seq(topMod) =>
for ((port, expr) <- findBadResetTypePorts(topMod, Input)) {
@@ -342,21 +353,23 @@ object CheckHighForm extends Pass with CheckHighFormLike {
override def prerequisites = firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferWidths],
- Dependency[transforms.InferResets] )
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferWidths],
+ Dependency[transforms.InferResets]
+ )
override def invalidates(a: Transform) = false
- class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.")
+ class IllegalChirrtlMemException(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.")
def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = {
val memName = s match {
case cm: CDefMemory => cm.name
- case cp: CDefMPort => cp.mem
+ case cp: CDefMPort => cp.mem
}
Some(new IllegalChirrtlMemException(info, mname, memName))
}
diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala
index 4a5577f9..96057831 100644
--- a/src/main/scala/firrtl/passes/CheckInitialization.scala
+++ b/src/main/scala/firrtl/passes/CheckInitialization.scala
@@ -22,10 +22,11 @@ object CheckInitialization extends Pass {
private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression])
- class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException(
- s"$info : [module $mname] Reference $name is not fully initialized.\n" +
- trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n")
- )
+ class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement])
+ extends PassException(
+ s"$info : [module $mname] Reference $name is not fully initialized.\n" +
+ trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n")
+ )
private def getTrace(expr: WrappedExpression, voidExprs: Map[WrappedExpression, VoidExpr]): Seq[Statement] = {
@tailrec
@@ -81,7 +82,7 @@ object CheckInitialization extends Pass {
case node: DefNode => // Ignore nodes
case decl: IsDeclaration =>
val trace = getTrace(expr, voidExprs.toMap)
- errors append new RefNotInitializedException(decl.info, m.name, decl.name, trace)
+ errors.append(new RefNotInitializedException(decl.info, m.name, decl.name, trace))
}
}
}
diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala
index c94928a1..956c1134 100644
--- a/src/main/scala/firrtl/passes/CheckTypes.scala
+++ b/src/main/scala/firrtl/passes/CheckTypes.scala
@@ -16,92 +16,105 @@ object CheckTypes extends Pass {
override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency(passes.ResolveFlows),
- Dependency(passes.CheckFlows),
- Dependency[passes.InferWidths],
- Dependency(passes.CheckWidths) )
+ Seq(
+ Dependency(passes.ResolveFlows),
+ Dependency(passes.CheckFlows),
+ Dependency[passes.InferWidths],
+ Dependency(passes.CheckWidths)
+ )
override def invalidates(a: Transform) = false
// Custom Exceptions
- class SubfieldNotInBundle(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname ] Subfield $name is not in bundle.")
- class SubfieldOnNonBundle(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Subfield $name is accessed on a non-bundle.")
- class IndexTooLarge(info: Info, mname: String, value: Int) extends PassException(
- s"$info: [module $mname] Index with value $value is too large.")
- class IndexOnNonVector(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Index illegal on non-vector type.")
- class AccessIndexNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Access index must be a UInt type.")
- class IndexNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Index is not of UIntType.")
- class EnableNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Enable is not of UIntType.")
+ class SubfieldNotInBundle(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname ] Subfield $name is not in bundle.")
+ class SubfieldOnNonBundle(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Subfield $name is accessed on a non-bundle.")
+ class IndexTooLarge(info: Info, mname: String, value: Int)
+ extends PassException(s"$info: [module $mname] Index with value $value is too large.")
+ class IndexOnNonVector(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Index illegal on non-vector type.")
+ class AccessIndexNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Access index must be a UInt type.")
+ class IndexNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Index is not of UIntType.")
+ class EnableNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Enable is not of UIntType.")
class InvalidConnect(info: Info, mname: String, con: String, lhs: Expression, rhs: Expression)
extends PassException({
- val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}"
- val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}"
- s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe"
- })
- class InvalidRegInit(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Type of init must match type of DefRegister.")
- class PrintfArgNotGround(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.")
- class ReqClk(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Requires a clock typed signal.")
- class RegReqClk(info: Info, mname: String, name: String) extends PassException(
- s"$info: [module $mname] Register $name requires a clock typed signal.")
- class EnNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Enable must be a UIntType typed signal.")
- class PredNotUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Predicate not a UIntType.")
- class OpNotGround(info: Info, mname: String, op: String) extends PassException(
- s"$info: [module $mname] Primop $op cannot operate on non-ground types.")
- class OpNotUInt(info: Info, mname: String, op: String, e: String) extends PassException(
- s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.")
- class OpNotAllUInt(info: Info, mname: String, op: String) extends PassException(
- s"$info: [module $mname] Primop $op requires all arguments to be UInt type.")
- class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException(
- s"$info: [module $mname] Primop $op requires all operands to have the same type.")
- class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.")
- class OpNotCorrectType(info:Info, mname: String, op: String, tpes: Seq[String]) extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.")
- class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.")
- class NodePassiveType(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Node must be a passive type.")
- class MuxSameType(info: Info, mname: String, t1: String, t2: String) extends PassException(
- s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.")
- class MuxPassiveTypes(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Must mux between passive types.")
- class MuxCondUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] A mux condition must be of type UInt.")
- class MuxClock(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Firrtl does not support muxing clocks.")
- class ValidIfPassiveTypes(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] Must validif a passive type.")
- class ValidIfCondUInt(info: Info, mname: String) extends PassException(
- s"$info: [module $mname] A validif condition must be of type UInt.")
- class IllegalAnalogDeclaration(info: Info, mname: String, decName: String) extends PassException(
- s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName.")
- class IllegalAttachExp(info: Info, mname: String, expName: String) extends PassException(
- s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName.")
- class IllegalResetType(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp.")
- class IllegalUnknownType(info: Info, mname: String, exp: String) extends PassException(
- s"$info: [module $mname] Uninferred type: $exp."
- )
+ val ltpe = s" ${lhs.serialize}: ${lhs.tpe.serialize}"
+ val rtpe = s" ${rhs.serialize}: ${rhs.tpe.serialize}"
+ s"$info: [module $mname] Type mismatch in '$con'.\n$ltpe\n$rtpe"
+ })
+ class InvalidRegInit(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Type of init must match type of DefRegister.")
+ class PrintfArgNotGround(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Printf arguments must be either UIntType or SIntType.")
+ class ReqClk(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Requires a clock typed signal.")
+ class RegReqClk(info: Info, mname: String, name: String)
+ extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.")
+ class EnNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.")
+ class PredNotUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Predicate not a UIntType.")
+ class OpNotGround(info: Info, mname: String, op: String)
+ extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.")
+ class OpNotUInt(info: Info, mname: String, op: String, e: String)
+ extends PassException(s"$info: [module $mname] Primop $op requires argument $e to be a UInt type.")
+ class OpNotAllUInt(info: Info, mname: String, op: String)
+ extends PassException(s"$info: [module $mname] Primop $op requires all arguments to be UInt type.")
+ class OpNotAllSameType(info: Info, mname: String, op: String)
+ extends PassException(s"$info: [module $mname] Primop $op requires all operands to have the same type.")
+ class OpNoMixFix(info: Info, mname: String, op: String)
+ extends PassException(
+ s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type."
+ )
+ class OpNotCorrectType(info: Info, mname: String, op: String, tpes: Seq[String])
+ extends PassException(s"${info}: [module ${mname}] Primop ${op} does not have correct arg types: $tpes.")
+ class OpNotAnalog(info: Info, mname: String, exp: String)
+ extends PassException(s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.")
+ class NodePassiveType(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Node must be a passive type.")
+ class MuxSameType(info: Info, mname: String, t1: String, t2: String)
+ extends PassException(s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.")
+ class MuxPassiveTypes(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Must mux between passive types.")
+ class MuxCondUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.")
+ class MuxClock(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.")
+ class ValidIfPassiveTypes(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] Must validif a passive type.")
+ class ValidIfCondUInt(info: Info, mname: String)
+ extends PassException(s"$info: [module $mname] A validif condition must be of type UInt.")
+ class IllegalAnalogDeclaration(info: Info, mname: String, decName: String)
+ extends PassException(
+ s"$info: [module $mname] Cannot declare a reg, node, or memory with an Analog type: $decName."
+ )
+ class IllegalAttachExp(info: Info, mname: String, expName: String)
+ extends PassException(
+ s"$info: [module $mname] Attach expression must be an port, wire, or port of instance: $expName."
+ )
+ class IllegalResetType(info: Info, mname: String, exp: String)
+ extends PassException(
+ s"$info: [module $mname] Register resets must have type Reset, AsyncReset, or UInt<1>: $exp."
+ )
+ class IllegalUnknownType(info: Info, mname: String, exp: String)
+ extends PassException(
+ s"$info: [module $mname] Uninferred type: $exp."
+ )
def fits(bigger: Constraint, smaller: Constraint): Boolean = (bigger, smaller) match {
case (IsKnown(v1), IsKnown(v2)) if v1 < v2 => false
- case _ => true
+ case _ => true
}
def legalResetType(tpe: Type): Boolean = tpe match {
case UIntType(IntWidth(w)) if w == 1 => true
- case AsyncResetType => true
- case ResetType => true
- case UIntType(UnknownWidth) =>
+ case AsyncResetType => true
+ case ResetType => true
+ case UIntType(UnknownWidth) =>
// cannot catch here, though width may ultimately be wrong
true
case _ => false
@@ -118,13 +131,13 @@ object CheckTypes extends Pass {
fits(i2.lower, i1.lower) && fits(i1.upper, i2.upper) && fits(i1.point, i2.point)
case (_: AnalogType, _: AnalogType) => true
case (AsyncResetType, AsyncResetType) => flip1 == flip2
- case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
- case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2
+ case (ResetType, tpe) => legalResetType(tpe) && flip1 == flip2
+ case (tpe, ResetType) => legalResetType(tpe) && flip1 == flip2
case (t1: BundleType, t2: BundleType) =>
- val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())(
- (map, f1) => map + (f1.name ->( (f1.tpe, f1.flip) )))
- t2.fields forall (f2 =>
- t1_fields get f2.name match {
+ val t1_fields =
+ (t1.fields.foldLeft(Map[String, (Type, Orientation)]()))((map, f1) => map + (f1.name -> ((f1.tpe, f1.flip))))
+ t2.fields.forall(f2 =>
+ t1_fields.get(f2.name) match {
case None => true
case Some((f1_tpe, f1_flip)) =>
bulk_equals(f1_tpe, f2.tpe, times(flip1, f1_flip), times(flip2, f2.flip))
@@ -155,79 +168,155 @@ object CheckTypes extends Pass {
def ut: UIntType = UIntType(UnknownWidth)
def st: SIntType = SIntType(UnknownWidth)
- def run (c:Circuit) : Circuit = {
+ def run(c: Circuit): Circuit = {
val errors = new Errors()
def passive(t: Type): Boolean = t match {
- case _: UIntType |_: SIntType => true
+ case _: UIntType | _: SIntType => true
case tx: VectorType => passive(tx.tpe)
- case tx: BundleType => tx.fields forall (x => x.flip == Default && passive(x.tpe))
+ case tx: BundleType => tx.fields.forall(x => x.flip == Default && passive(x.tpe))
case tx => true
}
def check_types_primop(info: Info, mname: String, e: DoPrim): Unit = {
- def checkAllTypes(exprs: Seq[Expression], okUInt: Boolean, okSInt: Boolean, okClock: Boolean, okFix: Boolean, okAsync: Boolean, okInterval: Boolean): Unit = {
+ def checkAllTypes(
+ exprs: Seq[Expression],
+ okUInt: Boolean,
+ okSInt: Boolean,
+ okClock: Boolean,
+ okFix: Boolean,
+ okAsync: Boolean,
+ okInterval: Boolean
+ ): Unit = {
exprs.foldLeft((false, false, false, false, false, false)) {
- case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) => expr.tpe match {
- case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval)
- case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval)
- case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval)
- case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
- case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
- case i:IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
- case UnknownType =>
- errors.append(new IllegalUnknownType(info, mname, e.serialize))
- (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
- case other => throwInternalError(s"Illegal Type: ${other.serialize}")
- }
+ case ((isUInt, isSInt, isClock, isFix, isAsync, isInterval), expr) =>
+ expr.tpe match {
+ case u: UIntType => (true, isSInt, isClock, isFix, isAsync, isInterval)
+ case s: SIntType => (isUInt, true, isClock, isFix, isAsync, isInterval)
+ case ClockType => (isUInt, isSInt, true, isFix, isAsync, isInterval)
+ case f: FixedType => (isUInt, isSInt, isClock, true, isAsync, isInterval)
+ case AsyncResetType => (isUInt, isSInt, isClock, isFix, true, isInterval)
+ case i: IntervalType => (isUInt, isSInt, isClock, isFix, isAsync, true)
+ case UnknownType =>
+ errors.append(new IllegalUnknownType(info, mname, e.serialize))
+ (isUInt, isSInt, isClock, isFix, isAsync, isInterval)
+ case other => throwInternalError(s"Illegal Type: ${other.serialize}")
+ }
} match {
// (UInt, SInt, Clock, Fixed, Async, Interval)
- case (isAll, false, false, false, false, false) if isAll == okUInt =>
- case (false, isAll, false, false, false, false) if isAll == okSInt =>
- case (false, false, isAll, false, false, false) if isAll == okClock =>
- case (false, false, false, isAll, false, false) if isAll == okFix =>
- case (false, false, false, false, isAll, false) if isAll == okAsync =>
+ case (isAll, false, false, false, false, false) if isAll == okUInt =>
+ case (false, isAll, false, false, false, false) if isAll == okSInt =>
+ case (false, false, isAll, false, false, false) if isAll == okClock =>
+ case (false, false, false, isAll, false, false) if isAll == okFix =>
+ case (false, false, false, false, isAll, false) if isAll == okAsync =>
case (false, false, false, false, false, isAll) if isAll == okInterval =>
- case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
+ case x => errors.append(new OpNotCorrectType(info, mname, e.op.serialize, exprs.map(_.tpe.serialize)))
}
}
e.op match {
case AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset | AsInterval =>
- // All types are ok
+ // All types are ok
case Dshl | Dshr =>
- checkAllTypes(Seq(e.args.head), okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
- checkAllTypes(Seq(e.args(1)), okUInt=true, okSInt=false, okClock=false, okFix=false, okAsync=false, okInterval=false)
+ checkAllTypes(
+ Seq(e.args.head),
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
+ checkAllTypes(
+ Seq(e.args(1)),
+ okUInt = true,
+ okSInt = false,
+ okClock = false,
+ okFix = false,
+ okAsync = false,
+ okInterval = false
+ )
case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
case Pad | Bits | Head | Tail =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=false)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = false
+ )
case Shl | Shr | Cat =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
case IncP | DecP | SetP =>
- checkAllTypes(e.args, okUInt=false, okSInt=false, okClock=false, okFix=true, okAsync=false, okInterval=true)
+ checkAllTypes(
+ e.args,
+ okUInt = false,
+ okSInt = false,
+ okClock = false,
+ okFix = true,
+ okAsync = false,
+ okInterval = true
+ )
case Wrap | Clip | Squeeze =>
- checkAllTypes(e.args, okUInt = false, okSInt = false, okClock = false, okFix = false, okAsync=false, okInterval = true)
+ checkAllTypes(
+ e.args,
+ okUInt = false,
+ okSInt = false,
+ okClock = false,
+ okFix = false,
+ okAsync = false,
+ okInterval = true
+ )
case _ =>
- checkAllTypes(e.args, okUInt=true, okSInt=true, okClock=false, okFix=false, okAsync=false, okInterval=false)
+ checkAllTypes(
+ e.args,
+ okUInt = true,
+ okSInt = true,
+ okClock = false,
+ okFix = false,
+ okAsync = false,
+ okInterval = false
+ )
}
}
- def check_types_e(info:Info, mname: String)(e: Expression): Unit = {
+ def check_types_e(info: Info, mname: String)(e: Expression): Unit = {
e match {
- case (e: WSubField) => e.expr.tpe match {
- case (t: BundleType) => t.fields find (_.name == e.name) match {
- case Some(_) =>
- case None => errors.append(new SubfieldNotInBundle(info, mname, e.name))
+ case (e: WSubField) =>
+ e.expr.tpe match {
+ case (t: BundleType) =>
+ t.fields.find(_.name == e.name) match {
+ case Some(_) =>
+ case None => errors.append(new SubfieldNotInBundle(info, mname, e.name))
+ }
+ case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name))
+ }
+ case (e: WSubIndex) =>
+ e.expr.tpe match {
+ case (t: VectorType) if e.value < t.size =>
+ case (t: VectorType) =>
+ errors.append(new IndexTooLarge(info, mname, e.value))
+ case _ =>
+ errors.append(new IndexOnNonVector(info, mname))
}
- case _ => errors.append(new SubfieldOnNonBundle(info, mname, e.name))
- }
- case (e: WSubIndex) => e.expr.tpe match {
- case (t: VectorType) if e.value < t.size =>
- case (t: VectorType) =>
- errors.append(new IndexTooLarge(info, mname, e.value))
- case _ =>
- errors.append(new IndexOnNonVector(info, mname))
- }
case (e: WSubAccess) =>
e.expr.tpe match {
case _: VectorType =>
@@ -256,11 +345,14 @@ object CheckTypes extends Pass {
}
case _ =>
}
- e foreach check_types_e(info, mname)
+ e.foreach(check_types_e(info, mname))
}
def check_types_s(minfo: Info, mname: String)(s: Statement): Unit = {
- val info = get_info(s) match { case NoInfo => minfo case x => x }
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
s match {
case sx: Connect if !validConnect(sx) =>
val conMsg = sx.copy(info = NoInfo).serialize
@@ -270,7 +362,7 @@ object CheckTypes extends Pass {
errors.append(new InvalidConnect(info, mname, conMsg, sx.loc, sx.expr))
case sx: DefRegister =>
sx.tpe match {
- case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
+ case AnalogType(_) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
case t if wt(sx.tpe) != wt(sx.init.tpe) => errors.append(new InvalidRegInit(info, mname))
case t if !validConnect(sx.tpe, sx.init.tpe) =>
val conMsg = sx.copy(info = NoInfo).serialize
@@ -285,11 +377,12 @@ object CheckTypes extends Pass {
}
case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) =>
errors.append(new PredNotUInt(info, mname))
- case sx: DefNode => sx.value.tpe match {
- case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
- case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname))
- case t =>
- }
+ case sx: DefNode =>
+ sx.value.tpe match {
+ case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
+ case t if !passive(sx.value.tpe) => errors.append(new NodePassiveType(info, mname))
+ case t =>
+ }
case sx: Attach =>
for (e <- sx.exprs) {
e.tpe match {
@@ -298,14 +391,14 @@ object CheckTypes extends Pass {
}
kind(e) match {
case (InstanceKind | PortKind | WireKind) =>
- case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize))
+ case _ => errors.append(new IllegalAttachExp(info, mname, e.serialize))
}
}
case sx: Stop =>
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
case sx: Print =>
- if (sx.args exists (x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
+ if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
errors.append(new PrintfArgNotGround(info, mname))
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
@@ -313,17 +406,18 @@ object CheckTypes extends Pass {
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname))
if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
- case sx: DefMemory => sx.dataType match {
- case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
- case t =>
- }
+ case sx: DefMemory =>
+ sx.dataType match {
+ case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
+ case t =>
+ }
case _ =>
}
- s foreach check_types_e(info, mname)
- s foreach check_types_s(info, mname)
+ s.foreach(check_types_e(info, mname))
+ s.foreach(check_types_s(info, mname))
}
- c.modules foreach (m => m foreach check_types_s(m.info, m.name))
+ c.modules.foreach(m => m.foreach(check_types_s(m.info, m.name)))
errors.trigger()
c
}
diff --git a/src/main/scala/firrtl/passes/CheckWidths.scala b/src/main/scala/firrtl/passes/CheckWidths.scala
index a7729ef8..f7fefa87 100644
--- a/src/main/scala/firrtl/passes/CheckWidths.scala
+++ b/src/main/scala/firrtl/passes/CheckWidths.scala
@@ -22,43 +22,49 @@ object CheckWidths extends Pass {
/** The maximum allowed width for any circuit element */
val MaxWidth = 1000000
val DshlMaxWidth = getUIntWidth(MaxWidth)
- class UninferredWidth (info: Info, target: String) extends PassException(
- s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?)
- |$target""".stripMargin)
- class UninferredBound (info: Info, target: String, bound: String) extends PassException(
- s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?)
- |$target""".stripMargin)
- class InvalidRange (info: Info, target: String, i: IntervalType) extends PassException(
- s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?)
- |$target""".stripMargin)
- class WidthTooSmall(info: Info, mname: String, b: BigInt) extends PassException(
- s"$info : [target $mname] Width too small for constant $b.")
- class WidthTooBig(info: Info, mname: String, b: BigInt) extends PassException(
- s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits")
- class DshlTooBig(info: Info, mname: String) extends PassException(
- s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits.")
- class MultiBitAsClock(info: Info, mname: String) extends PassException(
- s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.")
- class MultiBitAsAsyncReset(info: Info, mname: String) extends PassException(
- s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.")
- class NegWidthException(info:Info, mname: String) extends PassException(
- s"$info: [target $mname] Width cannot be negative or zero.")
- class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String) extends PassException(
- s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp.")
- class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
- s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.")
- class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt) extends PassException(
- s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.")
- class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String) extends PassException(
- s"$info: [target $mname] Attach source $source and expression $eName must have identical widths.")
+ class UninferredWidth(info: Info, target: String)
+ extends PassException(s"""|$info : Uninferred width for target below.serialize}. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class UninferredBound(info: Info, target: String, bound: String)
+ extends PassException(s"""|$info : Uninferred $bound bound for target. (Did you forget to assign to it?)
+ |$target""".stripMargin)
+ class InvalidRange(info: Info, target: String, i: IntervalType)
+ extends PassException(s"""|$info : Invalid range ${i.serialize} for target below. (Are the bounds valid?)
+ |$target""".stripMargin)
+ class WidthTooSmall(info: Info, mname: String, b: BigInt)
+ extends PassException(s"$info : [target $mname] Width too small for constant $b.")
+ class WidthTooBig(info: Info, mname: String, b: BigInt)
+ extends PassException(s"$info : [target $mname] Width $b greater than max allowed width of $MaxWidth bits")
+ class DshlTooBig(info: Info, mname: String)
+ extends PassException(
+ s"$info : [target $mname] Width of dshl shift amount must be less than $DshlMaxWidth bits."
+ )
+ class MultiBitAsClock(info: Info, mname: String)
+ extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to a Clock.")
+ class MultiBitAsAsyncReset(info: Info, mname: String)
+ extends PassException(s"$info : [target $mname] Cannot cast a multi-bit signal to an AsyncReset.")
+ class NegWidthException(info: Info, mname: String)
+ extends PassException(s"$info: [target $mname] Width cannot be negative or zero.")
+ class BitsWidthException(info: Info, mname: String, hi: BigInt, width: BigInt, exp: String)
+ extends PassException(
+ s"$info: [target $mname] High bit $hi in bits operator is larger than input width $width in $exp."
+ )
+ class HeadWidthException(info: Info, mname: String, n: BigInt, width: BigInt)
+ extends PassException(s"$info: [target $mname] Parameter $n in head operator is larger than input width $width.")
+ class TailWidthException(info: Info, mname: String, n: BigInt, width: BigInt)
+ extends PassException(s"$info: [target $mname] Parameter $n in tail operator is larger than input width $width.")
+ class AttachWidthsNotEqual(info: Info, mname: String, eName: String, source: String)
+ extends PassException(
+ s"$info: [target $mname] Attach source $source and expression $eName must have identical widths."
+ )
class DisjointSqueeze(info: Info, mname: String, squeeze: DoPrim)
- extends PassException({
- val toSqz = squeeze.args.head.serialize
- val toSqzTpe = squeeze.args.head.tpe.serialize
- val sqzTo = squeeze.args(1).serialize
- val sqzToTpe = squeeze.args(1).tpe.serialize
- s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe"
- })
+ extends PassException({
+ val toSqz = squeeze.args.head.serialize
+ val toSqzTpe = squeeze.args.head.tpe.serialize
+ val sqzTo = squeeze.args(1).serialize
+ val sqzToTpe = squeeze.args(1).tpe.serialize
+ s"$info: [module $mname] Disjoint squz currently unsupported: $toSqz:$toSqzTpe cannot be squeezed with $sqzTo's type $sqzToTpe"
+ })
def run(c: Circuit): Circuit = {
val errors = new Errors()
@@ -77,35 +83,35 @@ object CheckWidths extends Pass {
def hasWidth(tpe: Type): Boolean = tpe match {
case GroundType(IntWidth(w)) => true
- case GroundType(_) => false
- case _ => throwInternalError(s"hasWidth - $tpe")
+ case GroundType(_) => false
+ case _ => throwInternalError(s"hasWidth - $tpe")
}
def check_width_t(info: Info, target: Target)(t: Type): Unit = {
t match {
case tt: BundleType => tt.fields.foreach(check_width_f(info, target))
//Supports when l = u (if closed)
- case i@IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i
- case i:IntervalType if i.range == Some(Nil) =>
+ case i @ IntervalType(Closed(l), Closed(u), IntWidth(_)) if l <= u => i
+ case i: IntervalType if i.range == Some(Nil) =>
errors.append(new InvalidRange(info, target.prettyPrint(" "), i))
i
- case i@IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u =>
+ case i @ IntervalType(KnownBound(l), KnownBound(u), IntWidth(p)) if l >= u =>
errors.append(new InvalidRange(info, target.prettyPrint(" "), i))
i
- case i@IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i
- case i@IntervalType(_: IsKnown, _, _) =>
+ case i @ IntervalType(KnownBound(_), KnownBound(_), IntWidth(_)) => i
+ case i @ IntervalType(_: IsKnown, _, _) =>
errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper"))
i
- case i@IntervalType(_, _: IsKnown, _) =>
+ case i @ IntervalType(_, _: IsKnown, _) =>
errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower"))
i
- case i@IntervalType(_, _, _) =>
+ case i @ IntervalType(_, _, _) =>
errors.append(new UninferredBound(info, target.prettyPrint(" "), "lower"))
errors.append(new UninferredBound(info, target.prettyPrint(" "), "upper"))
i
- case tt => tt foreach check_width_t(info, target)
+ case tt => tt.foreach(check_width_t(info, target))
}
- t foreach check_width_w(info, target, t)
+ t.foreach(check_width_w(info, target, t))
}
def check_width_f(info: Info, target: Target)(f: Field): Unit =
@@ -120,7 +126,8 @@ object CheckWidths extends Pass {
errors.append(new WidthTooSmall(info, target.serialize, v))
case e @ DoPrim(op, Seq(a, b), _, tpe) =>
(op, a.tpe, b.tpe) match {
- case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _)) if (ua < lb) || (ub < la) =>
+ case (Squeeze, IntervalType(Closed(la), Closed(ua), _), IntervalType(Closed(lb), Closed(ub), _))
+ if (ua < lb) || (ub < la) =>
errors.append(new DisjointSqueeze(info, target.serialize, e))
case (Dshl, at, bt) if (hasWidth(at) && bitWidth(bt) >= DshlMaxWidth) =>
errors.append(new DshlTooBig(info, target.serialize))
@@ -159,7 +166,6 @@ object CheckWidths extends Pass {
}
}
-
def check_width_e_dfs(info: Info, target: Target, expr: Expression): Unit = {
val stack = collection.mutable.ArrayStack(expr)
def push(e: Expression): Unit = stack.push(e)
@@ -171,25 +177,31 @@ object CheckWidths extends Pass {
}
def check_width_s(minfo: Info, target: ModuleTarget)(s: Statement): Unit = {
- val info = get_info(s) match { case NoInfo => minfo case x => x }
- val subRef = s match { case sx: HasName => target.ref(sx.name) case _ => target }
- s foreach check_width_e(info, target, 4)
- s foreach check_width_s(info, target)
- s foreach check_width_t(info, subRef)
+ val info = get_info(s) match {
+ case NoInfo => minfo
+ case x => x
+ }
+ val subRef = s match {
+ case sx: HasName => target.ref(sx.name)
+ case _ => target
+ }
+ s.foreach(check_width_e(info, target, 4))
+ s.foreach(check_width_s(info, target))
+ s.foreach(check_width_t(info, subRef))
s match {
case Attach(infox, exprs) =>
- exprs.tail.foreach ( e =>
+ exprs.tail.foreach(e =>
if (bitWidth(e.tpe) != bitWidth(exprs.head.tpe))
errors.append(new AttachWidthsNotEqual(infox, target.serialize, e.serialize, exprs.head.serialize))
)
case sx: DefRegister =>
sx.reset.tpe match {
case UIntType(IntWidth(w)) if w == 1 =>
- case AsyncResetType =>
- case ResetType =>
- case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
+ case AsyncResetType =>
+ case ResetType =>
+ case _ => errors.append(new CheckTypes.IllegalResetType(info, target.serialize, sx.name))
}
- if(!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) {
+ if (!CheckTypes.validConnect(sx.tpe, sx.init.tpe)) {
val conMsg = sx.copy(info = NoInfo).serialize
errors.append(new CheckTypes.InvalidConnect(info, target.module, conMsg, WRef(sx), sx.init))
}
@@ -197,14 +209,15 @@ object CheckWidths extends Pass {
}
}
- def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit = check_width_t(p.info, target.ref(p.name))(p.tpe)
+ def check_width_p(minfo: Info, target: ModuleTarget)(p: Port): Unit =
+ check_width_t(p.info, target.ref(p.name))(p.tpe)
def check_width_m(circuit: CircuitTarget)(m: DefModule): Unit = {
- m foreach check_width_p(m.info, circuit.module(m.name))
- m foreach check_width_s(m.info, circuit.module(m.name))
+ m.foreach(check_width_p(m.info, circuit.module(m.name)))
+ m.foreach(check_width_s(m.info, circuit.module(m.name)))
}
- c.modules foreach check_width_m(CircuitTarget(c.main))
+ c.modules.foreach(check_width_m(CircuitTarget(c.main)))
errors.trigger()
c
}
diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
index 544f90a6..55a9c53a 100644
--- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
+++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala
@@ -10,15 +10,16 @@ import firrtl.options.Dependency
object CommonSubexpressionElimination extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats] )
+ Seq(
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[firrtl.transforms.CombineCats]
+ )
override def optionalPrerequisiteOf =
- Seq( Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = false
@@ -27,24 +28,26 @@ object CommonSubexpressionElimination extends Pass {
val nodes = collection.mutable.HashMap[String, Expression]()
def eliminateNodeRef(e: Expression): Expression = e match {
- case WRef(name, tpe, kind, flow) => nodes get name match {
- case Some(expression) => expressions get expression match {
- case Some(cseName) if cseName != name =>
- WRef(cseName, tpe, kind, flow)
+ case WRef(name, tpe, kind, flow) =>
+ nodes.get(name) match {
+ case Some(expression) =>
+ expressions.get(expression) match {
+ case Some(cseName) if cseName != name =>
+ WRef(cseName, tpe, kind, flow)
+ case _ => e
+ }
case _ => e
}
- case _ => e
- }
- case _ => e map eliminateNodeRef
+ case _ => e.map(eliminateNodeRef)
}
def eliminateNodeRefs(s: Statement): Statement = {
- s map eliminateNodeRef match {
+ s.map(eliminateNodeRef) match {
case x: DefNode =>
nodes(x.name) = x.value
expressions.getOrElseUpdate(x.value, x.name)
x
- case other => other map eliminateNodeRefs
+ case other => other.map(eliminateNodeRefs)
}
}
@@ -54,7 +57,7 @@ object CommonSubexpressionElimination extends Pass {
def run(c: Circuit): Circuit = {
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => Module(m.info, m.name, m.ports, cse(m.body))
+ case m: Module => Module(m.info, m.name, m.ports, cse(m.body))
}
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
index 4a426209..baf7d4d5 100644
--- a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
+++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
@@ -7,7 +7,7 @@ import firrtl.PrimOps._
import firrtl.ir._
import firrtl._
import firrtl.Mappers._
-import firrtl.Utils.{sub_type, module_type, field_type, max, throwInternalError}
+import firrtl.Utils.{field_type, max, module_type, sub_type, throwInternalError}
import firrtl.options.Dependency
/** Replaces FixedType with SIntType, and correctly aligns all binary points
@@ -15,71 +15,74 @@ import firrtl.options.Dependency
object ConvertFixedToSInt extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck],
- Dependency[RemoveIntervals] ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck],
+ Dependency[RemoveIntervals]
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform) = false
def alignArg(e: Expression, point: BigInt): Expression = e.tpe match {
case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p)
- if((point - p) > 0) {
+ if ((point - p) > 0) {
DoPrim(Shl, Seq(e), Seq(point - p), UnknownType)
} else if (point - p < 0) {
DoPrim(Shr, Seq(e), Seq(p - point), UnknownType)
} else e
case FixedType(w, p) => throwInternalError(s"alignArg: shouldn't be here - $e")
- case _ => e
+ case _ => e
}
def calcPoint(es: Seq[Expression]): BigInt =
es.map(_.tpe match {
case FixedType(IntWidth(w), IntWidth(p)) => p
- case _ => BigInt(0)
+ case _ => BigInt(0)
}).reduce(max(_, _))
def toSIntType(t: Type): Type = t match {
case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w))
- case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t")
- case _ => t map toSIntType
+ case FixedType(w, p) => throwInternalError(s"toSIntType: shouldn't be here - $t")
+ case _ => t.map(toSIntType)
}
def run(c: Circuit): Circuit = {
- val moduleTypes = mutable.HashMap[String,Type]()
- def onModule(m:DefModule) : DefModule = {
- val types = mutable.HashMap[String,Type]()
- def updateExpType(e:Expression): Expression = e match {
- case DoPrim(Mul, args, consts, tpe) => e map updateExpType
- case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) map updateExpType
- case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) map updateExpType
- case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) map updateExpType
- case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) map updateExpType
+ val moduleTypes = mutable.HashMap[String, Type]()
+ def onModule(m: DefModule): DefModule = {
+ val types = mutable.HashMap[String, Type]()
+ def updateExpType(e: Expression): Expression = e match {
+ case DoPrim(Mul, args, consts, tpe) => e.map(updateExpType)
+ case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe).map(updateExpType)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe).map(updateExpType)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe).map(updateExpType)
+ case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p).map(updateExpType)
case DoPrim(op, args, consts, tpe) =>
val point = calcPoint(args)
val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType)
- newExp map updateExpType match {
+ newExp.map(updateExpType) match {
case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe)
- case e => e
+ case e => e
}
case Mux(cond, tval, fval, tpe) =>
val point = calcPoint(Seq(tval, fval))
val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType)
- newExp map updateExpType
+ newExp.map(updateExpType)
case e: UIntLiteral => e
case e: SIntLiteral => e
- case _ => e map updateExpType match {
- case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe)
- case WRef(name, tpe, k, g) => WRef(name, types(name), k, g)
- case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g)
- case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g)
- case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g)
- }
+ case _ =>
+ e.map(updateExpType) match {
+ case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe)
+ case WRef(name, tpe, k, g) => WRef(name, types(name), k, g)
+ case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g)
+ case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g)
+ case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g)
+ }
}
def updateStmtType(s: Statement): Statement = s match {
case DefRegister(info, name, tpe, clock, reset, init) =>
val newType = toSIntType(tpe)
types(name) = newType
- DefRegister(info, name, newType, clock, reset, init) map updateExpType
+ DefRegister(info, name, newType, clock, reset, init).map(updateExpType)
case DefWire(info, name, tpe) =>
val newType = toSIntType(tpe)
types(name) = newType
@@ -101,37 +104,34 @@ object ConvertFixedToSInt extends Pass {
case Connect(info, loc, exp) =>
val point = calcPoint(Seq(loc))
val newExp = alignArg(exp, point)
- Connect(info, loc, newExp) map updateExpType
+ Connect(info, loc, newExp).map(updateExpType)
case PartialConnect(info, loc, exp) =>
val point = calcPoint(Seq(loc))
val newExp = alignArg(exp, point)
- PartialConnect(info, loc, newExp) map updateExpType
+ PartialConnect(info, loc, newExp).map(updateExpType)
// check Connect case, need to shl
- case s => (s map updateStmtType) map updateExpType
+ case s => (s.map(updateStmtType)).map(updateExpType)
}
m.ports.foreach(p => types(p.name) = p.tpe)
m match {
- case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body))
- case m:ExtModule => m
+ case Module(info, name, ports, body) => Module(info, name, ports, updateStmtType(body))
+ case m: ExtModule => m
}
}
- val newModules = for(m <- c.modules) yield {
- val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe)))
+ val newModules = for (m <- c.modules) yield {
+ val newPorts = m.ports.map(p => Port(p.info, p.name, p.direction, toSIntType(p.tpe)))
m match {
- case Module(info, name, ports, body) => Module(info,name,newPorts,body)
- case ext: ExtModule => ext.copy(ports = newPorts)
+ case Module(info, name, ports, body) => Module(info, name, newPorts, body)
+ case ext: ExtModule => ext.copy(ports = newPorts)
}
}
newModules.foreach(m => moduleTypes(m.name) = module_type(m))
/* @todo This should be moved outside */
- (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main ))
+ (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main))
}
}
-
-
-
// vim: set ts=4 sw=4 et:
diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala
index d28e6399..4f849c5a 100644
--- a/src/main/scala/firrtl/passes/ExpandConnects.scala
+++ b/src/main/scala/firrtl/passes/ExpandConnects.scala
@@ -9,8 +9,7 @@ import firrtl.Mappers._
object ExpandConnects extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped
+ Seq(Dependency(PullMuxes), Dependency(ReplaceAccesses)) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform) = a match {
case ResolveFlows => true
@@ -19,62 +18,65 @@ object ExpandConnects extends Pass {
def run(c: Circuit): Circuit = {
def expand_connects(m: Module): Module = {
- val flows = collection.mutable.LinkedHashMap[String,Flow]()
+ val flows = collection.mutable.LinkedHashMap[String, Flow]()
def expand_s(s: Statement): Statement = {
- def set_flow(e: Expression): Expression = e map set_flow match {
+ def set_flow(e: Expression): Expression = e.map(set_flow) match {
case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name))
case ex: WSubField =>
val f = get_field(ex.expr.tpe, ex.name)
val flowx = times(flow(ex.expr), f.flip)
WSubField(ex.expr, ex.name, ex.tpe, flowx)
- case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
+ case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr))
case ex => ex
}
s match {
- case sx: DefWire => flows(sx.name) = DuplexFlow; sx
- case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
+ case sx: DefWire => flows(sx.name) = DuplexFlow; sx
+ case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
case sx: WDefInstance => flows(sx.name) = SourceFlow; sx
- case sx: DefMemory => flows(sx.name) = SourceFlow; sx
+ case sx: DefMemory => flows(sx.name) = SourceFlow; sx
case sx: DefNode => flows(sx.name) = SourceFlow; sx
case sx: IsInvalid =>
- val invalids = create_exps(sx.expr).flatMap { case expx =>
- flow(set_flow(expx)) match {
+ val invalids = create_exps(sx.expr).flatMap {
+ case expx =>
+ flow(set_flow(expx)) match {
case DuplexFlow => Some(IsInvalid(sx.info, expx))
- case SinkFlow => Some(IsInvalid(sx.info, expx))
- case _ => None
- }
+ case SinkFlow => Some(IsInvalid(sx.info, expx))
+ case _ => None
+ }
}
invalids.size match {
- case 0 => EmptyStmt
- case 1 => invalids.head
- case _ => Block(invalids)
+ case 0 => EmptyStmt
+ case 1 => invalids.head
+ case _ => Block(invalids)
}
case sx: Connect =>
val locs = create_exps(sx.loc)
val exps = create_exps(sx.expr)
- Block(locs.zip(exps).map { case (locx, expx) =>
- to_flip(flow(locx)) match {
+ Block(locs.zip(exps).map {
+ case (locx, expx) =>
+ to_flip(flow(locx)) match {
case Default => Connect(sx.info, locx, expx)
- case Flip => Connect(sx.info, expx, locx)
- }
+ case Flip => Connect(sx.info, expx, locx)
+ }
})
case sx: PartialConnect =>
val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
val locs = create_exps(sx.loc)
val exps = create_exps(sx.expr)
- val stmts = ls map { case (x, y) =>
- locs(x).tpe match {
- case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
- case _ =>
- to_flip(flow(locs(x))) match {
- case Default => Connect(sx.info, locs(x), exps(y))
- case Flip => Connect(sx.info, exps(y), locs(x))
- }
- }
+ val stmts = ls.map {
+ case (x, y) =>
+ locs(x).tpe match {
+ case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
+ case _ =>
+ to_flip(flow(locs(x))) match {
+ case Default => Connect(sx.info, locs(x), exps(y))
+ case Flip => Connect(sx.info, exps(y), locs(x))
+ }
+ }
}
Block(stmts)
- case sx => sx map expand_s
+ case sx => sx.map(expand_s)
}
}
@@ -83,8 +85,8 @@ object ExpandConnects extends Pass {
}
val modulesx = c.modules.map {
- case (m: ExtModule) => m
- case (m: Module) => expand_connects(m)
+ case (m: ExtModule) => m
+ case (m: Module) => expand_connects(m)
}
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index ab7f02db..14d5d3ef 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -28,21 +28,23 @@ import collection.mutable
object ExpandWhens extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses)
+ ) ++ firrtl.stage.Forms.Resolved
override def invalidates(a: Transform): Boolean = a match {
case CheckInitialization | ResolveKinds | InferTypes => true
- case _ => false
+ case _ => false
}
/** Returns circuit with when and last connection semantics resolved */
def run(c: Circuit): Circuit = {
- val modulesx = c.modules map {
+ val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => onModule(m)
+ case m: Module => onModule(m)
}
Circuit(c.info, modulesx, c.main)
}
@@ -74,13 +76,12 @@ object ExpandWhens extends Pass {
// Does an expression contain WVoid inserted in this pass?
def containsVoid(e: Expression): Boolean = e match {
- case WVoid => true
+ case WVoid => true
case ValidIf(_, value, _) => memoizedVoid(value)
- case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
- case _ => false
+ case Mux(_, tv, fv, _) => memoizedVoid(tv) || memoizedVoid(fv)
+ case _ => false
}
-
// Memoizes the node that holds a particular expression, if any
val nodes = new NodeLookup
@@ -95,18 +96,15 @@ object ExpandWhens extends Pass {
* @param p predicate so far, used to update simulation constructs
* @param s statement to expand
*/
- def expandWhens(netlist: Netlist,
- defaults: Defaults,
- p: Expression)
- (s: Statement): Statement = s match {
+ def expandWhens(netlist: Netlist, defaults: Defaults, p: Expression)(s: Statement): Statement = s match {
// For each non-register declaration, update netlist with value WVoid for each sink reference
// Return self, unchanged
case stmt @ (_: DefNode | EmptyStmt) => stmt
case w: DefWire =>
- netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow) map (ref => we(ref) -> WVoid))
+ netlist ++= (getSinkRefs(w.name, w.tpe, DuplexFlow).map(ref => we(ref) -> WVoid))
w
case w: DefMemory =>
- netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow) map (ref => we(ref) -> WVoid))
+ netlist ++= (getSinkRefs(w.name, MemPortUtils.memType(w), SourceFlow).map(ref => we(ref) -> WVoid))
w
case w: WDefInstance =>
netlist ++= (getSinkRefs(w.name, w.tpe, SourceFlow).map(ref => we(ref) -> WVoid))
@@ -151,82 +149,88 @@ object ExpandWhens extends Pass {
// Process combined maps because we only want to create 1 mux for each node
// present in the conseq and/or alt
- val memos = (conseqNetlist ++ altNetlist) map { case (lvalue, _) =>
- // Defaults in netlist get priority over those in defaults
- val default = netlist get lvalue match {
- case Some(v) => Some(v)
- case None => getDefault(lvalue, defaults)
- }
- // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
- val (res, info0, info1) = default match {
- case Some(defaultValue) =>
- val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
- val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
- (trueValue, falseValue) match {
- case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
- case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
- case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
- case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
- }
- case None =>
- // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
- (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
- }
+ val memos = (conseqNetlist ++ altNetlist).map {
+ case (lvalue, _) =>
+ // Defaults in netlist get priority over those in defaults
+ val default = netlist.get(lvalue) match {
+ case Some(v) => Some(v)
+ case None => getDefault(lvalue, defaults)
+ }
+ // info0 and info1 correspond to Mux infos, use info0 only if ValidIf
+ val (res, info0, info1) = default match {
+ case Some(defaultValue) =>
+ val (tinfo, trueValue) = unwrap(conseqNetlist.getOrElse(lvalue, defaultValue))
+ val (finfo, falseValue) = unwrap(altNetlist.getOrElse(lvalue, defaultValue))
+ (trueValue, falseValue) match {
+ case (WInvalid, WInvalid) => (WInvalid, NoInfo, NoInfo)
+ case (WInvalid, fv) => (ValidIf(NOT(sx.pred), fv, fv.tpe), finfo, NoInfo)
+ case (tv, WInvalid) => (ValidIf(sx.pred, tv, tv.tpe), tinfo, NoInfo)
+ case (tv, fv) => (Mux(sx.pred, tv, fv, mux_type_and_widths(tv, fv)), tinfo, finfo)
+ }
+ case None =>
+ // Since not in netlist, lvalue must be declared in EXACTLY one of conseq or alt
+ (conseqNetlist.getOrElse(lvalue, altNetlist(lvalue)), NoInfo, NoInfo)
+ }
- res match {
- // Don't create a node to hold mux trees with void values
- // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches)
- case e if containsVoid(e) =>
- netlist(lvalue) = e
- memoizedVoid += e // remember that this was void
- EmptyStmt
- case _: ValidIf | _: Mux | _: DoPrim => nodes get res match {
- case Some(name) =>
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ res match {
+ // Don't create a node to hold mux trees with void values
+ // "Idiomatic" emission of these muxes isn't a concern because they represent bad code (latches)
+ case e if containsVoid(e) =>
+ netlist(lvalue) = e
+ memoizedVoid += e // remember that this was void
+ EmptyStmt
+ case _: ValidIf | _: Mux | _: DoPrim =>
+ nodes.get(res) match {
+ case Some(name) =>
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ EmptyStmt
+ case None =>
+ val name = namespace.newTemp
+ nodes(res) = name
+ netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
+ // Use MultiInfo constructor to preserve NoInfos
+ val info = new MultiInfo(List(sx.info, info0, info1))
+ DefNode(info, name, res)
+ }
+ case _ =>
+ netlist(lvalue) = res
EmptyStmt
- case None =>
- val name = namespace.newTemp
- nodes(res) = name
- netlist(lvalue) = WRef(name, res.tpe, NodeKind, SourceFlow)
- // Use MultiInfo constructor to preserve NoInfos
- val info = new MultiInfo(List(sx.info, info0, info1))
- DefNode(info, name, res)
}
- case _ =>
- netlist(lvalue) = res
- EmptyStmt
- }
}
Block(Seq(conseqStmt, altStmt) ++ memos)
- case block: Block => block map expandWhens(netlist, defaults, p)
+ case block: Block => block.map(expandWhens(netlist, defaults, p))
case _ => throwInternalError()
}
val netlist = new Netlist
// Add ports to netlist
- netlist ++= (m.ports flatMap { case Port(_, name, dir, tpe) =>
- getSinkRefs(name, tpe, to_flow(dir)) map (ref => we(ref) -> WVoid)
+ netlist ++= (m.ports.flatMap {
+ case Port(_, name, dir, tpe) =>
+ getSinkRefs(name, tpe, to_flow(dir)).map(ref => we(ref) -> WVoid)
})
// Do traversal and construct mutable datastructures
val bodyx = expandWhens(netlist, Seq(netlist), one)(m.body)
val attachedAnalogs = attaches.flatMap(_.exprs.map(we)).toSet
- val newBody = Block(Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
- combineAttaches(attaches.toSeq) ++ simlist)
+ val newBody = Block(
+ Seq(squashEmpty(bodyx)) ++ expandNetlist(netlist, attachedAnalogs) ++
+ combineAttaches(attaches.toSeq) ++ simlist
+ )
Module(m.info, m.name, m.ports, newBody)
}
-
/** Returns all references to all sink leaf subcomponents of a reference */
private def getSinkRefs(n: String, t: Type, g: Flow): Seq[Expression] = {
val exps = create_exps(WRef(n, t, ExpKind, g))
- exps.flatMap { case exp =>
- exp.tpe match {
- case AnalogType(w) => None
- case _ => flow(exp) match {
- case (DuplexFlow | SinkFlow) => Some(exp)
- case _ => None
+ exps.flatMap {
+ case exp =>
+ exp.tpe match {
+ case AnalogType(w) => None
+ case _ =>
+ flow(exp) match {
+ case (DuplexFlow | SinkFlow) => Some(exp)
+ case _ => None
+ }
}
- }
}
}
@@ -238,7 +242,7 @@ object ExpandWhens extends Pass {
def handleInvalid(k: WrappedExpression, info: Info): Statement =
if (attached.contains(k)) EmptyStmt else IsInvalid(info, k.e1)
netlist.map {
- case (k, WInvalid) => handleInvalid(k, NoInfo)
+ case (k, WInvalid) => handleInvalid(k, NoInfo)
case (k, InfoExpr(info, WInvalid)) => handleInvalid(k, info)
case (k, v) =>
val (info, expr) = unwrap(v)
@@ -261,7 +265,7 @@ object ExpandWhens extends Pass {
case Seq() => // None of these expressions is present in the attachMap
AttachAcc(exprs, attachMap.size)
case accs => // At least one expression present in the attachMap
- val sorted = accs sortBy (_.idx)
+ val sorted = accs.sortBy(_.idx)
AttachAcc((sorted.map(_.exprs) :+ exprs).flatten.distinct, sorted.head.idx)
}
attachMap ++= acc.exprs.map(_ -> acc)
@@ -274,10 +278,11 @@ object ExpandWhens extends Pass {
private def getDefault(lvalue: WrappedExpression, defaults: Defaults): Option[Expression] = {
defaults match {
case Nil => None
- case head :: tail => head get lvalue match {
- case Some(p) => Some(p)
- case None => getDefault(lvalue, tail)
- }
+ case head :: tail =>
+ head.get(lvalue) match {
+ case Some(p) => Some(p)
+ case None => getDefault(lvalue, tail)
+ }
}
}
@@ -290,10 +295,12 @@ object ExpandWhens extends Pass {
class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true
@@ -301,6 +308,6 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
}
override def execute(a: CircuitState): CircuitState =
- Seq(ExpandWhens, CheckInitialization).foldLeft(a){ case (acc, tx) => tx.transform(acc) }
+ Seq(ExpandWhens, CheckInitialization).foldLeft(a) { case (acc, tx) => tx.transform(acc) }
}
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
index a16205a7..f393d8a5 100644
--- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala
+++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
@@ -13,9 +13,7 @@ import firrtl.options.Dependency
class InferBinaryPoints extends Pass {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ResolveFlows) )
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows))
override def optionalPrerequisiteOf = Seq.empty
@@ -23,12 +21,12 @@ class InferBinaryPoints extends Pass {
private val constraintSolver = new ConstraintSolver()
- private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1,t2) match {
- case (UIntType(w1), UIntType(w2)) =>
- case (SIntType(w1), SIntType(w2)) =>
- case (ClockType, ClockType) =>
- case (ResetType, _) =>
- case (_, ResetType) =>
+ private def addTypeConstraints(r1: ReferenceTarget, r2: ReferenceTarget)(t1: Type, t2: Type): Unit = (t1, t2) match {
+ case (UIntType(w1), UIntType(w2)) =>
+ case (SIntType(w1), SIntType(w2)) =>
+ case (ClockType, ClockType) =>
+ case (ResetType, _) =>
+ case (_, ResetType) =>
case (AsyncResetType, AsyncResetType) =>
case (FixedType(w1, p1), FixedType(w2, p2)) =>
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
@@ -36,78 +34,86 @@ class InferBinaryPoints extends Pass {
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
case (AnalogType(w1), AnalogType(w2)) =>
case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields) foreach { case (f1, f2) =>
- (f1.flip, f2.flip) match {
- case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
- case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
- case _ => sys.error("Shouldn't be here")
- }
+ (t1.fields.zip(t2.fields)).foreach {
+ case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
}
case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case other => throwInternalError(s"Illegal compiler state: cannot constraint different types - $other")
}
- private def addDecConstraints(t: Type): Type = t map addDecConstraints
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s map addDecConstraints match {
+ private def addDecConstraints(t: Type): Type = t.map(addDecConstraints)
+ private def addStmtConstraints(mt: ModuleTarget)(s: Statement): Statement = s.map(addDecConstraints) match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
val exps = create_exps(c.expr)
- (locs zip exps) foreach { case (loc, exp) =>
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ (locs.zip(exps)).foreach {
+ case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
- ls foreach { case (x, y) =>
- val loc = locs(x)
- val exp = exps(y)
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ ls.foreach {
+ case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
pc
case r: DefRegister =>
- addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
+ addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
- case x => x map addStmtConstraints(mt)
+ case x => x.map(addStmtConstraints(mt))
}
private def fixWidth(w: Width): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => w
- case _ => sys.error("Shouldn't be here")
+ case None => w
+ case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type): Type = t.map(fixType).map(fixWidth) match {
case IntervalType(l, u, p) =>
val px = constraintSolver.get(p) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => p
- case _ => sys.error("Shouldn't be here")
+ case None => p
+ case _ => sys.error("Shouldn't be here")
}
IntervalType(l, u, px)
case FixedType(w, p) =>
val px = constraintSolver.get(p) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => p
- case _ => sys.error("Shouldn't be here")
+ case None => p
+ case _ => sys.error("Shouldn't be here")
}
FixedType(w, px)
case x => x
}
- private def fixStmt(s: Statement): Statement = s map fixStmt map fixType
- private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
- def run (c: Circuit): Circuit = {
+ private def fixStmt(s: Statement): Statement = s.map(fixStmt).map(fixType)
+ private def fixPort(p: Port): Port = Port(p.info, p.name, p.direction, fixType(p.tpe))
+ def run(c: Circuit): Circuit = {
val ct = CircuitTarget(c.main)
- c.modules foreach (m => m map addStmtConstraints(ct.module(m.name)))
- c.modules foreach (_.ports foreach {p => addDecConstraints(p.tpe)})
+ c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name))))
+ c.modules.foreach(_.ports.foreach { p => addDecConstraints(p.tpe) })
constraintSolver.solve()
- InferTypes.run(c.copy(modules = c.modules map (_
- map fixPort
- map fixStmt)))
+ InferTypes.run(
+ c.copy(modules =
+ c.modules.map(
+ _.map(fixPort)
+ .map(fixStmt)
+ )
+ )
+ )
}
}
diff --git a/src/main/scala/firrtl/passes/InferTypes.scala b/src/main/scala/firrtl/passes/InferTypes.scala
index 6cc9f2b9..4d14e7ff 100644
--- a/src/main/scala/firrtl/passes/InferTypes.scala
+++ b/src/main/scala/firrtl/passes/InferTypes.scala
@@ -23,16 +23,16 @@ object InferTypes extends Pass {
def remove_unknowns_b(b: Bound): Bound = b match {
case UnknownBound => VarBound(namespace.newName("b"))
- case k => k
+ case k => k
}
def remove_unknowns_w(w: Width): Width = w match {
case UnknownWidth => VarWidth(namespace.newName("w"))
- case wx => wx
+ case wx => wx
}
def remove_unknowns(t: Type): Type = {
- t map remove_unknowns map remove_unknowns_w match {
+ t.map(remove_unknowns).map(remove_unknowns_w) match {
case IntervalType(l, u, p) =>
IntervalType(remove_unknowns_b(l), remove_unknowns_b(u), p)
case x => x
@@ -41,18 +41,18 @@ object InferTypes extends Pass {
// we first need to remove the unknown widths and bounds from all ports,
// as their type will determine the module types
- val portsKnown = c.modules.map(_.map{ p: Port => p.copy(tpe = remove_unknowns(p.tpe)) })
+ val portsKnown = c.modules.map(_.map { p: Port => p.copy(tpe = remove_unknowns(p.tpe)) })
val mtypes = portsKnown.map(m => m.name -> module_type(m)).toMap
def infer_types_e(types: TypeLookup)(e: Expression): Expression =
- e map infer_types_e(types) match {
- case e: WRef => e copy (tpe = types(e.name))
- case e: WSubField => e copy (tpe = field_type(e.expr.tpe, e.name))
- case e: WSubIndex => e copy (tpe = sub_type(e.expr.tpe))
- case e: WSubAccess => e copy (tpe = sub_type(e.expr.tpe))
- case e: DoPrim => PrimOps.set_primop_type(e)
- case e: Mux => e copy (tpe = mux_type_and_widths(e.tval, e.fval))
- case e: ValidIf => e copy (tpe = e.value.tpe)
+ e.map(infer_types_e(types)) match {
+ case e: WRef => e.copy(tpe = types(e.name))
+ case e: WSubField => e.copy(tpe = field_type(e.expr.tpe, e.name))
+ case e: WSubIndex => e.copy(tpe = sub_type(e.expr.tpe))
+ case e: WSubAccess => e.copy(tpe = sub_type(e.expr.tpe))
+ case e: DoPrim => PrimOps.set_primop_type(e)
+ case e: Mux => e.copy(tpe = mux_type_and_widths(e.tval, e.fval))
+ case e: ValidIf => e.copy(tpe = e.value.tpe)
case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
@@ -60,37 +60,37 @@ object InferTypes extends Pass {
case sx: WDefInstance =>
val t = mtypes(sx.module)
types(sx.name) = t
- sx copy (tpe = t)
+ sx.copy(tpe = t)
case sx: DefWire =>
val t = remove_unknowns(sx.tpe)
types(sx.name) = t
- sx copy (tpe = t)
+ sx.copy(tpe = t)
case sx: DefNode =>
- val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode]
+ val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode]
val t = remove_unknowns(sxx.value.tpe)
types(sx.name) = t
sxx
case sx: DefRegister =>
val t = remove_unknowns(sx.tpe)
types(sx.name) = t
- sx copy (tpe = t) map infer_types_e(types)
+ sx.copy(tpe = t).map(infer_types_e(types))
case sx: DefMemory =>
// we need to remove the unknowns from the data type so that all ports get the same VarWidth
val knownDataType = sx.copy(dataType = remove_unknowns(sx.dataType))
types(sx.name) = MemPortUtils.memType(knownDataType)
knownDataType
- case sx => sx map infer_types_s(types) map infer_types_e(types)
+ case sx => sx.map(infer_types_s(types)).map(infer_types_e(types))
}
def infer_types_p(types: TypeLookup)(p: Port): Port = {
val t = remove_unknowns(p.tpe)
types(p.name) = t
- p copy (tpe = t)
+ p.copy(tpe = t)
}
def infer_types(m: DefModule): DefModule = {
val types = new TypeLookup
- m map infer_types_p(types) map infer_types_s(types)
+ m.map(infer_types_p(types)).map(infer_types_s(types))
}
c.copy(modules = portsKnown.map(infer_types))
@@ -108,45 +108,45 @@ object CInferTypes extends Pass {
private type TypeLookup = collection.mutable.HashMap[String, Type]
def run(c: Circuit): Circuit = {
- val mtypes = (c.modules map (m => m.name -> module_type(m))).toMap
-
- def infer_types_e(types: TypeLookup)(e: Expression) : Expression =
- e map infer_types_e(types) match {
- case (e: Reference) => e copy (tpe = types.getOrElse(e.name, UnknownType))
- case (e: SubField) => e copy (tpe = field_type(e.expr.tpe, e.name))
- case (e: SubIndex) => e copy (tpe = sub_type(e.expr.tpe))
- case (e: SubAccess) => e copy (tpe = sub_type(e.expr.tpe))
- case (e: DoPrim) => PrimOps.set_primop_type(e)
- case (e: Mux) => e copy (tpe = mux_type(e.tval, e.fval))
- case (e: ValidIf) => e copy (tpe = e.value.tpe)
- case e @ (_: UIntLiteral | _: SIntLiteral) => e
+ val mtypes = (c.modules.map(m => m.name -> module_type(m))).toMap
+
+ def infer_types_e(types: TypeLookup)(e: Expression): Expression =
+ e.map(infer_types_e(types)) match {
+ case (e: Reference) => e.copy(tpe = types.getOrElse(e.name, UnknownType))
+ case (e: SubField) => e.copy(tpe = field_type(e.expr.tpe, e.name))
+ case (e: SubIndex) => e.copy(tpe = sub_type(e.expr.tpe))
+ case (e: SubAccess) => e.copy(tpe = sub_type(e.expr.tpe))
+ case (e: DoPrim) => PrimOps.set_primop_type(e)
+ case (e: Mux) => e.copy(tpe = mux_type(e.tval, e.fval))
+ case (e: ValidIf) => e.copy(tpe = e.value.tpe)
+ case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
def infer_types_s(types: TypeLookup)(s: Statement): Statement = s match {
case sx: DefRegister =>
types(sx.name) = sx.tpe
- sx map infer_types_e(types)
+ sx.map(infer_types_e(types))
case sx: DefWire =>
types(sx.name) = sx.tpe
sx
case sx: DefNode =>
- val sxx = (sx map infer_types_e(types)).asInstanceOf[DefNode]
+ val sxx = (sx.map(infer_types_e(types))).asInstanceOf[DefNode]
types(sxx.name) = sxx.value.tpe
sxx
case sx: DefMemory =>
types(sx.name) = MemPortUtils.memType(sx)
sx
case sx: CDefMPort =>
- val t = types getOrElse(sx.mem, UnknownType)
+ val t = types.getOrElse(sx.mem, UnknownType)
types(sx.name) = t
- sx copy (tpe = t)
+ sx.copy(tpe = t)
case sx: CDefMemory =>
types(sx.name) = sx.tpe
sx
case sx: DefInstance =>
types(sx.name) = mtypes(sx.module)
sx
- case sx => sx map infer_types_s(types) map infer_types_e(types)
+ case sx => sx.map(infer_types_s(types)).map(infer_types_e(types))
}
def infer_types_p(types: TypeLookup)(p: Port): Port = {
@@ -156,9 +156,9 @@ object CInferTypes extends Pass {
def infer_types(m: DefModule): DefModule = {
val types = new TypeLookup
- m map infer_types_p(types) map infer_types_s(types)
+ m.map(infer_types_p(types)).map(infer_types_s(types))
}
- c copy (modules = c.modules map infer_types)
+ c.copy(modules = c.modules.map(infer_types))
}
}
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 3720523b..eae9690f 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -14,7 +14,7 @@ import firrtl.options.Dependency
object InferWidths {
def apply(): InferWidths = new InferWidths()
- def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
+ def run(c: Circuit): Circuit = new InferWidths().run(c)(new ConstraintSolver)
def execute(state: CircuitState): CircuitState = new InferWidths().execute(state)
}
@@ -22,12 +22,14 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
def update(renameMap: RenameMap): Seq[WidthGeqConstraintAnnotation] = {
val newLoc :: newExp :: Nil = Seq(loc, exp).map { target =>
renameMap.get(target) match {
- case None => Some(target)
- case Some(Seq()) => None
+ case None => Some(target)
+ case Some(Seq()) => None
case Some(Seq(one)) => Some(one)
case Some(many) =>
- throw new Exception(s"Target below is an AggregateType, which " +
- "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ throw new Exception(
+ s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()
+ )
}
}
@@ -60,28 +62,31 @@ case class WidthGeqConstraintAnnotation(loc: ReferenceTarget, exp: ReferenceTarg
*
* Uses firrtl.constraint package to infer widths
*/
-class InferWidths extends Transform
- with ResolvedAnnotationPaths
- with DependencyAPIMigration {
+class InferWidths extends Transform with ResolvedAnnotationPaths with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferBinaryPoints],
- Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferBinaryPoints],
+ Dependency[passes.TrimIntervals]
+ ) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform) = false
val annotationClasses = Seq(classOf[WidthGeqConstraintAnnotation])
- private def addTypeConstraints
- (r1: ReferenceTarget, r2: ReferenceTarget)
- (t1: Type, t2: Type)
- (implicit constraintSolver: ConstraintSolver)
- : Unit = (t1,t2) match {
+ private def addTypeConstraints(
+ r1: ReferenceTarget,
+ r2: ReferenceTarget
+ )(t1: Type,
+ t2: Type
+ )(
+ implicit constraintSolver: ConstraintSolver
+ ): Unit = (t1, t2) match {
case (UIntType(w1), UIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
case (SIntType(w1), SIntType(w2)) => constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
- case (ClockType, ClockType) =>
+ case (ClockType, ClockType) =>
case (FixedType(w1, p1), FixedType(w2, p2)) =>
constraintSolver.addGeq(p1, p2, r1.prettyPrint(""), r2.prettyPrint(""))
constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
@@ -93,101 +98,119 @@ class InferWidths extends Transform
constraintSolver.addGeq(w1, w2, r1.prettyPrint(""), r2.prettyPrint(""))
constraintSolver.addGeq(w2, w1, r1.prettyPrint(""), r2.prettyPrint(""))
case (t1: BundleType, t2: BundleType) =>
- (t1.fields zip t2.fields) foreach { case (f1, f2) =>
- (f1.flip, f2.flip) match {
- case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
- case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
- case _ => sys.error("Shouldn't be here")
- }
+ (t1.fields.zip(t2.fields)).foreach {
+ case (f1, f2) =>
+ (f1.flip, f2.flip) match {
+ case (Default, Default) => addTypeConstraints(r1.field(f1.name), r2.field(f2.name))(f1.tpe, f2.tpe)
+ case (Flip, Flip) => addTypeConstraints(r2.field(f2.name), r1.field(f1.name))(f2.tpe, f1.tpe)
+ case _ => sys.error("Shouldn't be here")
+ }
}
case (t1: VectorType, t2: VectorType) => addTypeConstraints(r1.index(0), r2.index(0))(t1.tpe, t2.tpe)
case (AsyncResetType, AsyncResetType) => Nil
- case (ResetType, _) => Nil
- case (_, ResetType) => Nil
+ case (ResetType, _) => Nil
+ case (_, ResetType) => Nil
}
- private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver)
- : Expression = e map addExpConstraints match {
- case m@Mux(p, tVal, fVal, t) =>
- constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
- m
- case other => other
- }
+ private def addExpConstraints(e: Expression)(implicit constraintSolver: ConstraintSolver): Expression =
+ e.map(addExpConstraints) match {
+ case m @ Mux(p, tVal, fVal, t) =>
+ constraintSolver.addGeq(getWidth(p), Closed(1), "mux predicate", "1.W")
+ m
+ case other => other
+ }
- private def addStmtConstraints(mt: ModuleTarget)(s: Statement)(implicit constraintSolver: ConstraintSolver)
- : Statement = s map addExpConstraints match {
+ private def addStmtConstraints(
+ mt: ModuleTarget
+ )(s: Statement
+ )(
+ implicit constraintSolver: ConstraintSolver
+ ): Statement = s.map(addExpConstraints) match {
case c: Connect =>
val n = get_size(c.loc.tpe)
val locs = create_exps(c.loc)
val exps = create_exps(c.expr)
- (locs zip exps).foreach { case (loc, exp) =>
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
- }
+ (locs.zip(exps)).foreach {
+ case (loc, exp) =>
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
+ }
c
case pc: PartialConnect =>
val ls = get_valid_points(pc.loc.tpe, pc.expr.tpe, Default, Default)
val locs = create_exps(pc.loc)
val exps = create_exps(pc.expr)
- ls foreach { case (x, y) =>
- val loc = locs(x)
- val exp = exps(y)
- to_flip(flow(loc)) match {
- case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
- case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
- }
+ ls.foreach {
+ case (x, y) =>
+ val loc = locs(x)
+ val exp = exps(y)
+ to_flip(flow(loc)) match {
+ case Default => addTypeConstraints(Target.asTarget(mt)(loc), Target.asTarget(mt)(exp))(loc.tpe, exp.tpe)
+ case Flip => addTypeConstraints(Target.asTarget(mt)(exp), Target.asTarget(mt)(loc))(exp.tpe, loc.tpe)
+ }
}
pc
case r: DefRegister =>
- if (r.reset.tpe != AsyncResetType ) {
+ if (r.reset.tpe != AsyncResetType) {
addTypeConstraints(Target.asTarget(mt)(r.reset), mt.ref("1"))(r.reset.tpe, UIntType(IntWidth(1)))
}
addTypeConstraints(mt.ref(r.name), Target.asTarget(mt)(r.init))(r.tpe, r.init.tpe)
r
- case a@Attach(_, exprs) =>
- val widths = exprs map (e => (e, getWidth(e.tpe)))
+ case a @ Attach(_, exprs) =>
+ val widths = exprs.map(e => (e, getWidth(e.tpe)))
val maxWidth = IsMax(widths.map(x => width2constraint(x._2)))
- widths.foreach { case (e, w) =>
- constraintSolver.addGeq(w, CalcWidth(maxWidth), Target.asTarget(mt)(e).prettyPrint(""), mt.ref(a.serialize).prettyPrint(""))
+ widths.foreach {
+ case (e, w) =>
+ constraintSolver.addGeq(
+ w,
+ CalcWidth(maxWidth),
+ Target.asTarget(mt)(e).prettyPrint(""),
+ mt.ref(a.serialize).prettyPrint("")
+ )
}
a
case c: Conditionally =>
addTypeConstraints(Target.asTarget(mt)(c.pred), mt.ref("1.W"))(c.pred.tpe, UIntType(IntWidth(1)))
- c map addStmtConstraints(mt)
- case x => x map addStmtConstraints(mt)
+ c.map(addStmtConstraints(mt))
+ case x => x.map(addStmtConstraints(mt))
}
private def fixWidth(w: Width)(implicit constraintSolver: ConstraintSolver): Width = constraintSolver.get(w) match {
case Some(Closed(x)) if trim(x).isWhole => IntWidth(x.toBigInt)
- case None => w
- case _ => sys.error("Shouldn't be here")
+ case None => w
+ case _ => sys.error("Shouldn't be here")
}
- private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t map fixType map fixWidth match {
+ private def fixType(t: Type)(implicit constraintSolver: ConstraintSolver): Type = t.map(fixType).map(fixWidth) match {
case IntervalType(l, u, p) =>
val (lx, ux) = (constraintSolver.get(l), constraintSolver.get(u)) match {
case (Some(x: Bound), Some(y: Bound)) => (x, y)
case (None, None) => (l, u)
- case x => sys.error(s"Shouldn't be here: $x")
-
+ case x => sys.error(s"Shouldn't be here: $x")
}
IntervalType(lx, ux, fixWidth(p))
case FixedType(w, p) => FixedType(w, fixWidth(p))
- case x => x
+ case x => x
}
- private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement = s map fixStmt map fixType
+ private def fixStmt(s: Statement)(implicit constraintSolver: ConstraintSolver): Statement =
+ s.map(fixStmt).map(fixType)
private def fixPort(p: Port)(implicit constraintSolver: ConstraintSolver): Port = {
Port(p.info, p.name, p.direction, fixType(p.tpe))
}
- def run (c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
+ def run(c: Circuit)(implicit constraintSolver: ConstraintSolver): Circuit = {
val ct = CircuitTarget(c.main)
- c.modules foreach ( m => m map addStmtConstraints(ct.module(m.name)))
+ c.modules.foreach(m => m.map(addStmtConstraints(ct.module(m.name))))
constraintSolver.solve()
- val ret = InferTypes.run(c.copy(modules = c.modules map (_
- map fixPort
- map fixStmt)))
+ val ret = InferTypes.run(
+ c.copy(modules =
+ c.modules.map(
+ _.map(fixPort)
+ .map(fixStmt)
+ )
+ )
+ )
constraintSolver.clear()
ret
}
@@ -200,15 +223,16 @@ class InferWidths extends Transform
def getDeclTypes(modName: String)(stmt: Statement): Unit = {
val pairOpt = stmt match {
- case w: DefWire => Some(w.name -> w.tpe)
- case r: DefRegister => Some(r.name -> r.tpe)
- case n: DefNode => Some(n.name -> n.value.tpe)
+ case w: DefWire => Some(w.name -> w.tpe)
+ case r: DefRegister => Some(r.name -> r.tpe)
+ case n: DefNode => Some(n.name -> n.value.tpe)
case i: WDefInstance => Some(i.name -> i.tpe)
- case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
+ case m: DefMemory => Some(m.name -> MemPortUtils.memType(m))
case other => None
}
- pairOpt.foreach { case (ref, tpe) =>
- typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
+ pairOpt.foreach {
+ case (ref, tpe) =>
+ typeMap += (ReferenceTarget(circuitName, modName, Nil, ref, Nil) -> tpe)
}
stmt.foreachStmt(getDeclTypes(modName))
}
@@ -223,14 +247,20 @@ class InferWidths extends Transform
}
state.annotations.foreach {
- case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
- val locType :: expType :: Nil = Seq(anno.loc, anno.exp) map { target =>
- val baseType = typeMap.getOrElse(target.copy(component = Seq.empty),
- throw new Exception(s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()))
+ case anno: WidthGeqConstraintAnnotation if anno.loc.isLocal && anno.exp.isLocal =>
+ val locType :: expType :: Nil = Seq(anno.loc, anno.exp).map { target =>
+ val baseType = typeMap.getOrElse(
+ target.copy(component = Seq.empty),
+ throw new Exception(
+ s"Target below from WidthGeqConstraintAnnotation was not found\n" + target.prettyPrint()
+ )
+ )
val leafType = target.componentType(baseType)
if (leafType.isInstanceOf[AggregateType]) {
- throw new Exception(s"Target below is an AggregateType, which " +
- "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint())
+ throw new Exception(
+ s"Target below is an AggregateType, which " +
+ "is not supported by WidthGeqConstraintAnnotation\n" + target.prettyPrint()
+ )
}
leafType
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index ad963b19..316878fb 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -32,89 +32,100 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
override def invalidates(a: Transform): Boolean = a == ResolveKinds
- private [firrtl] val inlineDelim: String = "_"
+ private[firrtl] val inlineDelim: String = "_"
val options = Seq(
new ShellOption[Seq[String]](
longOption = "inline",
- toAnnotationSeq = (a: Seq[String]) => a.map { value =>
- value.split('.') match {
- case Array(circuit) =>
- InlineAnnotation(CircuitName(circuit))
- case Array(circuit, module) =>
- InlineAnnotation(ModuleName(module, CircuitName(circuit)))
- case Array(circuit, module, inst) =>
- InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))))
- }
- } :+ RunFirrtlTransformAnnotation(new InlineInstances),
+ toAnnotationSeq = (a: Seq[String]) =>
+ a.map { value =>
+ value.split('.') match {
+ case Array(circuit) =>
+ InlineAnnotation(CircuitName(circuit))
+ case Array(circuit, module) =>
+ InlineAnnotation(ModuleName(module, CircuitName(circuit)))
+ case Array(circuit, module, inst) =>
+ InlineAnnotation(ComponentName(inst, ModuleName(module, CircuitName(circuit))))
+ }
+ } :+ RunFirrtlTransformAnnotation(new InlineInstances),
helpText = "Inline selected modules",
shortOption = Some("fil"),
- helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]") ) )
-
- private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
- anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) {
- case ((modNames, instNames), ann) => ann match {
- case InlineAnnotation(CircuitName(c)) =>
- (circuit.modules.collect {
- case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
- }.toSet, instNames)
- case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
- case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
- case _ => (modNames, instNames)
- }
- }
-
- def execute(state: CircuitState): CircuitState = {
- // TODO Add error check for more than one annotation for inlining
- val (modNames, instNames) = collectAnns(state.circuit, state.annotations)
- if (modNames.nonEmpty || instNames.nonEmpty) {
- run(state.circuit, modNames, instNames, state.annotations)
- } else {
- state
- }
- }
-
- // Checks the following properties:
- // 1) All annotated modules exist
- // 2) All annotated modules are InModules (can be inlined)
- // 3) All annotated instances exist, and their modules can be inline
- def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = {
- val errors = mutable.ArrayBuffer[PassException]()
- val moduleMap = InstanceKeyGraph(c).moduleMap
- def checkExists(name: String): Unit =
- if (!moduleMap.contains(name))
- errors += new PassException(s"Annotated module does not exist: $name")
- def checkExternal(name: String): Unit = moduleMap(name) match {
- case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name")
- case _ =>
- }
- def checkInstance(cn: ComponentName): Unit = {
- var containsCN = false
- def onStmt(name: String)(s: Statement): Statement = {
- s match {
- case WDefInstance(_, inst_name, module_name, tpe) =>
- if (name == inst_name) {
- containsCN = true
- checkExternal(module_name)
- }
- case _ =>
+ helpValueName = Some("<circuit>[.<module>[.<instance>]][,...]")
+ )
+ )
+
+ private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) {
+ case ((modNames, instNames), ann) =>
+ ann match {
+ case InlineAnnotation(CircuitName(c)) =>
+ (
+ circuit.modules.collect {
+ case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
+ }.toSet,
+ instNames
+ )
+ case InlineAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
+ case InlineAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
+ case _ => (modNames, instNames)
+ }
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ // TODO Add error check for more than one annotation for inlining
+ val (modNames, instNames) = collectAnns(state.circuit, state.annotations)
+ if (modNames.nonEmpty || instNames.nonEmpty) {
+ run(state.circuit, modNames, instNames, state.annotations)
+ } else {
+ state
+ }
+ }
+
+ // Checks the following properties:
+ // 1) All annotated modules exist
+ // 2) All annotated modules are InModules (can be inlined)
+ // 3) All annotated instances exist, and their modules can be inline
+ def check(c: Circuit, moduleNames: Set[ModuleName], instanceNames: Set[ComponentName]): Unit = {
+ val errors = mutable.ArrayBuffer[PassException]()
+ val moduleMap = InstanceKeyGraph(c).moduleMap
+ def checkExists(name: String): Unit =
+ if (!moduleMap.contains(name))
+ errors += new PassException(s"Annotated module does not exist: $name")
+ def checkExternal(name: String): Unit = moduleMap(name) match {
+ case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: $name")
+ case _ =>
+ }
+ def checkInstance(cn: ComponentName): Unit = {
+ var containsCN = false
+ def onStmt(name: String)(s: Statement): Statement = {
+ s match {
+ case WDefInstance(_, inst_name, module_name, tpe) =>
+ if (name == inst_name) {
+ containsCN = true
+ checkExternal(module_name)
}
- s map onStmt(name)
- }
- onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body)
- if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}")
+ case _ =>
+ }
+ s.map(onStmt(name))
}
+ onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body)
+ if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}")
+ }
- moduleNames.foreach{mn => checkExists(mn.name)}
- if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
- moduleNames.foreach{mn => checkExternal(mn.name)}
- if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
- instanceNames.foreach{cn => checkInstance(cn)}
- if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
- }
-
+ moduleNames.foreach { mn => checkExists(mn.name) }
+ if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
+ moduleNames.foreach { mn => checkExternal(mn.name) }
+ if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
+ instanceNames.foreach { cn => checkInstance(cn) }
+ if (errors.nonEmpty) throw new PassExceptions(errors.toSeq)
+ }
- def run(c: Circuit, modsToInline: Set[ModuleName], instsToInline: Set[ComponentName], annos: AnnotationSeq): CircuitState = {
+ def run(
+ c: Circuit,
+ modsToInline: Set[ModuleName],
+ instsToInline: Set[ComponentName],
+ annos: AnnotationSeq
+ ): CircuitState = {
def getInstancesOf(c: Circuit, modules: Set[String]): Set[(OfModule, Instance)] =
c.modules.foldLeft(Set[(OfModule, Instance)]()) { (set, d) =>
d match {
@@ -125,7 +136,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
case WDefInstance(info, instName, moduleName, instTpe) if modules.contains(moduleName) =>
instances += (OfModule(m.name) -> Instance(instName))
s
- case sx => sx map findInstances
+ case sx => sx.map(findInstances)
}
findInstances(m.body)
instances.toSet ++ set
@@ -135,7 +146,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
// Check annotations and circuit match up
check(c, modsToInline, instsToInline)
val flatModules = modsToInline.map(m => m.name)
- val flatInstances: Set[(OfModule, Instance)] = instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules)
+ val flatInstances: Set[(OfModule, Instance)] =
+ instsToInline.map(i => OfModule(i.module.name) -> Instance(i.name)) ++ getInstancesOf(c, flatModules)
val iGraph = InstanceKeyGraph(c)
val namespaceMap = collection.mutable.Map[String, Namespace]()
// Map of Module name to Map of instance name to Module name
@@ -144,11 +156,13 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
/** Add a prefix to all declarations updating a [[Namespace]] and appending to a [[RenameMap]] */
def appendNamePrefix(
currentModule: IsModule,
- nextModule: IsModule,
- prefix: String,
- ns: Namespace,
- renames: mutable.HashMap[String, String],
- renameMap: RenameMap)(s: Statement): Statement = {
+ nextModule: IsModule,
+ prefix: String,
+ ns: Namespace,
+ renames: mutable.HashMap[String, String],
+ renameMap: RenameMap
+ )(s: Statement
+ ): Statement = {
def onName(ofModuleOpt: Option[String])(name: String) = {
if (prefix.nonEmpty && !ns.tryName(prefix + name)) {
throw new Exception(s"Inlining failed. Inlined name '${prefix + name}' already exists")
@@ -164,25 +178,29 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
}
s match {
- case s: WDefInstance => s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
- case other => s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
+ case s: WDefInstance =>
+ s.map(onName(Some(s.module))).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
+ case other =>
+ s.map(onName(None)).map(appendNamePrefix(currentModule, nextModule, prefix, ns, renames, renameMap))
}
}
/** Modify all references */
def appendRefPrefix(
currentModule: IsModule,
- renames: mutable.HashMap[String, String])(s: Statement): Statement = {
- def onExpr(e: Expression): Expression = e match {
- case wr@ WRef(name, _, _, _) =>
- renames.get(name) match {
- case Some(prefixedName) => wr.copy(name = prefixedName)
- case None => wr
- }
- case ex => ex.map(onExpr)
- }
- s.map(onExpr).map(appendRefPrefix(currentModule, renames))
+ renames: mutable.HashMap[String, String]
+ )(s: Statement
+ ): Statement = {
+ def onExpr(e: Expression): Expression = e match {
+ case wr @ WRef(name, _, _, _) =>
+ renames.get(name) match {
+ case Some(prefixedName) => wr.copy(name = prefixedName)
+ case None => wr
+ }
+ case ex => ex.map(onExpr)
}
+ s.map(onExpr).map(appendRefPrefix(currentModule, renames))
+ }
val cache = mutable.HashMap.empty[ModuleTarget, Statement]
@@ -194,16 +212,19 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
val (renamesMap, renamesSeq) = {
val mutableDiGraph = new MutableDiGraph[(OfModule, Instance)]
// compute instance graph
- instMaps.foreach { case (grandParentOfMod, parents) =>
- parents.foreach { case (parentInst, parentOfMod) =>
- val from = grandParentOfMod -> parentInst
- mutableDiGraph.addVertex(from)
- instMaps(parentOfMod).foreach { case (childInst, _) =>
- val to = parentOfMod -> childInst
- mutableDiGraph.addVertex(to)
- mutableDiGraph.addEdge(from, to)
+ instMaps.foreach {
+ case (grandParentOfMod, parents) =>
+ parents.foreach {
+ case (parentInst, parentOfMod) =>
+ val from = grandParentOfMod -> parentInst
+ mutableDiGraph.addVertex(from)
+ instMaps(parentOfMod).foreach {
+ case (childInst, _) =>
+ val to = parentOfMod -> childInst
+ mutableDiGraph.addVertex(to)
+ mutableDiGraph.addEdge(from, to)
+ }
}
- }
}
val diGraph = DiGraph(mutableDiGraph)
@@ -226,10 +247,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
}
def fixupRefs(
- instMap: collection.Map[Instance, OfModule],
- currentModule: IsModule)(e: Expression): Expression = {
+ instMap: collection.Map[Instance, OfModule],
+ currentModule: IsModule
+ )(e: Expression
+ ): Expression = {
e match {
- case wsf@ WSubField(wr@ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
+ case wsf @ WSubField(wr @ WRef(ref, _, InstanceKind, _), field, tpe, gen) =>
val inst = currentModule.instOf(ref, instMap(Instance(ref)).value)
val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance))
val port = inst.ref(field)
@@ -242,12 +265,12 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
}
case None => wsf
}
- case wr@ WRef(name, _, InstanceKind, _) =>
+ case wr @ WRef(name, _, InstanceKind, _) =>
val inst = currentModule.instOf(name, instMap(Instance(name)).value)
val renamesOpt = renamesMap.get(OfModule(currentModule.module) -> Instance(inst.instance))
val comp = currentModule.ref(name)
renamesOpt.flatMap(_.get(comp)).getOrElse(Seq(comp)) match {
- case Seq(car: ReferenceTarget) => wr.copy(name=car.ref)
+ case Seq(car: ReferenceTarget) => wr.copy(name = car.ref)
}
case ex => ex.map(fixupRefs(instMap, currentModule))
}
@@ -258,7 +281,8 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
val ns = namespaceMap.getOrElseUpdate(currentModuleName, Namespace(iGraph.moduleMap(currentModuleName)))
val instMap = instMaps(OfModule(currentModuleName))
s match {
- case wDef@ WDefInstance(_, instName, modName, _) if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) =>
+ case wDef @ WDefInstance(_, instName, modName, _)
+ if flatInstances.contains(OfModule(currentModuleName) -> Instance(instName)) =>
val renames = renamesMap(OfModule(currentModuleName) -> Instance(instName))
val toInline = iGraph.moduleMap(modName) match {
case m: ExtModule => throw new PassException(s"Cannot inline external module ${m.name}")
@@ -269,7 +293,7 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
val bodyx = {
val module = currentModule.copy(module = modName)
- cache.getOrElseUpdate(module, Block(ports :+ toInline.body) map onStmt(module))
+ cache.getOrElseUpdate(module, Block(ports :+ toInline.body).map(onStmt(module)))
}
val names = "" +: Uniquify
@@ -294,14 +318,14 @@ class InlineInstances extends Transform with DependencyAPIMigration with Registe
renamedBody
case sx =>
sx
- .map(fixupRefs(instMap, currentModule))
- .map(onStmt(currentModule))
+ .map(fixupRefs(instMap, currentModule))
+ .map(onStmt(currentModule))
}
}
val flatCircuit = c.copy(modules = c.modules.flatMap {
case m if flatModules.contains(m.name) => None
- case m =>
+ case m =>
Some(m.map(onStmt(ModuleName(m.name, CircuitName(c.main)))))
})
diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala
index 8b7b733a..5d59e075 100644
--- a/src/main/scala/firrtl/passes/Legalize.scala
+++ b/src/main/scala/firrtl/passes/Legalize.scala
@@ -1,11 +1,11 @@
package firrtl.passes
import firrtl.PrimOps._
-import firrtl.Utils.{BoolType, error, zero}
+import firrtl.Utils.{error, zero, BoolType}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.transforms.ConstantPropagation
-import firrtl.{Transform, bitWidth}
+import firrtl.{bitWidth, Transform}
import firrtl.Mappers._
// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
@@ -62,30 +62,31 @@ object Legalize extends Pass {
} else {
val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
val expr = t match {
- case UIntType(_) => bits
- case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
+ case UIntType(_) => bits
+ case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
}
Connect(c.info, c.loc, expr)
}
}
- def run (c: Circuit): Circuit = {
- def legalizeE(expr: Expression): Expression = expr map legalizeE match {
- case prim: DoPrim => prim.op match {
- case Shr => legalizeShiftRight(prim)
- case Pad => legalizePad(prim)
- case Bits | Head | Tail => legalizeBitExtract(prim)
- case _ => prim
- }
+ def run(c: Circuit): Circuit = {
+ def legalizeE(expr: Expression): Expression = expr.map(legalizeE) match {
+ case prim: DoPrim =>
+ prim.op match {
+ case Shr => legalizeShiftRight(prim)
+ case Pad => legalizePad(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
+ case _ => prim
+ }
case e => e // respect pre-order traversal
}
- def legalizeS (s: Statement): Statement = {
+ def legalizeS(s: Statement): Statement = {
val legalizedStmt = s match {
case c: Connect => legalizeConnect(c)
case _ => s
}
- legalizedStmt map legalizeS map legalizeE
+ legalizedStmt.map(legalizeS).map(legalizeE)
}
- c copy (modules = c.modules map (_ map legalizeS))
+ c.copy(modules = c.modules.map(_.map(legalizeS)))
}
}
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index ace4f3e8..ad608cec 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -3,8 +3,26 @@
package firrtl.passes
import firrtl.analyses.{InstanceKeyGraph, SymbolTable}
-import firrtl.annotations.{CircuitTarget, MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, ReferenceTarget}
-import firrtl.{CircuitForm, CircuitState, DependencyAPIMigration, InstanceKind, Kind, MemKind, PortKind, RenameMap, Transform, UnknownForm, Utils}
+import firrtl.annotations.{
+ CircuitTarget,
+ MemoryInitAnnotation,
+ MemoryRandomInitAnnotation,
+ ModuleTarget,
+ ReferenceTarget
+}
+import firrtl.{
+ CircuitForm,
+ CircuitState,
+ DependencyAPIMigration,
+ InstanceKind,
+ Kind,
+ MemKind,
+ PortKind,
+ RenameMap,
+ Transform,
+ UnknownForm,
+ Utils
+}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.stage.TransformManager.TransformDependency
@@ -20,18 +38,19 @@ import scala.collection.mutable
object LowerTypes extends Transform with DependencyAPIMigration {
override def prerequisites: Seq[TransformDependency] = Seq(
Dependency(RemoveAccesses), // we require all SubAccess nodes to have been removed
- Dependency(CheckTypes), // we require all types to be correct
- Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType)
- Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded
+ Dependency(CheckTypes), // we require all types to be correct
+ Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType)
+ Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded
)
- override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty
+ override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty
override def invalidates(a: Transform): Boolean = a match {
case ResolveFlows => true // we generate UnknownFlow for now (could be fixed)
- case _ => false
+ case _ => false
}
/** Delimiter used in lowering names */
val delim = "_"
+
/** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name
* @param e [[firrtl.ir.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]]
* @return Lowered name of e
@@ -39,8 +58,8 @@ object LowerTypes extends Transform with DependencyAPIMigration {
*/
def loweredName(e: Expression): String = e match {
case e: Reference => e.name
- case e: SubField => s"${loweredName(e.expr)}$delim${e.name}"
- case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}"
+ case e: SubField => s"${loweredName(e.expr)}$delim${e.name}"
+ case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}"
}
def loweredName(s: Seq[String]): String = s.mkString(delim)
@@ -48,7 +67,7 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// When memories are lowered to ground type, we have to fix the init annotation or error on it.
val (memInitAnnos, otherAnnos) = state.annotations.partition {
case _: MemoryRandomInitAnnotation => false
- case _: MemoryInitAnnotation => true
+ case _: MemoryInitAnnotation => true
case _ => false
}
val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule)
@@ -61,14 +80,18 @@ object LowerTypes extends Transform with DependencyAPIMigration {
val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3)
// chain module renames in topological order
- val moduleRenames = resultAndRenames.map{ case(m,r, _) => m.name -> r }.toMap
+ val moduleRenames = resultAndRenames.map { case (m, r, _) => m.name -> r }.toMap
val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator
- val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a,b) => a.andThen(b))
+ val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a, b) => a.andThen(b))
state.copy(circuit = result, renames = Some(renames), annotations = newAnnos)
}
- private def onModule(c: CircuitTarget, m: DefModule, memoryInit: Seq[MemoryInitAnnotation]): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = {
+ private def onModule(
+ c: CircuitTarget,
+ m: DefModule,
+ memoryInit: Seq[MemoryInitAnnotation]
+ ): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = {
val renameMap = RenameMap()
val ref = c.module(m.name)
@@ -86,26 +109,36 @@ object LowerTypes extends Transform with DependencyAPIMigration {
}
// We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names.
- private def lowerPorts(ref: ModuleTarget, m: DefModule, renameMap: RenameMap):
- (DefModule, Seq[(String, Seq[Reference])]) = {
+ private def lowerPorts(
+ ref: ModuleTarget,
+ m: DefModule,
+ renameMap: RenameMap
+ ): (DefModule, Seq[(String, Seq[Reference])]) = {
val namespace = mutable.HashSet[String]() ++ m.ports.map(_.name)
val loweredPortsAndRefs = m.ports.flatMap { p =>
- val fieldsAndRefs = DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set())
- fieldsAndRefs.map { case (f, ref) =>
- (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind)))
+ val fieldsAndRefs =
+ DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set())
+ fieldsAndRefs.map {
+ case (f, ref) =>
+ (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind)))
}
}
val newM = m match {
- case e : ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1))
- case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1))
+ case e: ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1))
+ case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1))
}
(newM, loweredPortsAndRefs.map(_._2))
}
- private def onStatement(s: Statement)(implicit symbols: LoweringTable, memInit: Seq[MemoryInitAnnotation]): Statement = s match {
+ private def onStatement(
+ s: Statement
+ )(
+ implicit symbols: LoweringTable,
+ memInit: Seq[MemoryInitAnnotation]
+ ): Statement = s match {
// declarations
- case d : DefWire =>
- Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name=name, tpe=tpe) })
+ case d: DefWire =>
+ Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name = name, tpe = tpe) })
case d @ DefRegister(info, _, _, clock, reset, _) =>
// clock and reset are always of ground type
val loweredClock = onExpression(clock)
@@ -113,41 +146,41 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// It is important to first lower the declaration, because the reset can refer to the register itself!
val loweredRegs = symbols.lower(d.name, d.tpe, firrtl.RegKind)
val inits = Utils.create_exps(d.init).map(onExpression)
- Block(
- loweredRegs.zip(inits).map { case ((name, tpe, _), init) =>
+ Block(loweredRegs.zip(inits).map {
+ case ((name, tpe, _), init) =>
DefRegister(info, name, tpe, loweredClock, loweredReset, init)
})
- case d : DefNode =>
+ case d: DefNode =>
val values = Utils.create_exps(d.value).map(onExpression)
- Block(
- symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map{ case((name, tpe, _), value) =>
+ Block(symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map {
+ case ((name, tpe, _), value) =>
assert(tpe == value.tpe)
DefNode(d.info, name, value)
})
- case d : DefMemory =>
+ case d: DefMemory =>
// TODO: as an optimization, we could just skip ground type memories here.
// This would require that we don't error in getReferences() but instead return the old reference.
val mems = symbols.lower(d)
- if(mems.length > 1 && memInit.exists(_.target.ref == d.name)) {
+ if (mems.length > 1 && memInit.exists(_.target.ref == d.name)) {
val mod = memInit.find(_.target.ref == d.name).get.target.encapsulatingModule
val msg = s"[module $mod] Cannot initialize memory ${d.name} of non ground type ${d.dataType.serialize}"
throw new RuntimeException(msg)
}
Block(mems)
- case d : DefInstance => symbols.lower(d)
+ case d: DefInstance => symbols.lower(d)
// connections
case Connect(info, loc, expr) =>
- if(!expr.tpe.isInstanceOf[GroundType]) {
+ if (!expr.tpe.isInstanceOf[GroundType]) {
throw new RuntimeException(s"LowerTypes expects Connects to have been expanded! ${expr.tpe.serialize}")
}
val rhs = onExpression(expr)
// We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
val lhs = symbols.getReferences(loc.asInstanceOf[RefLikeExpression])
Block(lhs.map(loc => Connect(info, loc, rhs)))
- case p : PartialConnect =>
+ case p: PartialConnect =>
throw new RuntimeException(s"LowerTypes expects PartialConnects to be resolved! $p")
case IsInvalid(info, expr) =>
- if(!expr.tpe.isInstanceOf[GroundType]) {
+ if (!expr.tpe.isInstanceOf[GroundType]) {
throw new RuntimeException(s"LowerTypes expects IsInvalids to have been expanded! ${expr.tpe.serialize}")
}
// We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
@@ -172,15 +205,18 @@ object LowerTypes extends Transform with DependencyAPIMigration {
// Holds the first level of the module-level namespace.
// (i.e. everything that can be addressed directly by a Reference node)
private class LoweringSymbolTable extends SymbolTable {
- def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name)
+ def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name)
def declareInstance(name: String, module: String): Unit = symbols.append(name)
private val symbols = mutable.ArrayBuffer[String]()
def getSymbolNames: Iterable[String] = symbols
}
// Lowers types and keeps track of references to lowered types.
-private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: ModuleTarget,
- portNameToExprs: Seq[(String, Seq[Reference])]) {
+private class LoweringTable(
+ table: LoweringSymbolTable,
+ renameMap: RenameMap,
+ m: ModuleTarget,
+ portNameToExprs: Seq[(String, Seq[Reference])]) {
private val portNames: Set[String] = portNameToExprs.map(_._2.head.name).toSet
private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames
// Serialized old access string to new ground type reference.
@@ -196,10 +232,11 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m:
nameToExprs ++= refs.map { case (name, r) => name -> List(r) }
newInst
}
+
/** used to lower nodes, registers and wires */
def lower(name: String, tpe: Type, kind: Kind, flip: Orientation = Default): Seq[(String, Type, Orientation)] = {
val fieldsAndRefs = DestructTypes.destruct(m, Field(name, flip, tpe), namespace, renameMap, portNames)
- nameToExprs ++= fieldsAndRefs.map{ case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) }
+ nameToExprs ++= fieldsAndRefs.map { case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) }
fieldsAndRefs.map { case (f, _) => (f.name, f.tpe, f.flip) }
}
def lower(p: Port): Seq[Port] = {
@@ -211,10 +248,10 @@ private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m:
// We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left.
private def serialize(expr: RefLikeExpression): String = expr match {
- case Reference(name, _, _, _) => name
- case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name
+ case Reference(name, _, _, _) => name
+ case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name
case SubIndex(expr, index, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "[" + index.toString + "]"
- case a : SubAccess =>
+ case a: SubAccess =>
throw new RuntimeException(s"LowerTypes expects all SubAccesses to have been expanded! ${a.serialize}")
}
}
@@ -230,13 +267,18 @@ private object DestructTypes {
* - generates a list of all old reference name that now refer to the particular ground type field
* - updates namespace with all possibly conflicting names
*/
- def destruct(m: ModuleTarget, ref: Field, namespace: Namespace, renameMap: RenameMap, reserved: Set[String]):
- Seq[(Field, String)] = {
+ def destruct(
+ m: ModuleTarget,
+ ref: Field,
+ namespace: Namespace,
+ renameMap: RenameMap,
+ reserved: Set[String]
+ ): Seq[(Field, String)] = {
// field renames (uniquify) are computed bottom up
val (rename, _) = uniquify(ref, namespace, reserved)
// early exit for ground types that do not need renaming
- if(ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) {
+ if (ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) {
return List((ref, ref.name))
}
@@ -253,8 +295,13 @@ private object DestructTypes {
* Note that the list of fields is only of the child fields, and needs a SubField node
* instead of a flat Reference when turning them into access expressions.
*/
- def destructInstance(m: ModuleTarget, instance: DefInstance, namespace: Namespace, renameMap: RenameMap,
- reserved: Set[String]): (DefInstance, Seq[(String, SubField)]) = {
+ def destructInstance(
+ m: ModuleTarget,
+ instance: DefInstance,
+ namespace: Namespace,
+ renameMap: RenameMap,
+ reserved: Set[String]
+ ): (DefInstance, Seq[(String, SubField)]) = {
val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved)
val newName = rename.map(_.name).getOrElse(instance.name)
@@ -266,14 +313,14 @@ private object DestructTypes {
}
// rename all references to the instance if necessary
- if(newName != instance.name) {
+ if (newName != instance.name) {
renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module))
}
// The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered.
val newInstance = instance.copy(name = newName, tpe = BundleType(children.map(_._1)))
val instanceRef = Reference(newName, newInstance.tpe, InstanceKind)
- val refs = children.map{ case(c,r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) }
+ val refs = children.map { case (c, r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) }
(newInstance, refs)
}
@@ -285,8 +332,13 @@ private object DestructTypes {
* e.g. ("mem_a.r.clk", "mem.r.clk") and ("mem_b.r.clk", "mem.r.clk")
* Thus it is appropriate to groupBy old reference string instead of just inserting into a hash table.
*/
- def destructMemory(m: ModuleTarget, mem: DefMemory, namespace: Namespace, renameMap: RenameMap,
- reserved: Set[String]): (Seq[DefMemory], Seq[(String, SubField)]) = {
+ def destructMemory(
+ m: ModuleTarget,
+ mem: DefMemory,
+ namespace: Namespace,
+ renameMap: RenameMap,
+ reserved: Set[String]
+ ): (Seq[DefMemory], Seq[(String, SubField)]) = {
// Uniquify the lowered memory names: When memories get split up into ground types, the access order is changes.
// E.g. `mem.r.data.x` becomes `mem_x.r.data`.
// This is why we need to create the new bundle structure before we can resolve any name clashes.
@@ -301,48 +353,50 @@ private object DestructTypes {
// the "old dummy field" is used as a template for the new memory port types
val oldDummyField = Field("dummy", Default, MemPortUtils.memType(mem.copy(dataType = BoolType)))
- val newMemAndSubFields = res.map { case (field, refs) =>
- val newMem = mem.copy(name = field.name, dataType = field.tpe)
- val newMemRef = m.ref(field.name)
- val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType
- if(memWasRenamed) { renameMap.record(oldMemRef, newMemRef) }
-
- val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind)
- val refSuffixes = refs.map(_.component).filterNot(_.isEmpty)
-
- val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port =>
- val oldPortRef = oldMemRef.field(port.name)
- val newPortRef = newMemRef.field(port.name)
-
- val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe
- val newPortAccess = SubField(newMemReference, port.name, newPortType)
-
- port.tpe.asInstanceOf[BundleType].fields.map { portField =>
- val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata"
- val isMaskField = portField.name == "mask" || portField.name == "wmask"
- val isDataOrMaskField = isDataField || isMaskField
- val oldFieldRefs = if(memWasRenamed && isDataOrMaskField) {
- // there might have been multiple different fields which now alias to the same lowered field.
- val oldPortFieldBaseRef = oldPortRef.field(portField.name)
- refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s))
- } else {
- List(oldPortRef.field(portField.name))
+ val newMemAndSubFields = res.map {
+ case (field, refs) =>
+ val newMem = mem.copy(name = field.name, dataType = field.tpe)
+ val newMemRef = m.ref(field.name)
+ val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType
+ if (memWasRenamed) { renameMap.record(oldMemRef, newMemRef) }
+
+ val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind)
+ val refSuffixes = refs.map(_.component).filterNot(_.isEmpty)
+
+ val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port =>
+ val oldPortRef = oldMemRef.field(port.name)
+ val newPortRef = newMemRef.field(port.name)
+
+ val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe
+ val newPortAccess = SubField(newMemReference, port.name, newPortType)
+
+ port.tpe.asInstanceOf[BundleType].fields.map { portField =>
+ val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata"
+ val isMaskField = portField.name == "mask" || portField.name == "wmask"
+ val isDataOrMaskField = isDataField || isMaskField
+ val oldFieldRefs = if (memWasRenamed && isDataOrMaskField) {
+ // there might have been multiple different fields which now alias to the same lowered field.
+ val oldPortFieldBaseRef = oldPortRef.field(portField.name)
+ refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s))
+ } else {
+ List(oldPortRef.field(portField.name))
+ }
+
+ val newPortType = if (isDataField) { newMem.dataType }
+ else { portField.tpe }
+ val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType)
+
+ // record renames only for the data field which is the only port field of non-ground type
+ val newPortFieldRef = newPortRef.field(portField.name)
+ if (memWasRenamed && isDataOrMaskField) {
+ oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) }
+ }
+
+ val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs)
+ (oldFieldStringRef, newPortFieldAccess)
}
-
- val newPortType = if(isDataField) { newMem.dataType } else { portField.tpe }
- val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType)
-
- // record renames only for the data field which is the only port field of non-ground type
- val newPortFieldRef = newPortRef.field(portField.name)
- if(memWasRenamed && isDataOrMaskField) {
- oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) }
- }
-
- val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs)
- (oldFieldStringRef, newPortFieldAccess)
}
- }
- (newMem, subFields)
+ (newMem, subFields)
}
(newMemAndSubFields.map(_._1), newMemAndSubFields.flatMap(_._2))
@@ -356,22 +410,30 @@ private object DestructTypes {
Field(mem.name, Default, BundleType(fields))
}
- private def recordRenames(fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], renameMap: RenameMap, parent: ParentRef):
- Unit = {
+ private def recordRenames(
+ fieldToRefs: Seq[(Field, Seq[ReferenceTarget])],
+ renameMap: RenameMap,
+ parent: ParentRef
+ ): Unit = {
// TODO: if we group by ReferenceTarget, we could reduce the number of calls to `record`. Is it worth it?
- fieldToRefs.foreach { case(field, refs) =>
- val fieldRef = parent.ref(field.name)
- refs.foreach{ r => renameMap.record(r, fieldRef) }
+ fieldToRefs.foreach {
+ case (field, refs) =>
+ val fieldRef = parent.ref(field.name)
+ refs.foreach { r => renameMap.record(r, fieldRef) }
}
}
private def extractGroundTypeRefString(refs: Seq[ReferenceTarget]): String = {
- if (refs.isEmpty) { "" } else {
+ if (refs.isEmpty) { "" }
+ else {
// Since we depend on ExpandConnects any reference we encounter will be of ground type
// and thus the one with the longest access path.
- refs.reduceLeft((x, y) => if (x.component.length > y.component.length) x else y)
+ refs
+ .reduceLeft((x, y) => if (x.component.length > y.component.length) x else y)
// convert references to strings relative to the module
- .serialize.dropWhile(_ != '>').tail
+ .serialize
+ .dropWhile(_ != '>')
+ .tail
}
}
@@ -385,14 +447,19 @@ private object DestructTypes {
* @return a sequence of ground type fields with new names and, for each field,
* a sequence of old references that should to be renamed to point to the particular field
*/
- private def destruct(prefix: String, oldParent: ParentRef, oldField: Field,
- isVecField: Boolean, rename: Option[RenameNode]): Seq[(Field, Seq[ReferenceTarget])] = {
+ private def destruct(
+ prefix: String,
+ oldParent: ParentRef,
+ oldField: Field,
+ isVecField: Boolean,
+ rename: Option[RenameNode]
+ ): Seq[(Field, Seq[ReferenceTarget])] = {
val newName = rename.map(_.name).getOrElse(oldField.name)
val oldRef = oldParent.ref(oldField.name, isVecField)
oldField.tpe match {
- case _ : GroundType => List((oldField.copy(name = prefix + newName), List(oldRef)))
- case _ : BundleType | _ : VectorType =>
+ case _: GroundType => List((oldField.copy(name = prefix + newName), List(oldRef)))
+ case _: BundleType | _: VectorType =>
val newPrefix = prefix + newName + LowerTypes.delim
val isVecField = oldField.tpe.isInstanceOf[VectorType]
val fields = getFields(oldField.tpe)
@@ -401,7 +468,7 @@ private object DestructTypes {
destruct(newPrefix, RefParentRef(oldRef), f, isVecField, rename.flatMap(_.children.get(f.name)))
}
// the bundle/vec reference refers to all children
- children.map{ case(c, r) => (c, r :+ oldRef) }
+ children.map { case (c, r) => (c, r :+ oldRef) }
}
}
@@ -409,7 +476,8 @@ private object DestructTypes {
/** Implements the core functionality of the old Uniquify pass: rename bundle fields and top-level references
* where necessary in order to avoid name clashes when lowering aggregate type with the `_` delimiter.
- * We don't actually do the rename here but just calculate a rename tree. */
+ * We don't actually do the rename here but just calculate a rename tree.
+ */
private def uniquify(ref: Field, namespace: Namespace, reserved: Set[String]): (Option[RenameNode], Seq[String]) = {
// ensure that there are no name clashes with the list of reserved (port) names
val newRefName = findValidPrefix(ref.name, reserved.contains)
@@ -426,23 +494,23 @@ private object DestructTypes {
// We added f.name in previous map, delete if we change it
val renamed = prefix != ref.name
if (renamed) {
- if(!reserved.contains(ref.name)) namespace -= ref.name
+ if (!reserved.contains(ref.name)) namespace -= ref.name
namespace += prefix
}
val suffixes = renamedFieldNames.map(f => prefix + LowerTypes.delim + f)
val anyChildRenamed = renamedFields.exists(_._1.isDefined)
- val rename = if(renamed || anyChildRenamed){
- val children = renamedFields.map(_._1).zip(fields).collect{ case (Some(r), f) => f.name -> r }.toMap
+ val rename = if (renamed || anyChildRenamed) {
+ val children = renamedFields.map(_._1).zip(fields).collect { case (Some(r), f) => f.name -> r }.toMap
Some(RenameNode(prefix, children))
} else { None }
(rename, suffixes :+ prefix)
- case v : VectorType=>
+ case v: VectorType =>
// if Vecs are to be lowered, we can just treat them like a bundle
uniquify(ref.copy(tpe = vecToBundle(v)), namespace, reserved)
- case _ : GroundType =>
- if(newRefName == ref.name) {
+ case _: GroundType =>
+ if (newRefName == ref.name) {
(None, List(ref.name))
} else {
(Some(RenameNode(newRefName, Map())), List(newRefName))
@@ -452,22 +520,23 @@ private object DestructTypes {
}
/** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check
- * because elts could be Seq("") In this case, we're just really checking if prefix itself collides */
+ * because elts could be Seq("") In this case, we're just really checking if prefix itself collides
+ */
@tailrec
private def findValidPrefix(prefix: String, inNamespace: String => Boolean, elts: Seq[String] = List("")): String = {
elts.find(elt => inNamespace(prefix + elt)) match {
case Some(_) => findValidPrefix(prefix + "_", inNamespace, elts)
- case None => prefix
+ case None => prefix
}
}
private def getFields(tpe: Type): Seq[Field] = tpe match {
case BundleType(fields) => fields
- case v : VectorType => vecToBundle(v).fields
+ case v: VectorType => vecToBundle(v).fields
}
private def vecToBundle(v: VectorType): BundleType = {
- BundleType(( 0 until v.size).map(i => Field(i.toString, Default, v.tpe)))
+ BundleType((0 until v.size).map(i => Field(i.toString, Default, v.tpe)))
}
/** Used to abstract over module and reference parents.
@@ -480,6 +549,7 @@ private object DestructTypes {
}
private case class RefParentRef(r: ReferenceTarget) extends ParentRef {
override def ref(name: String, asVecField: Boolean): ReferenceTarget =
- if(asVecField) { r.index(name.toInt) } else { r.field(name) }
+ if (asVecField) { r.index(name.toInt) }
+ else { r.field(name) }
}
}
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index ca5c2544..79560605 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -15,23 +15,21 @@ object PadWidths extends Pass {
override def prerequisites =
((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
+ ++ firrtl.stage.Forms.LowForm
+ - Dependency(firrtl.passes.Legalize)
+ + Dependency(firrtl.passes.RemoveValidIf)).toSeq
override def optionalPrerequisites = Seq(Dependency[firrtl.transforms.ConstantPropagation])
override def optionalPrerequisiteOf =
- Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency(firrtl.passes.memlib.VerilogMemDelays), Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case _: firrtl.transforms.ConstantPropagation | Legalize => true
case _ => false
}
- private def width(t: Type): Int = bitWidth(t).toInt
+ private def width(t: Type): Int = bitWidth(t).toInt
private def width(e: Expression): Int = width(e.tpe)
// Returns an expression with the correct integer width
private def fixup(i: Int)(e: Expression) = {
@@ -54,31 +52,31 @@ object PadWidths extends Pass {
}
// Recursive, updates expression so children exp's have correct widths
- private def onExp(e: Expression): Expression = e map onExp match {
+ private def onExp(e: Expression): Expression = e.map(onExp) match {
case Mux(cond, tval, fval, tpe) =>
Mux(cond, fixup(width(tpe))(tval), fixup(width(tpe))(fval), tpe)
- case ex: ValidIf => ex copy (value = fixup(width(ex.tpe))(ex.value))
- case ex: DoPrim => ex.op match {
- case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor |
- Add | Sub | Mul | Div | Rem | Shr =>
- // sensitive ops
- ex map fixup((ex.args map width foldLeft 0)(math.max))
- case Dshl =>
- // special case as args aren't all same width
- ex copy (op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
- case _ => ex
- }
+ case ex: ValidIf => ex.copy(value = fixup(width(ex.tpe))(ex.value))
+ case ex: DoPrim =>
+ ex.op match {
+ case Lt | Leq | Gt | Geq | Eq | Neq | Not | And | Or | Xor | Add | Sub | Mul | Div | Rem | Shr =>
+ // sensitive ops
+ ex.map(fixup((ex.args.map(width).foldLeft(0))(math.max)))
+ case Dshl =>
+ // special case as args aren't all same width
+ ex.copy(op = Dshlw, args = Seq(fixup(width(ex.tpe))(ex.args.head), ex.args(1)))
+ case _ => ex
+ }
case ex => ex
}
// Recursive. Fixes assignments and register initialization widths
- private def onStmt(s: Statement): Statement = s map onExp match {
+ private def onStmt(s: Statement): Statement = s.map(onExp) match {
case sx: Connect =>
- sx copy (expr = fixup(width(sx.loc))(sx.expr))
+ sx.copy(expr = fixup(width(sx.loc))(sx.expr))
case sx: DefRegister =>
- sx copy (init = fixup(width(sx.tpe))(sx.init))
- case sx => sx map onStmt
+ sx.copy(init = fixup(width(sx.tpe))(sx.init))
+ case sx => sx.map(onStmt)
}
- def run(c: Circuit): Circuit = c copy (modules = c.modules map (_ map onStmt))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(_.map(onStmt)))
}
diff --git a/src/main/scala/firrtl/passes/Pass.scala b/src/main/scala/firrtl/passes/Pass.scala
index 036bd06a..b5eac4ed 100644
--- a/src/main/scala/firrtl/passes/Pass.scala
+++ b/src/main/scala/firrtl/passes/Pass.scala
@@ -8,7 +8,7 @@ import firrtl.{CircuitState, FirrtlUserException, Transform}
* Has an [[UnknownForm]], because larger [[Transform]] should specify form
*/
trait Pass extends Transform with DependencyAPIMigration {
- def run(c: Circuit): Circuit
+ def run(c: Circuit): Circuit
def execute(state: CircuitState): CircuitState = state.copy(circuit = run(state.circuit))
}
diff --git a/src/main/scala/firrtl/passes/PullMuxes.scala b/src/main/scala/firrtl/passes/PullMuxes.scala
index b805b5fc..27543d63 100644
--- a/src/main/scala/firrtl/passes/PullMuxes.scala
+++ b/src/main/scala/firrtl/passes/PullMuxes.scala
@@ -11,38 +11,50 @@ object PullMuxes extends Pass {
override def invalidates(a: Transform) = false
def run(c: Circuit): Circuit = {
- def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match {
- case ex: WSubField => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubField(exx.tval, ex.name, ex.tpe, ex.flow),
- WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubIndex => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow),
- WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubAccess => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow),
- WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex => ex
- }
- def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e
- val modulesx = c.modules.map {
- case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
- case (m:ExtModule) => m
- }
- Circuit(c.info, modulesx, c.main)
- }
+ def pull_muxes_e(e: Expression): Expression = e.map(pull_muxes_e) match {
+ case ex: WSubField =>
+ ex.expr match {
+ case exx: Mux =>
+ Mux(
+ exx.cond,
+ WSubField(exx.tval, ex.name, ex.tpe, ex.flow),
+ WSubField(exx.fval, ex.name, ex.tpe, ex.flow),
+ ex.tpe
+ )
+ case exx: ValidIf => ValidIf(exx.cond, WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubIndex =>
+ ex.expr match {
+ case exx: Mux =>
+ Mux(
+ exx.cond,
+ WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow),
+ WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow),
+ ex.tpe
+ )
+ case exx: ValidIf => ValidIf(exx.cond, WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubAccess =>
+ ex.expr match {
+ case exx: Mux =>
+ Mux(
+ exx.cond,
+ WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow),
+ WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow),
+ ex.tpe
+ )
+ case exx: ValidIf => ValidIf(exx.cond, WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex => ex
+ }
+ def pull_muxes(s: Statement): Statement = s.map(pull_muxes).map(pull_muxes_e)
+ val modulesx = c.modules.map {
+ case (m: Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
+ case (m: ExtModule) => m
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 18db5939..015346ff 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -2,7 +2,7 @@
package firrtl.passes
-import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubIndex, WSubField}
+import firrtl.{Namespace, Transform, WRef, WSubAccess, WSubField, WSubIndex}
import firrtl.PrimOps.{And, Eq}
import firrtl.ir._
import firrtl.Mappers._
@@ -17,10 +17,12 @@ import scala.collection.mutable
object RemoveAccesses extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ZeroLengthVecs),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ZeroLengthVecs),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case Uniquify | ResolveKinds | ResolveFlows => true
@@ -28,8 +30,8 @@ object RemoveAccesses extends Pass {
}
private def AND(e1: Expression, e2: Expression) =
- if(e1 == one) e2
- else if(e2 == one) e1
+ if (e1 == one) e2
+ else if (e2 == one) e1
else DoPrim(And, Seq(e1, e2), Nil, BoolType)
private def EQV(e1: Expression, e2: Expression): Expression =
@@ -45,30 +47,35 @@ object RemoveAccesses extends Pass {
* Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1)))
*/
private def getLocations(e: Expression): Seq[Location] = e match {
- case e: WRef => create_exps(e).map(Location(_,one))
+ case e: WRef => create_exps(e).map(Location(_, one))
case e: WSubIndex =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
- for ((l, i) <- ls.zipWithIndex
- if ((i % stride) >= start) & ((i % stride) < end)) yield l
+ for (
+ (l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)
+ ) yield l
case e: WSubField =>
val ls = getLocations(e.expr)
val start = get_point(e)
val end = start + get_size(e.tpe)
val stride = get_size(e.expr.tpe)
- for ((l, i) <- ls.zipWithIndex
- if ((i % stride) >= start) & ((i % stride) < end)) yield l
+ for (
+ (l, i) <- ls.zipWithIndex
+ if ((i % stride) >= start) & ((i % stride) < end)
+ ) yield l
case e: WSubAccess =>
val ls = getLocations(e.expr)
val stride = get_size(e.tpe)
val wrap = e.expr.tpe.asInstanceOf[VectorType].size
- ls.zipWithIndex map {case (l, i) =>
- val c = (i / stride) % wrap
- val basex = l.base
- val guardx = AND(l.guard,EQV(UIntLiteral(c),e.index))
- Location(basex,guardx)
+ ls.zipWithIndex.map {
+ case (l, i) =>
+ val c = (i / stride) % wrap
+ val basex = l.base
+ val guardx = AND(l.guard, EQV(UIntLiteral(c), e.index))
+ Location(basex, guardx)
}
}
@@ -78,10 +85,10 @@ object RemoveAccesses extends Pass {
var ret: Boolean = false
def rec_has_access(e: Expression): Expression = {
e match {
- case _ : WSubAccess => ret = true
+ case _: WSubAccess => ret = true
case _ =>
}
- e map rec_has_access
+ e.map(rec_has_access)
}
rec_has_access(e)
ret
@@ -90,7 +97,7 @@ object RemoveAccesses extends Pass {
// This improves the performance of this pass
private val createExpsCache = mutable.HashMap[Expression, Seq[Expression]]()
private def create_exps(e: Expression) =
- createExpsCache getOrElseUpdate (e, firrtl.Utils.create_exps(e))
+ createExpsCache.getOrElseUpdate(e, firrtl.Utils.create_exps(e))
def run(c: Circuit): Circuit = {
def remove_m(m: Module): Module = {
@@ -105,21 +112,21 @@ object RemoveAccesses extends Pass {
*/
val stmts = mutable.ArrayBuffer[Statement]()
def removeSource(e: Expression): Expression = e match {
- case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(e) =>
+ case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(e) =>
val rs = getLocations(e)
- rs find (x => x.guard != one) match {
+ rs.find(x => x.guard != one) match {
case None => throwInternalError(s"removeSource: shouldn't be here - $e")
case Some(_) =>
val (wire, temp) = create_temp(e)
val temps = create_exps(temp)
def getTemp(i: Int) = temps(i % temps.size)
stmts += wire
- rs.zipWithIndex foreach {
+ rs.zipWithIndex.foreach {
case (x, i) if i < temps.size =>
- stmts += IsInvalid(get_info(s),getTemp(i))
- stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += IsInvalid(get_info(s), getTemp(i))
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
case (x, i) =>
- stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += Conditionally(get_info(s), x.guard, Connect(get_info(s), getTemp(i), x.base), EmptyStmt)
}
temp
}
@@ -129,14 +136,16 @@ object RemoveAccesses extends Pass {
/** Replaces a subaccess in a given sink expression
*/
def removeSink(info: Info, loc: Expression): Expression = loc match {
- case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if hasAccess(loc) =>
+ case (_: WSubAccess | _: WSubField | _: WSubIndex | _: WRef) if hasAccess(loc) =>
val ls = getLocations(loc)
- if (ls.size == 1 & weq(ls.head.guard,one)) loc
+ if (ls.size == 1 & weq(ls.head.guard, one)) loc
else {
val (wire, temp) = create_temp(loc)
stmts += wire
- ls foreach (x => stmts +=
- Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt))
+ ls.foreach(x =>
+ stmts +=
+ Conditionally(info, x.guard, Connect(info, x.base, temp), EmptyStmt)
+ )
temp
}
case _ => loc
@@ -150,7 +159,7 @@ object RemoveAccesses extends Pass {
case w: WSubAccess => removeSource(WSubAccess(w.expr, fixSource(w.index), w.tpe, w.flow))
//case w: WSubIndex => removeSource(w)
//case w: WSubField => removeSource(w)
- case x => x map fixSource
+ case x => x.map(fixSource)
}
/** Recursively walks a sink expression and fixes all subaccesses
@@ -159,13 +168,13 @@ object RemoveAccesses extends Pass {
*/
def fixSink(e: Expression): Expression = e match {
case w: WSubAccess => WSubAccess(fixSink(w.expr), fixSource(w.index), w.tpe, w.flow)
- case x => x map fixSink
+ case x => x.map(fixSink)
}
val sx = s match {
case Connect(info, loc, exp) =>
Connect(info, removeSink(info, fixSink(loc)), fixSource(exp))
- case sxx => sxx map fixSource map onStmt
+ case sxx => sxx.map(fixSource).map(onStmt)
}
stmts += sx
if (stmts.size != 1) Block(stmts.toSeq) else stmts(0)
@@ -173,9 +182,9 @@ object RemoveAccesses extends Pass {
Module(m.info, m.name, m.ports, squashEmpty(onStmt(m.body)))
}
- c copy (modules = c.modules map {
+ c.copy(modules = c.modules.map {
case m: ExtModule => m
- case m: Module => remove_m(m)
+ case m: Module => remove_m(m)
})
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 61fd6258..624138ab 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -17,8 +17,7 @@ case class DataRef(exp: Expression, source: String, sink: String, mask: String,
object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.ChirrtlForm ++
- Seq( Dependency(passes.CInferTypes),
- Dependency(passes.CInferMDir) )
+ Seq(Dependency(passes.CInferTypes), Dependency(passes.CInferMDir))
override def invalidates(a: Transform) = false
@@ -31,10 +30,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match {
case _: GroundType => Seq(ex)
- case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex)
- case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex)
+ case t: BundleType =>
+ (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(
+ ex
+ )
+ case t: VectorType =>
+ ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) =>
+ exps ++ create_all_exps(SubIndex(ex, i, t.tpe))
+ ) ++ Seq(ex)
case UnknownType => Seq(ex)
}
@@ -42,17 +45,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
case ex: Mux =>
val e1s = create_exps(ex.tval)
val e2s = create_exps(ex.fval)
- (e1s zip e2s) map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) }
+ (e1s.zip(e2s)).map { case (e1, e2) => Mux(ex.cond, e1, e2, mux_type(e1, e2)) }
case ex: ValidIf =>
- create_exps(ex.value) map (e1 => ValidIf(ex.cond, e1, e1.tpe))
- case ex => ex.tpe match {
- case _: GroundType => Seq(ex)
- case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_exps(SubField(ex, f.name, f.tpe)))
- case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_exps(SubIndex(ex, i, t.tpe)))
- case UnknownType => Seq(ex)
- }
+ create_exps(ex.value).map(e1 => ValidIf(ex.cond, e1, e1.tpe))
+ case ex =>
+ ex.tpe match {
+ case _: GroundType => Seq(ex)
+ case t: BundleType =>
+ (t.fields.foldLeft(Seq[Expression]()))((exps, f) => exps ++ create_exps(SubField(ex, f.name, f.tpe)))
+ case t: VectorType =>
+ ((0 until t.size).foldLeft(Seq[Expression]()))((exps, i) => exps ++ create_exps(SubIndex(ex, i, t.tpe)))
+ case UnknownType => Seq(ex)
+ }
}
private def EMPs: MPorts = MPorts(ArrayBuffer[MPort](), ArrayBuffer[MPort](), ArrayBuffer[MPort]())
@@ -61,40 +65,48 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
s match {
case sx: CDefMemory if sx.seq => smems += sx.name
case sx: CDefMPort =>
- val p = mports getOrElse (sx.mem, EMPs)
+ val p = mports.getOrElse(sx.mem, EMPs)
sx.direction match {
- case MRead => p.readers += MPort(sx.name, sx.exps(1))
- case MWrite => p.writers += MPort(sx.name, sx.exps(1))
+ case MRead => p.readers += MPort(sx.name, sx.exps(1))
+ case MWrite => p.writers += MPort(sx.name, sx.exps(1))
case MReadWrite => p.readwriters += MPort(sx.name, sx.exps(1))
- case MInfer => // direction may not be inferred if it's not being used
+ case MInfer => // direction may not be inferred if it's not being used
}
mports(sx.mem) = p
case _ =>
}
- s map collect_smems_and_mports(mports, smems)
+ s.map(collect_smems_and_mports(mports, smems))
}
- def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
- refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
+ def collect_refs(
+ mports: MPortMap,
+ smems: SeqMemSet,
+ types: MPortTypeMap,
+ refs: DataRefMap,
+ raddrs: AddrMap,
+ renames: RenameMap
+ )(s: Statement
+ ): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
- val taddr = UIntType(IntWidth(1 max getUIntWidth(sx.size - 1)))
+ val taddr = UIntType(IntWidth(1.max(getUIntWidth(sx.size - 1))))
val tdata = sx.tpe
- def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap (r => Seq(
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType))
- ))
- def set_enable(vec: scala.collection.Seq[MPort], en: String) = vec.toSeq.map (r =>
- Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)
+ def set_poison(vec: scala.collection.Seq[MPort]) = vec.toSeq.flatMap(r =>
+ Seq(
+ IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "addr", taddr)),
+ IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), "clk", ClockType))
+ )
)
+ def set_enable(vec: scala.collection.Seq[MPort], en: String) =
+ vec.toSeq.map(r => Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero))
def set_write(vec: scala.collection.Seq[MPort], data: String, mask: String) = vec.toSeq.flatMap { r =>
val tmask = createMask(sx.tpe)
val portRef = SubField(Reference(sx.name, ut), r.name, ut)
Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask)))
}
- val rds = (mports getOrElse (sx.name, EMPs)).readers
- val wrs = (mports getOrElse (sx.name, EMPs)).writers
- val rws = (mports getOrElse (sx.name, EMPs)).readwriters
+ val rds = (mports.getOrElse(sx.name, EMPs)).readers
+ val wrs = (mports.getOrElse(sx.name, EMPs)).writers
+ val rws = (mports.getOrElse(sx.name, EMPs)).readwriters
val stmts = set_poison(rds) ++
set_enable(rds, "en") ++
set_poison(wrs) ++
@@ -104,8 +116,18 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
set_enable(rws, "wmode") ++
set_enable(rws, "en") ++
set_write(rws, "wdata", "wmask")
- val mem = DefMemory(sx.info, sx.name, sx.tpe, sx.size, 1, if (sx.seq) 1 else 0,
- rds.map(_.name).toSeq, wrs.map(_.name).toSeq, rws.map(_.name).toSeq, sx.readUnderWrite)
+ val mem = DefMemory(
+ sx.info,
+ sx.name,
+ sx.tpe,
+ sx.size,
+ 1,
+ if (sx.seq) 1 else 0,
+ rds.map(_.name).toSeq,
+ wrs.map(_.name).toSeq,
+ rws.map(_.name).toSeq,
+ sx.readUnderWrite
+ )
Block(mem +: stmts)
case sx: CDefMPort =>
types.get(sx.mem) match {
@@ -130,8 +152,8 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val es = create_all_exps(WRef(sx.name, sx.tpe))
val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe))
- ((es zip rs) zip ws) map {
- case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
+ ((es.zip(rs)).zip(ws)).map {
+ case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
}
case MWrite =>
refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false)
@@ -142,7 +164,7 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
- (es zip ws) map {
+ (es.zip(ws)).map {
case (e, w) => renames.rename(e.serialize, w.serialize)
}
case MRead =>
@@ -157,63 +179,69 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
- (es zip rs) map {
+ (es.zip(rs)).map {
case (e, r) => renames.rename(e.serialize, r.serialize)
}
case MInfer => // do nothing if it's not being used
}
- Block(List() ++
- (addrs.map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
- (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
- (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++
- masks.map(lhs => Connect(sx.info, lhs, zero))
+ Block(
+ List() ++
+ (addrs.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
+ (clks.map(x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
+ (ens.map(x => Connect(sx.info, SubField(portRef, x, ut), one))) ++
+ masks.map(lhs => Connect(sx.info, lhs, zero))
)
- case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
+ case sx => sx.map(collect_refs(mports, smems, types, refs, raddrs, renames))
}
def get_mask(refs: DataRefMap)(e: Expression): Expression =
- e map get_mask(refs) match {
- case ex: Reference => refs get ex.name match {
- case None => ex
- case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe))
- }
+ e.map(get_mask(refs)) match {
+ case ex: Reference =>
+ refs.get(ex.name) match {
+ case None => ex
+ case Some(p) => SubField(p.exp, p.mask, createMask(ex.tpe))
+ }
case ex => ex
}
def remove_chirrtl_s(refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = {
var has_write_mport = false
var has_readwrite_mport: Option[Expression] = None
- var has_read_mport: Option[Expression] = None
+ var has_read_mport: Option[Expression] = None
def remove_chirrtl_e(g: Flow)(e: Expression): Expression = e match {
- case Reference(name, tpe, _, _) => refs get name match {
- case Some(p) => g match {
- case SinkFlow =>
- has_write_mport = true
- if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType))
- SubField(p.exp, p.sink, tpe)
- case SourceFlow =>
- SubField(p.exp, p.source, tpe)
- }
- case None => g match {
- case SinkFlow => raddrs get name match {
- case Some(en) => has_read_mport = Some(en) ; e
- case None => e
- }
- case SourceFlow => e
+ case Reference(name, tpe, _, _) =>
+ refs.get(name) match {
+ case Some(p) =>
+ g match {
+ case SinkFlow =>
+ has_write_mport = true
+ if (p.rdwrite) has_readwrite_mport = Some(SubField(p.exp, "wmode", BoolType))
+ SubField(p.exp, p.sink, tpe)
+ case SourceFlow =>
+ SubField(p.exp, p.source, tpe)
+ }
+ case None =>
+ g match {
+ case SinkFlow =>
+ raddrs.get(name) match {
+ case Some(en) => has_read_mport = Some(en); e
+ case None => e
+ }
+ case SourceFlow => e
+ }
}
- }
- case SubAccess(expr, index, tpe, _) => SubAccess(
- remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe)
- case ex => ex map remove_chirrtl_e(g)
- }
- s match {
+ case SubAccess(expr, index, tpe, _) =>
+ SubAccess(remove_chirrtl_e(g)(expr), remove_chirrtl_e(SourceFlow)(index), tpe)
+ case ex => ex.map(remove_chirrtl_e(g))
+ }
+ s match {
case DefNode(info, name, value) =>
val valuex = remove_chirrtl_e(SourceFlow)(value)
val sx = DefNode(info, name, valuex)
// Check node is used for read port address
remove_chirrtl_e(SinkFlow)(Reference(name, value.tpe))
has_read_mport match {
- case None => sx
+ case None => sx
case Some(en) => Block(sx, Connect(info, en, one))
}
case Connect(info, loc, expr) =>
@@ -222,14 +250,14 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val sx = Connect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
- case None =>
+ case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val locs = create_exps(get_mask(refs)(loc))
- stmts ++= (locs map (x => Connect(info, x, one)))
+ stmts ++= (locs.map(x => Connect(info, x, one)))
has_readwrite_mport match {
- case None =>
+ case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
@@ -240,20 +268,20 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val sx = PartialConnect(info, locx, rocx)
val stmts = ArrayBuffer[Statement]()
has_read_mport match {
- case None =>
+ case None =>
case Some(en) => stmts += Connect(info, en, one)
}
if (has_write_mport) {
val ls = get_valid_points(loc.tpe, expr.tpe, Default, Default)
val locs = create_exps(get_mask(refs)(loc))
- stmts ++= (ls map { case (x, _) => Connect(info, locs(x), one) })
+ stmts ++= (ls.map { case (x, _) => Connect(info, locs(x), one) })
has_readwrite_mport match {
- case None =>
+ case None =>
case Some(wmode) => stmts += Connect(info, wmode, one)
}
}
if (stmts.isEmpty) sx else Block(sx +: stmts.toSeq)
- case sx => sx map remove_chirrtl_s(refs, raddrs) map remove_chirrtl_e(SourceFlow)
+ case sx => sx.map(remove_chirrtl_s(refs, raddrs)).map(remove_chirrtl_e(SourceFlow))
}
}
@@ -264,16 +292,16 @@ object RemoveCHIRRTL extends Transform with DependencyAPIMigration {
val refs = new DataRefMap
val raddrs = new AddrMap
renames.setModule(m.name)
- (m map collect_smems_and_mports(mports, smems)
- map collect_refs(mports, smems, types, refs, raddrs, renames)
- map remove_chirrtl_s(refs, raddrs))
+ (m.map(collect_smems_and_mports(mports, smems))
+ .map(collect_refs(mports, smems, types, refs, raddrs, renames))
+ .map(remove_chirrtl_s(refs, raddrs)))
}
def execute(state: CircuitState): CircuitState = {
val c = state.circuit
val renames = RenameMap()
renames.setCircuit(c.main)
- val result = c copy (modules = c.modules map remove_chirrtl_m(renames))
+ val result = c.copy(modules = c.modules.map(remove_chirrtl_m(renames)))
state.copy(circuit = result, renames = Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveEmpty.scala b/src/main/scala/firrtl/passes/RemoveEmpty.scala
index eabf667c..eb25dcc4 100644
--- a/src/main/scala/firrtl/passes/RemoveEmpty.scala
+++ b/src/main/scala/firrtl/passes/RemoveEmpty.scala
@@ -15,7 +15,7 @@ object RemoveEmpty extends Pass with DependencyAPIMigration {
private def onModule(m: DefModule): DefModule = {
m match {
- case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body))
+ case m: Module => Module(m.info, m.name, m.ports, Utils.squashEmpty(m.body))
case m: ExtModule => m
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveIntervals.scala b/src/main/scala/firrtl/passes/RemoveIntervals.scala
index 7059526c..657b4356 100644
--- a/src/main/scala/firrtl/passes/RemoveIntervals.scala
+++ b/src/main/scala/firrtl/passes/RemoveIntervals.scala
@@ -13,14 +13,13 @@ import firrtl.options.Dependency
import scala.math.BigDecimal.RoundingMode._
class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
- extends PassException({
- val toWrap = wrap.args.head.serialize
- val toWrapTpe = wrap.args.head.tpe.serialize
- val wrapTo = wrap.args(1).serialize
- val wrapToTpe = wrap.args(1).tpe.serialize
- s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
- })
-
+ extends PassException({
+ val toWrap = wrap.args.head.serialize
+ val toWrapTpe = wrap.args.head.tpe.serialize
+ val wrapTo = wrap.args(1).serialize
+ val wrapToTpe = wrap.args(1).tpe.serialize
+ s"$info: [module $mname] Wraps with remainder currently unsupported: $toWrap:$toWrapTpe cannot be wrapped to $wrapTo's type $wrapToTpe"
+ })
/** Replaces IntervalType with SIntType, three AST walks:
* 1) Align binary points
@@ -39,48 +38,50 @@ class WrapWithRemainder(info: Info, mname: String, wrap: DoPrim)
class RemoveIntervals extends Pass {
override def prerequisites: Seq[Dependency[Transform]] =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck] ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck]
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(transform: Transform): Boolean = {
transform match {
case InferTypes | ResolveKinds => true
- case _ => false
+ case _ => false
}
}
def run(c: Circuit): Circuit = {
val alignedCircuit = c
val errors = new Errors()
- val wiredCircuit = alignedCircuit map makeWireModule
- val replacedCircuit = wiredCircuit map replaceModuleInterval(errors)
+ val wiredCircuit = alignedCircuit.map(makeWireModule)
+ val replacedCircuit = wiredCircuit.map(replaceModuleInterval(errors))
errors.trigger()
replacedCircuit
}
/* Replace interval types */
private def replaceModuleInterval(errors: Errors)(m: DefModule): DefModule =
- m map replaceStmtInterval(errors, m.name) map replacePortInterval
+ m.map(replaceStmtInterval(errors, m.name)).map(replacePortInterval)
private def replaceStmtInterval(errors: Errors, mname: String)(s: Statement): Statement = {
val info = s match {
case h: HasInfo => h.info
case _ => NoInfo
}
- s map replaceTypeInterval map replaceStmtInterval(errors, mname) map replaceExprInterval(errors, info, mname)
+ s.map(replaceTypeInterval).map(replaceStmtInterval(errors, mname)).map(replaceExprInterval(errors, info, mname))
}
private def replaceExprInterval(errors: Errors, info: Info, mname: String)(e: Expression): Expression = e match {
case _: WRef | _: WSubIndex | _: WSubField => e
case o =>
- o map replaceExprInterval(errors, info, mname) match {
+ o.map(replaceExprInterval(errors, info, mname)) match {
case DoPrim(AsInterval, Seq(a1), _, tpe) => DoPrim(AsSInt, Seq(a1), Seq.empty, tpe)
- case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
- case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
+ case DoPrim(IncP, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
+ case DoPrim(DecP, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
case DoPrim(Clip, Seq(a1, _), Nil, tpe: IntervalType) =>
// Output interval (pre-calculated)
val clipLo = tpe.minAdjusted.get
@@ -94,13 +95,13 @@ class RemoveIntervals extends Pass {
val ltOpt = clipLo <= inLow
(gtOpt, ltOpt) match {
// input range within output range -> no optimization
- case (true, true) => a1
+ case (true, true) => a1
case (true, false) => Mux(Lt(a1, clipLo.S), clipLo.S, a1)
case (false, true) => Mux(Gt(a1, clipHi.S), clipHi.S, a1)
- case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
+ case _ => Mux(Gt(a1, clipHi.S), clipHi.S, Mux(Lt(a1, clipLo.S), clipLo.S, a1))
}
- case sqz@DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ case sqz @ DoPrim(Squeeze, Seq(a1, a2), Nil, tpe: IntervalType) =>
// Using (conditional) reassign interval w/o adding mux
val a1tpe = a1.tpe.asInstanceOf[IntervalType]
val a2tpe = a2.tpe.asInstanceOf[IntervalType]
@@ -117,54 +118,55 @@ class RemoveIntervals extends Pass {
val bits = DoPrim(Bits, Seq(a1), Seq(w2 - 1, 0), UIntType(IntWidth(w2)))
DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(w2)))
}
- case w@DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) => a2.tpe match {
- // If a2 type is Interval wrap around range. If UInt, wrap around width
- case t: IntervalType =>
- // Need to match binary points before getting *adjusted!
- val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
- case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
- case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
- }
- val (inLo, inHi) = a1.tpe match {
- case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
- case _ => sys.error("Shouldn't be here")
- }
- // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
- val range = wrapHi - wrapLo
- val ltOpt = Add(a1, (range + 1).S)
- val gtOpt = Sub(a1, (range + 1).S)
- // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
- // If x < wl
- // output: wh - (wl - x) + 1 AKA x + r + 1
- // worst case: wh - (wl - xl) + 1 = wl
- // -> xl + wr + 1 = wl
- // If x > wh
- // output: wl + (x - wh) - 1 AKA x - r - 1
- // worst case: wl + (xh - wh) - 1 = wh
- // -> xh - wr - 1 = wh
- val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
- (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
- case (true, true, _, _) => a1
- case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
- case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
- // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
- case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
- case _ =>
- errors.append(new WrapWithRemainder(info, mname, w))
- default
- }
- case _ => sys.error("Shouldn't be here")
- }
+ case w @ DoPrim(Wrap, Seq(a1, a2), Nil, tpe: IntervalType) =>
+ a2.tpe match {
+ // If a2 type is Interval wrap around range. If UInt, wrap around width
+ case t: IntervalType =>
+ // Need to match binary points before getting *adjusted!
+ val (wrapLo, wrapHi) = t.copy(point = tpe.point) match {
+ case t: IntervalType => (t.minAdjusted.get, t.maxAdjusted.get)
+ case _ => Utils.throwInternalError(s"Illegal AST state: cannot have $e not have an IntervalType")
+ }
+ val (inLo, inHi) = a1.tpe match {
+ case t2: IntervalType => (t2.minAdjusted.get, t2.maxAdjusted.get)
+ case _ => sys.error("Shouldn't be here")
+ }
+ // If (max input) - (max wrap) + (min wrap) is less then (maxwrap), we can optimize when (max input > max wrap)
+ val range = wrapHi - wrapLo
+ val ltOpt = Add(a1, (range + 1).S)
+ val gtOpt = Sub(a1, (range + 1).S)
+ // [Angie]: This is dangerous. Would rather throw compilation error right now than allow "Rem" without the user explicitly including it.
+ // If x < wl
+ // output: wh - (wl - x) + 1 AKA x + r + 1
+ // worst case: wh - (wl - xl) + 1 = wl
+ // -> xl + wr + 1 = wl
+ // If x > wh
+ // output: wl + (x - wh) - 1 AKA x - r - 1
+ // worst case: wl + (xh - wh) - 1 = wh
+ // -> xh - wr - 1 = wh
+ val default = Add(Rem(Sub(a1, wrapLo.S), Sub(wrapHi.S, wrapLo.S)), wrapLo.S)
+ (wrapHi >= inHi, wrapLo <= inLo, (inHi - range - 1) <= wrapHi, (inLo + range + 1) >= wrapLo) match {
+ case (true, true, _, _) => a1
+ case (true, _, _, true) => Mux(Lt(a1, wrapLo.S), ltOpt, a1)
+ case (_, true, true, _) => Mux(Gt(a1, wrapHi.S), gtOpt, a1)
+ // Note: inHi - range - 1 = wrapHi can't be true when inLo + range + 1 = wrapLo (i.e. simultaneous extreme cases don't work)
+ case (_, _, true, true) => Mux(Gt(a1, wrapHi.S), gtOpt, Mux(Lt(a1, wrapLo.S), ltOpt, a1))
+ case _ =>
+ errors.append(new WrapWithRemainder(info, mname, w))
+ default
+ }
+ case _ => sys.error("Shouldn't be here")
+ }
case other => other
}
}
- private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+ private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval)
private def replaceTypeInterval(t: Type): Type = t match {
- case i@IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
+ case i @ IntervalType(l: IsKnown, u: IsKnown, p: IntWidth) => SIntType(i.width)
case i: IntervalType => sys.error(s"Shouldn't be here: $i")
- case v => v map replaceTypeInterval
+ case v => v.map(replaceTypeInterval)
}
/** Replace Interval Nodes with Interval Wires
@@ -174,15 +176,16 @@ class RemoveIntervals extends Pass {
* @param m module to replace nodes with wire + connection
* @return
*/
- private def makeWireModule(m: DefModule): DefModule = m map makeWireStmt
+ private def makeWireModule(m: DefModule): DefModule = m.map(makeWireStmt)
private def makeWireStmt(s: Statement): Statement = s match {
- case DefNode(info, name, value) => value.tpe match {
- case IntervalType(l, u, p) =>
- val newType = IntervalType(l, u, p)
- Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value)))
- case other => s
- }
- case other => other map makeWireStmt
+ case DefNode(info, name, value) =>
+ value.tpe match {
+ case IntervalType(l, u, p) =>
+ val newType = IntervalType(l, u, p)
+ Block(Seq(DefWire(info, name, newType), Connect(info, WRef(name, newType, WireKind, SinkFlow), value)))
+ case other => s
+ }
+ case other => other.map(makeWireStmt)
}
}
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index 895cb10f..7e82b37b 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -26,14 +26,13 @@ object RemoveValidIf extends Pass {
case ClockType => ClockZero
case _: FixedType => FixedZero
case AsyncResetType => AsyncZero
- case other => throwInternalError(s"Unexpected type $other")
+ case other => throwInternalError(s"Unexpected type $other")
}
override def prerequisites = firrtl.stage.Forms.LowForm
override def optionalPrerequisiteOf =
- Seq( Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case Legalize | _: firrtl.transforms.ConstantPropagation => true
@@ -42,24 +41,25 @@ object RemoveValidIf extends Pass {
// Recursive. Removes ValidIfs
private def onExp(e: Expression): Expression = {
- e map onExp match {
+ e.map(onExp) match {
case ValidIf(_, value, _) => value
- case x => x
+ case x => x
}
}
// Recursive. Replaces IsInvalid with connecting zero
- private def onStmt(s: Statement): Statement = s map onStmt map onExp match {
- case invalid @ IsInvalid(info, loc) => loc.tpe match {
- case _: AnalogType => EmptyStmt
- case tpe => Connect(info, loc, getGroundZero(tpe))
- }
+ private def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp) match {
+ case invalid @ IsInvalid(info, loc) =>
+ loc.tpe match {
+ case _: AnalogType => EmptyStmt
+ case tpe => Connect(info, loc, getGroundZero(tpe))
+ }
case other => other
}
private def onModule(m: DefModule): DefModule = {
m match {
- case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body))
+ case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body))
case m: ExtModule => m
}
}
diff --git a/src/main/scala/firrtl/passes/ReplaceAccesses.scala b/src/main/scala/firrtl/passes/ReplaceAccesses.scala
index e31d9410..4a3cd697 100644
--- a/src/main/scala/firrtl/passes/ReplaceAccesses.scala
+++ b/src/main/scala/firrtl/passes/ReplaceAccesses.scala
@@ -18,15 +18,16 @@ object ReplaceAccesses extends Pass {
override def invalidates(a: Transform) = false
def run(c: Circuit): Circuit = {
- def onStmt(s: Statement): Statement = s map onStmt map onExp
- def onExp(e: Expression): Expression = e match {
- case WSubAccess(ex, UIntLiteral(value, _), t, g) => ex.tpe match {
- case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g)
- case _ => e map onExp
- }
- case _ => e map onExp
+ def onStmt(s: Statement): Statement = s.map(onStmt).map(onExp)
+ def onExp(e: Expression): Expression = e match {
+ case WSubAccess(ex, UIntLiteral(value, _), t, g) =>
+ ex.tpe match {
+ case VectorType(_, len) if (value < len) => WSubIndex(onExp(ex), value.toInt, t, g)
+ case _ => e.map(onExp)
+ }
+ case _ => e.map(onExp)
}
- c copy (modules = c.modules map (_ map onStmt))
+ c.copy(modules = c.modules.map(_.map(onStmt)))
}
}
diff --git a/src/main/scala/firrtl/passes/ResolveFlows.scala b/src/main/scala/firrtl/passes/ResolveFlows.scala
index 85a0a26f..48b9479c 100644
--- a/src/main/scala/firrtl/passes/ResolveFlows.scala
+++ b/src/main/scala/firrtl/passes/ResolveFlows.scala
@@ -14,17 +14,22 @@ object ResolveFlows extends Pass {
override def invalidates(a: Transform) = false
def resolve_e(g: Flow)(e: Expression): Expression = e match {
- case ex: WRef => ex copy (flow = g)
- case WSubField(exp, name, tpe, _) => WSubField(
- Utils.field_flip(exp.tpe, name) match {
- case Default => resolve_e(g)(exp)
- case Flip => resolve_e(Utils.swap(g))(exp)
- }, name, tpe, g)
+ case ex: WRef => ex.copy(flow = g)
+ case WSubField(exp, name, tpe, _) =>
+ WSubField(
+ Utils.field_flip(exp.tpe, name) match {
+ case Default => resolve_e(g)(exp)
+ case Flip => resolve_e(Utils.swap(g))(exp)
+ },
+ name,
+ tpe,
+ g
+ )
case WSubIndex(exp, value, tpe, _) =>
WSubIndex(resolve_e(g)(exp), value, tpe, g)
case WSubAccess(exp, index, tpe, _) =>
WSubAccess(resolve_e(g)(exp), resolve_e(SourceFlow)(index), tpe, g)
- case _ => e map resolve_e(g)
+ case _ => e.map(resolve_e(g))
}
def resolve_s(s: Statement): Statement = s match {
@@ -35,11 +40,11 @@ object ResolveFlows extends Pass {
Connect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr))
case PartialConnect(info, loc, expr) =>
PartialConnect(info, resolve_e(SinkFlow)(loc), resolve_e(SourceFlow)(expr))
- case sx => sx map resolve_e(SourceFlow) map resolve_s
+ case sx => sx.map(resolve_e(SourceFlow)).map(resolve_s)
}
- def resolve_flow(m: DefModule): DefModule = m map resolve_s
+ def resolve_flow(m: DefModule): DefModule = m.map(resolve_s)
def run(c: Circuit): Circuit =
- c copy (modules = c.modules map resolve_flow)
+ c.copy(modules = c.modules.map(resolve_flow))
}
diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala
index 67360b74..fcbac163 100644
--- a/src/main/scala/firrtl/passes/ResolveKinds.scala
+++ b/src/main/scala/firrtl/passes/ResolveKinds.scala
@@ -20,21 +20,21 @@ object ResolveKinds extends Pass {
}
def resolve_expr(kinds: KindMap)(e: Expression): Expression = e match {
- case ex: WRef => ex copy (kind = kinds(ex.name))
- case _ => e map resolve_expr(kinds)
+ case ex: WRef => ex.copy(kind = kinds(ex.name))
+ case _ => e.map(resolve_expr(kinds))
}
def resolve_stmt(kinds: KindMap)(s: Statement): Statement = {
s match {
- case sx: DefWire => kinds(sx.name) = WireKind
- case sx: DefNode => kinds(sx.name) = NodeKind
- case sx: DefRegister => kinds(sx.name) = RegKind
+ case sx: DefWire => kinds(sx.name) = WireKind
+ case sx: DefNode => kinds(sx.name) = NodeKind
+ case sx: DefRegister => kinds(sx.name) = RegKind
case sx: WDefInstance => kinds(sx.name) = InstanceKind
- case sx: DefMemory => kinds(sx.name) = MemKind
+ case sx: DefMemory => kinds(sx.name) = MemKind
case _ =>
}
s.map(resolve_stmt(kinds))
- .map(resolve_expr(kinds))
+ .map(resolve_expr(kinds))
}
def resolve_kinds(m: DefModule): DefModule = {
@@ -44,5 +44,5 @@ object ResolveKinds extends Pass {
}
def run(c: Circuit): Circuit =
- c copy (modules = c.modules map resolve_kinds)
+ c.copy(modules = c.modules.map(resolve_kinds))
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index c536cd5d..a65f8921 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -7,7 +7,7 @@ import firrtl.{SystemVerilogEmitter, Transform, VerilogEmitter}
import firrtl.ir._
import firrtl.options.Dependency
import firrtl.Mappers._
-import firrtl.Utils.{kind, flow, get_info}
+import firrtl.Utils.{flow, get_info, kind}
// Datastructures
import scala.collection.mutable
@@ -17,65 +17,63 @@ import scala.collection.mutable
object SplitExpressions extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency(firrtl.passes.memlib.VerilogMemDelays) )
+ Seq(Dependency(firrtl.passes.RemoveValidIf), Dependency(firrtl.passes.memlib.VerilogMemDelays))
override def optionalPrerequisiteOf =
- Seq( Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = a match {
case ResolveKinds => true
case _ => false
}
- private def onModule(m: Module): Module = {
- val namespace = Namespace(m)
- def onStmt(s: Statement): Statement = {
- val v = mutable.ArrayBuffer[Statement]()
- // Splits current expression if needed
- // Adds named temporaries to v
- def split(e: Expression): Expression = e match {
- case e: DoPrim =>
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), flow(e))
- case e: Mux =>
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), flow(e))
- case e: ValidIf =>
- val name = namespace.newTemp
- v += DefNode(get_info(s), name, e)
- WRef(name, e.tpe, kind(e), flow(e))
- case _ => e
- }
-
- // Recursive. Splits compound nodes
- def onExp(e: Expression): Expression =
- e map onExp match {
- case ex: DoPrim => ex map split
- case ex => ex
- }
+ private def onModule(m: Module): Module = {
+ val namespace = Namespace(m)
+ def onStmt(s: Statement): Statement = {
+ val v = mutable.ArrayBuffer[Statement]()
+ // Splits current expression if needed
+ // Adds named temporaries to v
+ def split(e: Expression): Expression = e match {
+ case e: DoPrim =>
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), flow(e))
+ case e: Mux =>
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), flow(e))
+ case e: ValidIf =>
+ val name = namespace.newTemp
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), flow(e))
+ case _ => e
+ }
- s map onExp match {
- case x: Block => x map onStmt
- case EmptyStmt => EmptyStmt
- case x =>
- v += x
- v.size match {
- case 1 => v.head
- case _ => Block(v.toSeq)
- }
+ // Recursive. Splits compound nodes
+ def onExp(e: Expression): Expression =
+ e.map(onExp) match {
+ case ex: DoPrim => ex.map(split)
+ case ex => ex
}
+
+ s.map(onExp) match {
+ case x: Block => x.map(onStmt)
+ case EmptyStmt => EmptyStmt
+ case x =>
+ v += x
+ v.size match {
+ case 1 => v.head
+ case _ => Block(v.toSeq)
+ }
}
- Module(m.info, m.name, m.ports, onStmt(m.body))
- }
- def run(c: Circuit): Circuit = {
- val modulesx = c.modules map {
- case m: Module => onModule(m)
- case m: ExtModule => m
- }
- Circuit(c.info, modulesx, c.main)
- }
+ }
+ Module(m.info, m.name, m.ports, onStmt(m.body))
+ }
+ def run(c: Circuit): Circuit = {
+ val modulesx = c.modules.map {
+ case m: Module => onModule(m)
+ case m: ExtModule => m
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}
diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala
index c271302a..03faaf3c 100644
--- a/src/main/scala/firrtl/passes/ToWorkingIR.scala
+++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala
@@ -6,5 +6,5 @@ import firrtl.Transform
object ToWorkingIR extends Pass {
override def prerequisites = firrtl.stage.Forms.MinimalHighForm
override def invalidates(a: Transform) = false
- def run(c:Circuit): Circuit = c
+ def run(c: Circuit): Circuit = c
}
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala
index 822a8125..0a05bd4e 100644
--- a/src/main/scala/firrtl/passes/TrimIntervals.scala
+++ b/src/main/scala/firrtl/passes/TrimIntervals.scala
@@ -23,10 +23,7 @@ import firrtl.Transform
class TrimIntervals extends Pass {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ResolveFlows),
- Dependency[InferBinaryPoints] )
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows), Dependency[InferBinaryPoints])
override def optionalPrerequisiteOf = Seq.empty
@@ -34,48 +31,51 @@ class TrimIntervals extends Pass {
def run(c: Circuit): Circuit = {
// Open -> closed
- val firstPass = InferTypes.run(c map replaceModuleInterval)
+ val firstPass = InferTypes.run(c.map(replaceModuleInterval))
// Align binary points and adjust range accordingly (loss of precision changes range)
- firstPass map alignModuleBP
+ firstPass.map(alignModuleBP)
}
/* Replace interval types */
- private def replaceModuleInterval(m: DefModule): DefModule = m map replaceStmtInterval map replacePortInterval
+ private def replaceModuleInterval(m: DefModule): DefModule = m.map(replaceStmtInterval).map(replacePortInterval)
- private def replaceStmtInterval(s: Statement): Statement = s map replaceTypeInterval map replaceStmtInterval
+ private def replaceStmtInterval(s: Statement): Statement = s.map(replaceTypeInterval).map(replaceStmtInterval)
- private def replacePortInterval(p: Port): Port = p map replaceTypeInterval
+ private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval)
private def replaceTypeInterval(t: Type): Type = t match {
- case i@IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
+ case i @ IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p))
case i: IntervalType => i
- case v => v map replaceTypeInterval
+ case v => v.map(replaceTypeInterval)
}
/* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */
- private def alignModuleBP(m: DefModule): DefModule = m map alignStmtBP
-
- private def alignStmtBP(s: Statement): Statement = s map alignExpBP match {
- case c@Connect(info, loc, expr) => loc.tpe match {
- case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
- case _ => c
- }
- case c@PartialConnect(info, loc, expr) => loc.tpe match {
- case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
- case _ => c
- }
- case other => other map alignStmtBP
+ private def alignModuleBP(m: DefModule): DefModule = m.map(alignStmtBP)
+
+ private def alignStmtBP(s: Statement): Statement = s.map(alignExpBP) match {
+ case c @ Connect(info, loc, expr) =>
+ loc.tpe match {
+ case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case c @ PartialConnect(info, loc, expr) =>
+ loc.tpe match {
+ case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
+ case _ => c
+ }
+ case other => other.map(alignStmtBP)
}
// Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned
// Note - Mul does not need its binary points aligned, because multiplication is cool like that
- private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq/*, Wrap, Clip, Squeeze*/)
+ private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq /*, Wrap, Clip, Squeeze*/ )
- private def alignExpBP(e: Expression): Expression = e map alignExpBP match {
+ private def alignExpBP(e: Expression): Expression = e.map(alignExpBP) match {
case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg)
- case DoPrim(o, args, consts, t) if opsToFix.contains(o) &&
- (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
+ case DoPrim(o, args, consts, t)
+ if opsToFix.contains(o) &&
+ (args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t)
case Mux(cond, tval, fval, t: IntervalType) =>
@@ -85,9 +85,9 @@ class TrimIntervals extends Pass {
}
private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match {
case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e
- case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired)))
- case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
+ case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
val shiftAmt = current - desired
val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt)
val shiftMul = Closed(BigDecimal(1) / shiftGain)
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index b9cd32fa..10198b33 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -2,7 +2,6 @@
package firrtl.passes
-
import scala.annotation.tailrec
import firrtl._
import firrtl.ir._
@@ -35,12 +34,11 @@ import MemPortUtils.memType
object Uniquify extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(ResolveKinds),
- Dependency(InferTypes) ) ++ firrtl.stage.Forms.WorkingIR
+ Seq(Dependency(ResolveKinds), Dependency(InferTypes)) ++ firrtl.stage.Forms.WorkingIR
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | InferTypes => true
- case _ => false
+ case _ => false
}
private case class UniquifyException(msg: String) extends FirrtlInternalException(msg)
@@ -55,12 +53,13 @@ object Uniquify extends Transform with DependencyAPIMigration {
*/
@tailrec
def findValidPrefix(
- prefix: String,
- elts: Seq[String],
- namespace: collection.mutable.HashSet[String]): String = {
- elts find (elt => namespace.contains(prefix + elt)) match {
+ prefix: String,
+ elts: Seq[String],
+ namespace: collection.mutable.HashSet[String]
+ ): String = {
+ elts.find(elt => namespace.contains(prefix + elt)) match {
case Some(_) => findValidPrefix(prefix + "_", elts, namespace)
- case None => prefix
+ case None => prefix
}
}
@@ -70,16 +69,16 @@ object Uniquify extends Transform with DependencyAPIMigration {
* => foo, foo bar, foo bar 0, foo bar 1, foo bar 0 a, foo bar 0 b, foo bar 1 a, foo bar 1 b, foo c
* }}}
*/
- private [firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match {
+ private[firrtl] def enumerateNames(tpe: Type): Seq[Seq[String]] = tpe match {
case t: BundleType =>
- t.fields flatMap { f =>
- (enumerateNames(f.tpe) map (f.name +: _)) ++ Seq(Seq(f.name))
+ t.fields.flatMap { f =>
+ (enumerateNames(f.tpe).map(f.name +: _)) ++ Seq(Seq(f.name))
}
case t: VectorType =>
- ((0 until t.size) map (i => Seq(i.toString))) ++
- ((0 until t.size) flatMap { i =>
- enumerateNames(t.tpe) map (i.toString +: _)
- })
+ ((0 until t.size).map(i => Seq(i.toString))) ++
+ ((0 until t.size).flatMap { i =>
+ enumerateNames(t.tpe).map(i.toString +: _)
+ })
case _ => Seq()
}
@@ -87,27 +86,38 @@ object Uniquify extends Transform with DependencyAPIMigration {
def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = {
// Recursive helper
def recStmtToType(s: Statement): Seq[Field] = s match {
- case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe))
+ case sx: DefWire => Seq(Field(sx.name, Default, sx.tpe))
case sx: DefRegister => Seq(Field(sx.name, Default, sx.tpe))
case sx: WDefInstance => Seq(Field(sx.name, Default, sx.tpe))
- case sx: DefMemory => sx.dataType match {
- case (_: UIntType | _: SIntType | _: FixedType) =>
- Seq(Field(sx.name, Default, memType(sx)))
- case tpe: BundleType =>
- val newFields = tpe.fields map ( f =>
- DefMemory(sx.info, f.name, f.tpe, sx.depth, sx.writeLatency,
- sx.readLatency, sx.readers, sx.writers, sx.readwriters)
- ) flatMap recStmtToType
- Seq(Field(sx.name, Default, BundleType(newFields)))
- case tpe: VectorType =>
- val newFields = (0 until tpe.size) map ( i =>
- sx.copy(name = i.toString, dataType = tpe.tpe)
- ) flatMap recStmtToType
- Seq(Field(sx.name, Default, BundleType(newFields)))
- }
- case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe))
+ case sx: DefMemory =>
+ sx.dataType match {
+ case (_: UIntType | _: SIntType | _: FixedType) =>
+ Seq(Field(sx.name, Default, memType(sx)))
+ case tpe: BundleType =>
+ val newFields = tpe.fields
+ .map(f =>
+ DefMemory(
+ sx.info,
+ f.name,
+ f.tpe,
+ sx.depth,
+ sx.writeLatency,
+ sx.readLatency,
+ sx.readers,
+ sx.writers,
+ sx.readwriters
+ )
+ )
+ .flatMap(recStmtToType)
+ Seq(Field(sx.name, Default, BundleType(newFields)))
+ case tpe: VectorType =>
+ val newFields =
+ (0 until tpe.size).map(i => sx.copy(name = i.toString, dataType = tpe.tpe)).flatMap(recStmtToType)
+ Seq(Field(sx.name, Default, BundleType(newFields)))
+ }
+ case sx: DefNode => Seq(Field(sx.name, Default, sx.value.tpe))
case sx: Conditionally => recStmtToType(sx.conseq) ++ recStmtToType(sx.alt)
- case sx: Block => (sx.stmts map recStmtToType).flatten
+ case sx: Block => (sx.stmts.map(recStmtToType)).flatten
case sx => Seq()
}
BundleType(recStmtToType(s))
@@ -116,40 +126,44 @@ object Uniquify extends Transform with DependencyAPIMigration {
// Accepts a Type and an initial namespace
// Returns new Type with uniquified names
private def uniquifyNames(
- t: BundleType,
- namespace: collection.mutable.HashSet[String])
- (implicit sinfo: Info, mname: String): BundleType = {
+ t: BundleType,
+ namespace: collection.mutable.HashSet[String]
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): BundleType = {
def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match {
case tx: BundleType =>
// First add everything
- val newFieldsAndElts = tx.fields map { f =>
+ val newFieldsAndElts = tx.fields.map { f =>
val newName = findValidPrefix(f.name, Seq(""), namespace)
namespace += newName
Field(newName, f.flip, f.tpe)
- } map { f => f.tpe match {
- case _: GroundType => (f, Seq[String](f.name))
- case _ =>
- val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet())
- // Need leading _ for findValidPrefix, it doesn't add _ for checks
- val eltsNames: Seq[String] = eltsx map (e => "_" + e)
- val prefix = findValidPrefix(f.name, eltsNames, namespace)
- // We added f.name in previous map, delete if we change it
- if (prefix != f.name) {
- namespace -= f.name
- namespace += prefix
- }
- val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e)))
- namespace ++= newElts
- (Field(prefix, f.flip, tpe), prefix +: newElts)
+ }.map { f =>
+ f.tpe match {
+ case _: GroundType => (f, Seq[String](f.name))
+ case _ =>
+ val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet())
+ // Need leading _ for findValidPrefix, it doesn't add _ for checks
+ val eltsNames: Seq[String] = eltsx.map(e => "_" + e)
+ val prefix = findValidPrefix(f.name, eltsNames, namespace)
+ // We added f.name in previous map, delete if we change it
+ if (prefix != f.name) {
+ namespace -= f.name
+ namespace += prefix
+ }
+ val newElts: Seq[String] = eltsx.map(e => LowerTypes.loweredName(prefix +: Seq(e)))
+ namespace ++= newElts
+ (Field(prefix, f.flip, tpe), prefix +: newElts)
}
}
val (newFields, elts) = newFieldsAndElts.unzip
(BundleType(newFields), elts.flatten)
case tx: VectorType =>
val (tpe, elts) = recUniquifyNames(tx.tpe, namespace)
- val newElts = ((0 until tx.size) map (i => i.toString)) ++
- ((0 until tx.size) flatMap { i =>
- elts map (e => LowerTypes.loweredName(Seq(i.toString, e)))
+ val newElts = ((0 until tx.size).map(i => i.toString)) ++
+ ((0 until tx.size).flatMap { i =>
+ elts.map(e => LowerTypes.loweredName(Seq(i.toString, e)))
})
(VectorType(tpe, tx.size), newElts)
case tx => (tx, Nil)
@@ -164,19 +178,26 @@ object Uniquify extends Transform with DependencyAPIMigration {
// Creates a mapping from flattened references to members of $from ->
// flattened references to members of $to
private def createNameMapping(
- from: Type,
- to: Type)
- (implicit sinfo: Info, mname: String): Map[String, NameMapNode] = {
+ from: Type,
+ to: Type
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): Map[String, NameMapNode] = {
(from, to) match {
case (fromx: BundleType, tox: BundleType) =>
- (fromx.fields zip tox.fields flatMap { case (f, t) =>
- val eltsMap = createNameMapping(f.tpe, t.tpe)
- if ((f.name != t.name) || eltsMap.nonEmpty) {
- Map(f.name -> NameMapNode(t.name, eltsMap))
- } else {
- Map[String, NameMapNode]()
- }
- }).toMap
+ (fromx.fields
+ .zip(tox.fields)
+ .flatMap {
+ case (f, t) =>
+ val eltsMap = createNameMapping(f.tpe, t.tpe)
+ if ((f.name != t.name) || eltsMap.nonEmpty) {
+ Map(f.name -> NameMapNode(t.name, eltsMap))
+ } else {
+ Map[String, NameMapNode]()
+ }
+ })
+ .toMap
case (fromx: VectorType, tox: VectorType) =>
createNameMapping(fromx.tpe, tox.tpe)
case (fromx, tox) =>
@@ -187,18 +208,19 @@ object Uniquify extends Transform with DependencyAPIMigration {
// Maps names in expression to new uniquified names
private def uniquifyNamesExp(
- exp: Expression,
- map: Map[String, NameMapNode])
- (implicit sinfo: Info, mname: String): Expression = {
+ exp: Expression,
+ map: Map[String, NameMapNode]
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): Expression = {
// Recursive Helper
- def rec(exp: Expression, m: Map[String, NameMapNode]):
- (Expression, Map[String, NameMapNode]) = exp match {
+ def rec(exp: Expression, m: Map[String, NameMapNode]): (Expression, Map[String, NameMapNode]) = exp match {
case e: WRef =>
if (m.contains(e.name)) {
val node = m(e.name)
(WRef(node.name, e.tpe, e.kind, e.flow), node.elts)
- }
- else (e, Map())
+ } else (e, Map())
case e: WSubField =>
val (subExp, subMap) = rec(e.expr, m)
val (retName, retMap) =
@@ -218,18 +240,21 @@ object Uniquify extends Transform with DependencyAPIMigration {
(WSubAccess(subExp, index, e.tpe, e.flow), subMap)
case (_: UIntLiteral | _: SIntLiteral) => (exp, m)
case (_: Mux | _: ValidIf | _: DoPrim) =>
- (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m)
+ (exp.map((e: Expression) => uniquifyNamesExp(e, map)), m)
}
rec(exp, map)._1
}
// Uses map to recursively rename fields of tpe
private def uniquifyNamesType(
- tpe: Type,
- map: Map[String, NameMapNode])
- (implicit sinfo: Info, mname: String): Type = tpe match {
+ tpe: Type,
+ map: Map[String, NameMapNode]
+ )(
+ implicit sinfo: Info,
+ mname: String
+ ): Type = tpe match {
case t: BundleType =>
- val newFields = t.fields map { f =>
+ val newFields = t.fields.map { f =>
if (map.contains(f.name)) {
val node = map(f.name)
Field(node.name, f.flip, uniquifyNamesType(f.tpe, node.elts))
@@ -244,8 +269,11 @@ object Uniquify extends Transform with DependencyAPIMigration {
}
// Everything wrapped in run so that it's thread safe
- @deprecated("The functionality of Uniquify is now part of LowerTypes." +
- "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", "Firrtl 1.4")
+ @deprecated(
+ "The functionality of Uniquify is now part of LowerTypes." +
+ "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.",
+ "Firrtl 1.4"
+ )
def execute(state: CircuitState): CircuitState = {
val c = state.circuit
val renames = RenameMap()
@@ -263,22 +291,22 @@ object Uniquify extends Transform with DependencyAPIMigration {
val nameMap = collection.mutable.HashMap[String, NameMapNode]()
def uniquifyExp(e: Expression): Expression = e match {
- case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess ) =>
+ case (_: WRef | _: WSubField | _: WSubIndex | _: WSubAccess) =>
uniquifyNamesExp(e, nameMap.toMap)
- case e: Mux => e map uniquifyExp
- case e: ValidIf => e map uniquifyExp
+ case e: Mux => e.map(uniquifyExp)
+ case e: ValidIf => e.map(uniquifyExp)
case (_: UIntLiteral | _: SIntLiteral) => e
- case e: DoPrim => e map uniquifyExp
+ case e: DoPrim => e.map(uniquifyExp)
}
def uniquifyStmt(s: Statement): Statement = {
- s map uniquifyStmt map uniquifyExp match {
+ s.map(uniquifyStmt).map(uniquifyExp) match {
case sx: DefWire =>
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
val newType = uniquifyNamesType(sx.tpe, node.elts)
- (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
DefWire(sx.info, node.name, newType)
@@ -290,7 +318,7 @@ object Uniquify extends Transform with DependencyAPIMigration {
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
val newType = uniquifyNamesType(sx.tpe, node.elts)
- (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init)
@@ -302,7 +330,7 @@ object Uniquify extends Transform with DependencyAPIMigration {
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
val newType = portTypeMap(m.name)
- (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ (Utils.create_exps(sx.name, sx.tpe).zip(Utils.create_exps(node.name, newType))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
WDefInstance(sx.info, node.name, sx.module, newType)
@@ -317,7 +345,7 @@ object Uniquify extends Transform with DependencyAPIMigration {
val mem = sx.copy(name = node.name, dataType = dataType)
// Create new mapping to handle references to memory data fields
val uniqueMemMap = createNameMapping(memType(sx), memType(mem))
- (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach {
+ (Utils.create_exps(sx.name, memType(sx)).zip(Utils.create_exps(node.name, memType(mem)))).foreach {
case (from, to) => renames.rename(from.serialize, to.serialize)
}
nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap)
@@ -329,9 +357,12 @@ object Uniquify extends Transform with DependencyAPIMigration {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach {
- case (from, to) => renames.rename(from.serialize, to.serialize)
- }
+ (Utils
+ .create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe)
+ .zip(Utils.create_exps(node.name, sx.value.tpe)))
+ .foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
DefNode(sx.info, node.name, sx.value)
} else {
sx
@@ -354,19 +385,18 @@ object Uniquify extends Transform with DependencyAPIMigration {
mname = m.name
m match {
case m: ExtModule => m
- case m: Module =>
+ case m: Module =>
// Adds port names to namespace and namemap
nameMap ++= portNameMap(m.name)
- namespace ++= create_exps("", portTypeMap(m.name)) map
- LowerTypes.loweredName map (_.tail)
- m.copy(body = uniquifyBody(m.body) )
+ namespace ++= create_exps("", portTypeMap(m.name)).map(LowerTypes.loweredName).map(_.tail)
+ m.copy(body = uniquifyBody(m.body))
}
}
def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = {
renames.setModule(m.name)
def uniquifyPorts(ports: Seq[Port]): Seq[Port] = {
- val portsType = BundleType(ports map {
+ val portsType = BundleType(ports.map {
case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
})
val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet())
@@ -374,11 +404,12 @@ object Uniquify extends Transform with DependencyAPIMigration {
portNameMap += (m.name -> localMap)
portTypeMap += (m.name -> uniquePortsType)
- ports zip uniquePortsType.fields map { case (p, f) =>
- (Utils.create_exps(p.name, p.tpe) zip Utils.create_exps(f.name, f.tpe)) foreach {
- case (from, to) => renames.rename(from.serialize, to.serialize)
- }
- Port(p.info, f.name, p.direction, f.tpe)
+ ports.zip(uniquePortsType.fields).map {
+ case (p, f) =>
+ (Utils.create_exps(p.name, p.tpe).zip(Utils.create_exps(f.name, f.tpe))).foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ Port(p.info, f.name, p.direction, f.tpe)
}
}
@@ -386,12 +417,12 @@ object Uniquify extends Transform with DependencyAPIMigration {
mname = m.name
m match {
case m: ExtModule => m.copy(ports = uniquifyPorts(m.ports))
- case m: Module => m.copy(ports = uniquifyPorts(m.ports))
+ case m: Module => m.copy(ports = uniquifyPorts(m.ports))
}
}
sinfo = c.info
- val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main)
+ val result = Circuit(c.info, c.modules.map(uniquifyPorts(renames)).map(uniquifyModule(renames)), c.main)
state.copy(circuit = result, renames = Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala
index 36eff379..0b046a5f 100644
--- a/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala
+++ b/src/main/scala/firrtl/passes/VerilogModulusCleanup.scala
@@ -12,28 +12,30 @@ import firrtl.options.Dependency
import scala.collection.mutable
/**
- * Verilog has the width of (a % b) = Max(W(a), W(b))
- * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense,
- * but nevertheless is a problem when emitting verilog
- *
- * This pass finds every instance of (a % b) and:
- * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b))
- * 2) replaces the reference to (a % b) with a bitslice of the temporary node
- * to get back down to width Min(W(a), W(b))
- *
- * This is technically incorrect firrtl, but allows the verilog emitter
- * to emit correct verilog without needing to add temporary nodes
- */
+ * Verilog has the width of (a % b) = Max(W(a), W(b))
+ * FIRRTL has the width of (a % b) = Min(W(a), W(b)), which makes more sense,
+ * but nevertheless is a problem when emitting verilog
+ *
+ * This pass finds every instance of (a % b) and:
+ * 1) adds a temporary node equal to (a % b) with width Max(W(a), W(b))
+ * 2) replaces the reference to (a % b) with a bitslice of the temporary node
+ * to get back down to width Min(W(a), W(b))
+ *
+ * This is technically incorrect firrtl, but allows the verilog emitter
+ * to emit correct verilog without needing to add temporary nodes
+ */
object VerilogModulusCleanup extends Pass {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -51,32 +53,35 @@ object VerilogModulusCleanup extends Pass {
case t => UnknownWidth
}
- def maxWidth(ws: Seq[Width]): Width = ws reduceLeft { (x,y) => (x,y) match {
- case (IntWidth(x), IntWidth(y)) => IntWidth(x max y)
- case (x, y) => UnknownWidth
- }}
+ def maxWidth(ws: Seq[Width]): Width = ws.reduceLeft { (x, y) =>
+ (x, y) match {
+ case (IntWidth(x), IntWidth(y)) => IntWidth(x.max(y))
+ case (x, y) => UnknownWidth
+ }
+ }
def verilogRemWidth(e: DoPrim)(tpe: Type): Type = {
val newWidth = maxWidth(e.args.map(exp => getWidth(exp)))
- tpe mapWidth (w => newWidth)
+ tpe.mapWidth(w => newWidth)
}
def removeRem(e: Expression): Expression = e match {
- case e: DoPrim => e.op match {
- case Rem =>
- val name = namespace.newTemp
- val newType = e mapType verilogRemWidth(e)
- v += DefNode(get_info(s), name, e mapType verilogRemWidth(e))
- val remRef = WRef(name, newType.tpe, kind(e), flow(e))
- val remWidth = bitWidth(e.tpe)
- DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe)
- case _ => e
- }
+ case e: DoPrim =>
+ e.op match {
+ case Rem =>
+ val name = namespace.newTemp
+ val newType = e.mapType(verilogRemWidth(e))
+ v += DefNode(get_info(s), name, e.mapType(verilogRemWidth(e)))
+ val remRef = WRef(name, newType.tpe, kind(e), flow(e))
+ val remWidth = bitWidth(e.tpe)
+ DoPrim(Bits, Seq(remRef), Seq(remWidth - 1, BigInt(0)), e.tpe)
+ case _ => e
+ }
case _ => e
}
- s map removeRem match {
- case x: Block => x map onStmt
+ s.map(removeRem) match {
+ case x: Block => x.map(onStmt)
case EmptyStmt => EmptyStmt
case x =>
v += x
@@ -90,8 +95,8 @@ object VerilogModulusCleanup extends Pass {
}
def run(c: Circuit): Circuit = {
- val modules = c.modules map {
- case m: Module => onModule(m)
+ val modules = c.modules.map {
+ case m: Module => onModule(m)
case m: ExtModule => m
}
Circuit(c.info, modules, c.main)
diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala
index 03d47cfc..eeb34fa9 100644
--- a/src/main/scala/firrtl/passes/VerilogPrep.scala
+++ b/src/main/scala/firrtl/passes/VerilogPrep.scala
@@ -21,15 +21,17 @@ import scala.collection.mutable
object VerilogPrep extends Pass {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -46,9 +48,9 @@ object VerilogPrep extends Pass {
val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression]
lazy val namespace = Namespace(m)
- def onStmt(stmt: Statement): Statement = stmt map onStmt match {
+ def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
case attach: Attach =>
- val wires = attach.exprs groupBy kind
+ val wires = attach.exprs.groupBy(kind)
val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty)
val instPorts = wires.getOrElse(InstanceKind, Seq.empty)
// Sanity check (Should be caught by CheckTypes)
@@ -71,14 +73,14 @@ object VerilogPrep extends Pass {
case s => s
}
- (m map onStmt, sourceMap.toMap)
+ (m.map(onStmt), sourceMap.toMap)
}
def run(c: Circuit): Circuit = {
def lowerE(e: Expression): Expression = e match {
case (_: WRef | _: WSubField) if kind(e) == InstanceKind =>
WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e))
- case _ => e map lowerE
+ case _ => e.map(lowerE)
}
def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match {
@@ -96,12 +98,12 @@ object VerilogPrep extends Pass {
}.unzip
val newInst = WDefInstanceConnector(info, name, module, tpe, portCons)
Block(wires.flatten :+ newInst)
- case other => other map lowerS(attachMap) map lowerE
+ case other => other.map(lowerS(attachMap)).map(lowerE)
}
- val modulesx = c.modules map { mod =>
+ val modulesx = c.modules.map { mod =>
val (modx, attachMap) = collectAndRemoveAttach(mod)
- modx map lowerS(attachMap)
+ modx.map(lowerS(attachMap))
}
c.copy(modules = modulesx)
}
diff --git a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
index 39c127de..e61780a4 100644
--- a/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
+++ b/src/main/scala/firrtl/passes/ZeroLengthVecs.scala
@@ -17,10 +17,7 @@ import firrtl.options.Dependency
*/
object ZeroLengthVecs extends Pass {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ResolveKinds),
- Dependency(InferTypes),
- Dependency(ExpandConnects) )
+ Seq(Dependency(PullMuxes), Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ExpandConnects))
override def invalidates(a: Transform) = false
@@ -28,8 +25,8 @@ object ZeroLengthVecs extends Pass {
// interval type with the type alone unless you declare a component
private def replaceWithDontCare(toReplace: Expression): Expression = {
val default = toReplace.tpe match {
- case UIntType(w) => UIntLiteral(0, w)
- case SIntType(w) => SIntLiteral(0, w)
+ case UIntType(w) => UIntLiteral(0, w)
+ case SIntType(w) => SIntLiteral(0, w)
case FixedType(w, p) => FixedLiteral(0, w, p)
case it: IntervalType =>
val zeroType = IntervalType(Closed(0), Closed(0), IntWidth(0))
@@ -40,11 +37,11 @@ object ZeroLengthVecs extends Pass {
}
private def zeroLenDerivedRefLike(expr: Expression): Boolean = (expr, expr.tpe) match {
- case (_, VectorType(_, 0)) => true
- case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case (_, VectorType(_, 0)) => true
+ case (WSubIndex(e, _, _, _), _) => zeroLenDerivedRefLike(e)
case (WSubAccess(e, _, _, _), _) => zeroLenDerivedRefLike(e)
- case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e)
- case _ => false
+ case (WSubField(e, _, _, _), _) => zeroLenDerivedRefLike(e)
+ case _ => false
}
// The connects have all been lowered, so all aggregate-typed expressions are "grounded" by WSubField/WSubAccess/WSubIndex
@@ -52,13 +49,13 @@ object ZeroLengthVecs extends Pass {
private def dropZeroLenSubAccesses(expr: Expression): Expression = expr match {
case _: WSubIndex | _: WSubAccess | _: WSubField =>
if (zeroLenDerivedRefLike(expr)) replaceWithDontCare(expr) else expr
- case e => e map dropZeroLenSubAccesses
+ case e => e.map(dropZeroLenSubAccesses)
}
// Attach semantics: drop all zero-length-derived members of attach group, drop stmt if trivial
private def onStmt(stmt: Statement): Statement = stmt match {
case Connect(_, sink, _) if zeroLenDerivedRefLike(sink) => EmptyStmt
- case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt
+ case IsInvalid(_, sink) if zeroLenDerivedRefLike(sink) => EmptyStmt
case Attach(info, sinks) =>
val filtered = Attach(info, sinks.filterNot(zeroLenDerivedRefLike))
if (filtered.exprs.length < 2) EmptyStmt else filtered
diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala
index 56d66ef0..82321f95 100644
--- a/src/main/scala/firrtl/passes/ZeroWidth.scala
+++ b/src/main/scala/firrtl/passes/ZeroWidth.scala
@@ -11,12 +11,14 @@ import firrtl.options.Dependency
object ZeroWidth extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses),
- Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency[ExpandWhensAndCheck],
- Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped
+ Seq(
+ Dependency(PullMuxes),
+ Dependency(ReplaceAccesses),
+ Dependency(ExpandConnects),
+ Dependency(RemoveAccesses),
+ Dependency[ExpandWhensAndCheck],
+ Dependency(ConvertFixedToSInt)
+ ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case InferTypes => true
@@ -24,30 +26,41 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
}
private def makeEmptyMemBundle(name: String): Field =
- Field(name, Flip, BundleType(Seq(
- Field("addr", Default, UIntType(IntWidth(0))),
- Field("en", Default, UIntType(IntWidth(0))),
- Field("clk", Default, UIntType(IntWidth(0))),
- Field("data", Flip, UIntType(IntWidth(0)))
- )))
+ Field(
+ name,
+ Flip,
+ BundleType(
+ Seq(
+ Field("addr", Default, UIntType(IntWidth(0))),
+ Field("en", Default, UIntType(IntWidth(0))),
+ Field("clk", Default, UIntType(IntWidth(0))),
+ Field("data", Flip, UIntType(IntWidth(0)))
+ )
+ )
+ )
private def onEmptyMemStmt(s: Statement): Statement = s match {
- case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => removeZero(tpe) match {
- case None =>
- DefWire(info, name, BundleType(
- rs.map(r => makeEmptyMemBundle(r)) ++
- ws.map(w => makeEmptyMemBundle(w)) ++
- rws.map(rw => makeEmptyMemBundle(rw))
- ))
- case Some(_) => d
- }
- case sx => sx map onEmptyMemStmt
+ case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) =>
+ removeZero(tpe) match {
+ case None =>
+ DefWire(
+ info,
+ name,
+ BundleType(
+ rs.map(r => makeEmptyMemBundle(r)) ++
+ ws.map(w => makeEmptyMemBundle(w)) ++
+ rws.map(rw => makeEmptyMemBundle(rw))
+ )
+ )
+ case Some(_) => d
+ }
+ case sx => sx.map(onEmptyMemStmt)
}
private def onModuleEmptyMemStmt(m: DefModule): DefModule = {
m match {
case ext: ExtModule => ext
- case in: Module => in.copy(body = onEmptyMemStmt(in.body))
+ case in: Module => in.copy(body = onEmptyMemStmt(in.body))
}
}
@@ -59,20 +72,20 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
* This replaces memories with a DefWire() bundle that contains the address, en,
* clk, and data fields implemented as zero width wires. Running the rest of the ZeroWidth
* transform will remove these dangling references properly.
- *
*/
def executeEmptyMemStmt(state: CircuitState): CircuitState = {
val c = state.circuit
- val result = c.copy(modules = c.modules map onModuleEmptyMemStmt)
+ val result = c.copy(modules = c.modules.map(onModuleEmptyMemStmt))
state.copy(circuit = result)
}
// This is slightly different and specialized version of create_exps, TODO unify?
private def findRemovable(expr: => Expression, tpe: Type): Seq[Expression] = tpe match {
- case GroundType(width) => width match {
- case IntWidth(ZERO) => List(expr)
- case _ => List.empty
- }
+ case GroundType(width) =>
+ width match {
+ case IntWidth(ZERO) => List(expr)
+ case _ => List.empty
+ }
case BundleType(fields) =>
if (fields.isEmpty) List(expr)
else fields.flatMap(f => findRemovable(WSubField(expr, f.name, f.tpe, SourceFlow), f.tpe))
@@ -95,7 +108,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
t
}
x match {
- case s: Statement => s map onType(s.name)
+ case s: Statement => s.map(onType(s.name))
case Port(_, name, _, t) => onType(name)(t)
}
removedNames
@@ -103,14 +116,14 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
private[passes] def removeZero(t: Type): Option[Type] = t match {
case GroundType(IntWidth(ZERO)) => None
case BundleType(fields) =>
- fields map (f => (f, removeZero(f.tpe))) collect {
+ fields.map(f => (f, removeZero(f.tpe))).collect {
case (Field(name, flip, _), Some(t)) => Field(name, flip, t)
} match {
case Nil => None
case seq => Some(BundleType(seq))
}
- case VectorType(t, size) => removeZero(t) map (VectorType(_, size))
- case x => Some(x)
+ case VectorType(t, size) => removeZero(t).map(VectorType(_, size))
+ case x => Some(x)
}
private def onExp(e: Expression): Expression = e match {
case DoPrim(Cat, args, consts, tpe) =>
@@ -118,26 +131,27 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
x.tpe match {
case UIntType(IntWidth(ZERO)) => Seq.empty[Expression]
case SIntType(IntWidth(ZERO)) => Seq.empty[Expression]
- case other => Seq(x)
+ case other => Seq(x)
}
}
nonZeros match {
- case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1)))
+ case Nil => UIntLiteral(ZERO, IntWidth(BigInt(1)))
case Seq(x) => x
- case seq => DoPrim(Cat, seq, consts, tpe) map onExp
+ case seq => DoPrim(Cat, seq, consts, tpe).map(onExp)
}
case DoPrim(Andr, Seq(x), _, _) if (bitWidth(x.tpe) == 0) => UIntLiteral(1) // nothing false
- case other => other.tpe match {
- case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1)))
- case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1)))
- case _ => e map onExp
- }
+ case other =>
+ other.tpe match {
+ case UIntType(IntWidth(ZERO)) => UIntLiteral(ZERO, IntWidth(BigInt(1)))
+ case SIntType(IntWidth(ZERO)) => SIntLiteral(ZERO, IntWidth(BigInt(1)))
+ case _ => e.map(onExp)
+ }
}
private def onStmt(renames: RenameMap)(s: Statement): Statement = s match {
case d @ DefWire(info, name, tpe) =>
renames.delete(getRemoved(d))
removeZero(tpe) match {
- case None => EmptyStmt
+ case None => EmptyStmt
case Some(t) => DefWire(info, name, t)
}
case d @ DefRegister(info, name, tpe, clock, reset, init) =>
@@ -145,7 +159,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
removeZero(tpe) match {
case None => EmptyStmt
case Some(t) =>
- DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init))
+ DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init))
}
case d: DefMemory =>
renames.delete(getRemoved(d))
@@ -154,25 +168,28 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
Utils.throwInternalError(s"private pass ZeroWidthMemRemove should have removed this memory: $d")
case Some(t) => d.copy(dataType = t)
}
- case Connect(info, loc, exp) => removeZero(loc.tpe) match {
- case None => EmptyStmt
- case Some(t) => Connect(info, loc, onExp(exp))
- }
- case IsInvalid(info, exp) => removeZero(exp.tpe) match {
- case None => EmptyStmt
- case Some(t) => IsInvalid(info, onExp(exp))
- }
- case DefNode(info, name, value) => removeZero(value.tpe) match {
- case None => EmptyStmt
- case Some(t) => DefNode(info, name, onExp(value))
- }
- case sx => sx map onStmt(renames) map onExp
+ case Connect(info, loc, exp) =>
+ removeZero(loc.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => Connect(info, loc, onExp(exp))
+ }
+ case IsInvalid(info, exp) =>
+ removeZero(exp.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => IsInvalid(info, onExp(exp))
+ }
+ case DefNode(info, name, value) =>
+ removeZero(value.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => DefNode(info, name, onExp(value))
+ }
+ case sx => sx.map(onStmt(renames)).map(onExp)
}
private def onModule(renames: RenameMap)(m: DefModule): DefModule = {
renames.setModule(m.name)
// For each port, record deleted subcomponents
- m.ports.foreach{p => renames.delete(getRemoved(p))}
- val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap {
+ m.ports.foreach { p => renames.delete(getRemoved(p)) }
+ val ports = m.ports.map(p => (p, removeZero(p.tpe))).flatMap {
case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t))
case (Port(_, name, _, _), None) =>
renames.delete(name)
@@ -180,7 +197,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
}
m match {
case ext: ExtModule => ext.copy(ports = ports)
- case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body))
+ case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body))
}
}
def execute(state: CircuitState): CircuitState = {
@@ -189,7 +206,7 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
val c = InferTypes.run(executeEmptyMemStmt(state).circuit)
val renames = RenameMap()
renames.setCircuit(c.main)
- val result = c.copy(modules = c.modules map onModule(renames))
+ val result = c.copy(modules = c.modules.map(onModule(renames)))
CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockList.scala b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
index c2323d4c..bfc03b51 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockList.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockList.scala
@@ -13,8 +13,8 @@ import Utils._
import memlib.AnalysisUtils._
/** Starting with a top module, determine the clock origins of each child instance.
- * Write the result to writer.
- */
+ * Write the result to writer.
+ */
class ClockList(top: String, writer: Writer) extends Pass {
def run(c: Circuit): Circuit = {
// Build useful datastructures
@@ -29,7 +29,7 @@ class ClockList(top: String, writer: Writer) extends Pass {
// Clock sources must be blackbox outputs and top's clock
val partialSourceList = getSourceList(moduleMap)(lineages)
- val sourceList = partialSourceList ++ moduleMap(top).ports.collect{ case Port(i, n, Input, ClockType) => n }
+ val sourceList = partialSourceList ++ moduleMap(top).ports.collect { case Port(i, n, Input, ClockType) => n }
writer.append(s"Sourcelist: $sourceList \n")
// Remove everything from the circuit, unless it has a clock type
@@ -37,8 +37,9 @@ class ClockList(top: String, writer: Writer) extends Pass {
val onlyClockCircuit = RemoveAllButClocks.run(c)
// Inline the clock-only circuit up to the specified top module
- val modulesToInline = (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet
- val inlineTransform = new InlineInstances{ override val inlineDelim = "$" }
+ val modulesToInline =
+ (c.modules.collect { case Module(_, n, _, _) if n != top => ModuleName(n, CircuitName(c.main)) }).toSet
+ val inlineTransform = new InlineInstances { override val inlineDelim = "$" }
val inlinedCircuit = inlineTransform.run(onlyClockCircuit, modulesToInline, Set(), Seq()).circuit
val topModule = inlinedCircuit.modules.find(_.name == top).getOrElse(throwInternalError("no top module"))
@@ -49,13 +50,14 @@ class ClockList(top: String, writer: Writer) extends Pass {
val origins = getOrigins(connects, "", moduleMap)(lineages)
// If the clock origin is contained in the source list, label good (otherwise bad)
- origins.foreach { case (instance, origin) =>
- val sep = if(instance == "") "" else "."
- if(!sourceList.contains(origin.replace('.','$'))){
- outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n")
- } else {
- outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n")
- }
+ origins.foreach {
+ case (instance, origin) =>
+ val sep = if (instance == "") "" else "."
+ if (!sourceList.contains(origin.replace('.', '$'))) {
+ outputBuffer.append(s"Bad Origin of $instance${sep}clock is $origin\n")
+ } else {
+ outputBuffer.append(s"Good Origin of $instance${sep}clock is $origin\n")
+ }
}
// Write to output file
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
index e6617857..468ba905 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListTransform.scala
@@ -12,8 +12,7 @@ import memlib._
import firrtl.options.{RegisteredTransform, ShellOption}
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
-case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends
- SingleTargetAnnotation[ModuleName] {
+case class ClockListAnnotation(target: ModuleName, outputConfig: String) extends SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = ClockListAnnotation(n, outputConfig)
}
@@ -44,7 +43,7 @@ Usage:
)
passOptions.get(InputConfigFileName) match {
case Some(x) => error("Unneeded input config file name!" + usage)
- case None =>
+ case None =>
}
val target = ModuleName(passModule, CircuitName(passCircuit))
ClockListAnnotation(target, outputConfig)
@@ -53,18 +52,20 @@ Usage:
class ClockListTransform extends Transform with DependencyAPIMigration with RegisteredTransform {
- override def prerequisites = Forms.LowForm
- override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Forms.LowEmitters
+ override def prerequisites = Forms.LowForm
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Forms.LowEmitters
val options = Seq(
new ShellOption[String](
longOption = "list-clocks",
- toAnnotationSeq = (a: String) => Seq( passes.clocklist.ClockListAnnotation.parse(a),
- RunFirrtlTransformAnnotation(new ClockListTransform) ),
+ toAnnotationSeq = (a: String) =>
+ Seq(passes.clocklist.ClockListAnnotation.parse(a), RunFirrtlTransformAnnotation(new ClockListTransform)),
helpText = "List which signal drives each clock of every descendent of specified modules",
shortOption = Some("clks"),
- helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>") ) )
+ helpValueName = Some("-c:<circuit>:-m:<module>:-o:<filename>")
+ )
+ )
def passSeq(top: String, writer: Writer): Seq[Pass] =
Seq(new ClockList(top, writer))
diff --git a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
index b77629fc..00e07588 100644
--- a/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
+++ b/src/main/scala/firrtl/passes/clocklist/ClockListUtils.scala
@@ -10,45 +10,56 @@ import Utils._
import memlib.AnalysisUtils._
object ClockListUtils {
+
/** Returns a list of clock outputs from instances of external modules
- */
+ */
def getSourceList(moduleMap: Map[String, DefModule])(lin: Lineage): Seq[String] = {
- val s = lin.foldLeft(Seq[String]()){case (sL, (i, l)) =>
- val sLx = getSourceList(moduleMap)(l)
- val sLxx = sLx map (i + "$" + _)
- sL ++ sLxx
+ val s = lin.foldLeft(Seq[String]()) {
+ case (sL, (i, l)) =>
+ val sLx = getSourceList(moduleMap)(l)
+ val sLxx = sLx.map(i + "$" + _)
+ sL ++ sLxx
}
val sourceList = moduleMap(lin.name) match {
case ExtModule(i, n, ports, dn, p) =>
- val portExps = ports.flatMap{p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction)))}
+ val portExps = ports.flatMap { p => create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction))) }
portExps.filter(e => (e.tpe == ClockType) && (flow(e) == SinkFlow)).map(_.serialize)
case _ => Nil
}
val sx = sourceList ++ s
sx
}
+
/** Returns a map from instance name to its clock origin.
- * Child instances are not included if they share the same clock as their parent
- */
- def getOrigins(connects: Connects, me: String, moduleMap: Map[String, DefModule])(lin: Lineage): Map[String, String] = {
- val sep = if(me == "") "" else "$"
+ * Child instances are not included if they share the same clock as their parent
+ */
+ def getOrigins(
+ connects: Connects,
+ me: String,
+ moduleMap: Map[String, DefModule]
+ )(lin: Lineage
+ ): Map[String, String] = {
+ val sep = if (me == "") "" else "$"
// Get origins from all children
- val childrenOrigins = lin.foldLeft(Map[String, String]()){case (o, (i, l)) =>
- o ++ getOrigins(connects, me + sep + i, moduleMap)(l)
+ val childrenOrigins = lin.foldLeft(Map[String, String]()) {
+ case (o, (i, l)) =>
+ o ++ getOrigins(connects, me + sep + i, moduleMap)(l)
}
// If I have a clock, get it
val clockOpt = moduleMap(lin.name) match {
- case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" }
+ case Module(i, n, ports, b) => ports.collectFirst { case p if p.name == "clock" => me + sep + "clock" }
case ExtModule(i, n, ports, dn, p) => None
}
// Return new origins with direct children removed, if they match my clock
clockOpt match {
case Some(clock) =>
val myOrigin = getOrigin(connects, clock).serialize
- childrenOrigins.foldLeft(Map(me -> myOrigin)) { case (o, (childInstance, childOrigin)) =>
- val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance }
- // If direct child shares my origin, omit it
- if(childOrigin == myOrigin && childrenInstances.contains(childInstance)) o else o + (childInstance -> childOrigin)
+ childrenOrigins.foldLeft(Map(me -> myOrigin)) {
+ case (o, (childInstance, childOrigin)) =>
+ val childrenInstances = lin.children.map { case (instance, _) => me + sep + instance }
+ // If direct child shares my origin, omit it
+ if (childOrigin == myOrigin && childrenInstances.contains(childInstance)) o
+ else o + (childInstance -> childOrigin)
}
case None => childrenOrigins
}
diff --git a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
index 6eb8c138..d72bc293 100644
--- a/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
+++ b/src/main/scala/firrtl/passes/clocklist/RemoveAllButClocks.scala
@@ -9,22 +9,22 @@ import Utils._
import Mappers._
/** Remove all statements and ports (except instances/whens/blocks) whose
- * expressions do not relate to ground types.
- */
+ * expressions do not relate to ground types.
+ */
object RemoveAllButClocks extends Pass {
- def onStmt(s: Statement): Statement = (s map onStmt) match {
- case DefWire(i, n, ClockType) => s
+ def onStmt(s: Statement): Statement = (s.map(onStmt)) match {
+ case DefWire(i, n, ClockType) => s
case DefNode(i, n, value) if value.tpe == ClockType => s
- case Connect(i, l, r) if l.tpe == ClockType => s
- case sx: WDefInstance => sx
- case sx: DefInstance => sx
- case sx: Block => sx
+ case Connect(i, l, r) if l.tpe == ClockType => s
+ case sx: WDefInstance => sx
+ case sx: DefInstance => sx
+ case sx: Block => sx
case sx: Conditionally => sx
case _ => EmptyStmt
}
def onModule(m: DefModule): DefModule = m match {
- case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b)))
+ case Module(i, n, ps, b) => Module(i, n, ps.filter(_.tpe == ClockType), squashEmpty(onStmt(b)))
case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps.filter(_.tpe == ClockType), dn, p)
}
- def run(c: Circuit): Circuit = c.copy(modules = c.modules map onModule)
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(onModule))
}
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
index 14bd9e44..d237c36a 100644
--- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -19,8 +19,9 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform
import CustomYAMLProtocol._
val configs = r.parse[Config]
val oldAnnos = state.annotations
- val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) =>
- (annos, pins :+ config.pin.name)
+ val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) {
+ case ((annos, pins), config) =>
+ (annos, pins :+ config.pin.name)
}
state.copy(annotations = PinAnnotation(pins.toSeq) +: as)
}
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 4847a698..e290633e 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -10,12 +10,11 @@ import firrtl.PrimOps._
import firrtl.Utils.{one, zero, BoolType}
import firrtl.options.{HasShellOptions, ShellOption}
import MemPortUtils.memPortField
-import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin}
+import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects}
import WrappedExpression.weq
import annotations._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
-
case object InferReadWriteAnnotation extends NoTargetAnnotation
// This pass examine the enable signals of the read & write ports of memories
@@ -40,12 +39,13 @@ object InferReadWritePass extends Pass {
getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval)
// Visit each term of AND operation
case DoPrim(op, args, consts, tpe) if op == And =>
- e +: (args flatMap getProductTerms(connects))
+ e +: (args.flatMap(getProductTerms(connects)))
// Visit connected nodes to references
- case _: WRef | _: WSubField | _: WSubIndex => connects get e match {
- case None => Seq(e)
- case Some(ex) => e +: getProductTerms(connects)(ex)
- }
+ case _: WRef | _: WSubField | _: WSubIndex =>
+ connects.get(e) match {
+ case None => Seq(e)
+ case Some(ex) => e +: getProductTerms(connects)(ex)
+ }
// Otherwise just return itself
case _ => Seq(e)
}
@@ -58,96 +58,103 @@ object InferReadWritePass extends Pass {
// b ?= Eq(a, 0) or b ?= Eq(0, a)
case (_, DoPrim(Eq, args, _, _)) =>
weq(args.head, a) && weq(args(1), zero) ||
- weq(args(1), a) && weq(args.head, zero)
+ weq(args(1), a) && weq(args.head, zero)
// a ?= Eq(b, 0) or b ?= Eq(0, a)
case (DoPrim(Eq, args, _, _), _) =>
weq(args.head, b) && weq(args(1), zero) ||
- weq(args(1), b) && weq(args.head, zero)
+ weq(args(1), b) && weq(args.head, zero)
case _ => false
}
-
def replaceExp(repl: Netlist)(e: Expression): Expression =
- e map replaceExp(repl) match {
- case ex: WSubField => repl getOrElse (ex.serialize, ex)
+ e.map(replaceExp(repl)) match {
+ case ex: WSubField => repl.getOrElse(ex.serialize, ex)
case ex => ex
}
def replaceStmt(repl: Netlist)(s: Statement): Statement =
- s map replaceStmt(repl) map replaceExp(repl) match {
+ s.map(replaceStmt(repl)).map(replaceExp(repl)) match {
case Connect(_, EmptyExpression, _) => EmptyStmt
- case sx => sx
+ case sx => sx
}
- def inferReadWriteStmt(connects: Connects,
- repl: Netlist,
- stmts: Statements)
- (s: Statement): Statement = s match {
+ def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match {
// infer readwrite ports only for non combinational memories
case mem: DefMemory if mem.readLatency > 0 =>
val readers = new PortSet
val writers = new PortSet
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
- for (w <- mem.writers ; r <- mem.readers) {
+ for {
+ w <- mem.writers
+ r <- mem.readers
+ } {
val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
- val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
- val rw = namespace newName "rw"
+ val rw = namespace.newName("rw")
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
readers += r
writers += w
- repl(memPortField(mem, r, "clk")) = EmptyExpression
- repl(memPortField(mem, r, "en")) = EmptyExpression
+ repl(memPortField(mem, r, "clk")) = EmptyExpression
+ repl(memPortField(mem, r, "en")) = EmptyExpression
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
- repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = EmptyExpression
+ repl(memPortField(mem, w, "clk")) = EmptyExpression
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
- stmts += Connect(NoInfo, WSubField(rwExp, "en"),
- DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
- connects(memPortField(mem, w, "en"))), Nil, BoolType))
- stmts += Connect(NoInfo, WSubField(rwExp, "addr"),
- Mux(connects(memPortField(mem, w, "en")),
- connects(memPortField(mem, w, "addr")),
- connects(memPortField(mem, r, "addr")), UnknownType))
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "en"),
+ DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType)
+ )
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "addr"),
+ Mux(
+ connects(memPortField(mem, w, "en")),
+ connects(memPortField(mem, w, "addr")),
+ connects(memPortField(mem, r, "addr")),
+ UnknownType
+ )
+ )
}
}
- if (readwriters.isEmpty) mem else mem copy (
- readers = mem.readers filterNot readers,
- writers = mem.writers filterNot writers,
- readwriters = mem.readwriters ++ readwriters)
- case sx => sx map inferReadWriteStmt(connects, repl, stmts)
+ if (readwriters.isEmpty) mem
+ else
+ mem.copy(
+ readers = mem.readers.filterNot(readers),
+ writers = mem.writers.filterNot(writers),
+ readwriters = mem.readwriters ++ readwriters
+ )
+ case sx => sx.map(inferReadWriteStmt(connects, repl, stmts))
}
def inferReadWrite(m: DefModule) = {
val connects = getConnects(m)
val repl = new Netlist
val stmts = new Statements
- (m map inferReadWriteStmt(connects, repl, stmts)
- map replaceStmt(repl)) match {
+ (m.map(inferReadWriteStmt(connects, repl, stmts))
+ .map(replaceStmt(repl))) match {
case m: ExtModule => m
- case m: Module => m copy (body = Block(m.body +: stmts.toSeq))
+ case m: Module => m.copy(body = Block(m.body +: stmts.toSeq))
}
}
- def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite))
}
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite extends Transform
- with DependencyAPIMigration
- with SeqTransformBased
- with HasShellOptions {
+class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions {
override def prerequisites = Forms.MidForm
override def optionalPrerequisites = Seq.empty
@@ -159,7 +166,9 @@ class InferReadWrite extends Transform
longOption = "infer-rw",
toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)),
helpText = "Enable read/write port inference for memories",
- shortOption = Some("firw") ) )
+ shortOption = Some("firw")
+ )
+ )
def transforms = Seq(
InferReadWritePass,
diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala
index 3809c47c..871a1093 100644
--- a/src/main/scala/firrtl/passes/memlib/MemConf.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala
@@ -3,7 +3,6 @@
package firrtl.passes
package memlib
-
sealed abstract class MemPort(val name: String) { override def toString = name }
case object ReadPort extends MemPort("read")
@@ -19,22 +18,27 @@ object MemPort {
def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s)
def fromString(s: String): Map[MemPort, Int] = {
- s.split(",").toSeq.map(MemPort.apply).map(_ match {
- case Some(x) => x
- case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
- }).groupBy(identity).mapValues(_.size).toMap
+ s.split(",")
+ .toSeq
+ .map(MemPort.apply)
+ .map(_ match {
+ case Some(x) => x
+ case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
+ })
+ .groupBy(identity)
+ .mapValues(_.size)
+ .toMap
}
}
case class MemConf(
- name: String,
- depth: BigInt,
- width: Int,
- ports: Map[MemPort, Int],
- maskGranularity: Option[Int]
-) {
+ name: String,
+ depth: BigInt,
+ width: Int,
+ ports: Map[MemPort, Int],
+ maskGranularity: Option[Int]) {
- private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",")
+ private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",")
private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("")
// Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes
@@ -49,21 +53,34 @@ object MemConf {
val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r
def fromString(s: String): Seq[MemConf] = {
- s.split("\n").toSeq.map(_ match {
- case MemConf.regex(name, depth, width, ports, maskGran) => Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt)))
- case "" => None
- case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
- }).flatten
+ s.split("\n")
+ .toSeq
+ .map(_ match {
+ case MemConf.regex(name, depth, width, ports, maskGran) =>
+ Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt)))
+ case "" => None
+ case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
+ })
+ .flatten
}
- def apply(name: String, depth: BigInt, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = {
+ def apply(
+ name: String,
+ depth: BigInt,
+ width: Int,
+ readPorts: Int,
+ writePorts: Int,
+ readWritePorts: Int,
+ maskGranularity: Option[Int]
+ ): MemConf = {
val ports: Seq[(MemPort, Int)] = (if (maskGranularity.isEmpty) {
- (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++
- (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts))
- } else {
- (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++
- (if (readWritePorts == 0) Seq() else Seq(MaskedReadWritePort -> readWritePorts))
- }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts))
+ (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts))
+ } else {
+ (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Seq()
+ else Seq(MaskedReadWritePort -> readWritePorts))
+ }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts))
new MemConf(name, depth, width, ports.toMap, maskGranularity)
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
index 3731ea86..c8cd3e8d 100644
--- a/src/main/scala/firrtl/passes/memlib/MemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -19,38 +19,38 @@ object DefAnnotatedMemory {
m.readwriters,
m.readUnderWrite,
None, // mask granularity annotation
- None // No reference yet to another memory
+ None // No reference yet to another memory
)
}
}
case class DefAnnotatedMemory(
- info: Info,
- name: String,
- dataType: Type,
- depth: BigInt,
- writeLatency: Int,
- readLatency: Int,
- readers: Seq[String],
- writers: Seq[String],
- readwriters: Seq[String],
- readUnderWrite: ReadUnderWrite.Value,
- maskGran: Option[BigInt],
- memRef: Option[(String, String)] /* (Module, Mem) */
- //pins: Seq[Pin],
- ) extends Statement with IsDeclaration {
+ info: Info,
+ name: String,
+ dataType: Type,
+ depth: BigInt,
+ writeLatency: Int,
+ readLatency: Int,
+ readers: Seq[String],
+ writers: Seq[String],
+ readwriters: Seq[String],
+ readUnderWrite: ReadUnderWrite.Value,
+ maskGran: Option[BigInt],
+ memRef: Option[(String, String)] /* (Module, Mem) */
+ //pins: Seq[Pin],
+) extends Statement
+ with IsDeclaration {
override def serialize: String = this.toMem.serialize
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def toMem = DefMemory(info, name, dataType, depth,
- writeLatency, readLatency, readers, writers,
- readwriters, readUnderWrite)
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(dataType)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def toMem =
+ DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
index f0c9ebf4..1db132f7 100644
--- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
@@ -7,8 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption}
class MemLibOptions extends RegisteredLibrary {
val name: String = "MemLib Options"
- val options: Seq[ShellOption[_]] = Seq( new InferReadWrite,
- new ReplSeqMem )
+ val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem)
.flatMap(_.options)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
index b6a9a23d..f153fa2b 100644
--- a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
@@ -11,12 +11,12 @@ import MemPortUtils.{MemPortMap}
object MemTransformUtils {
/** Replaces references to old memory port names with new memory port names
- */
+ */
def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
//TODO(izraelevitz): check speed
def updateRef(e: Expression): Expression = {
- val ex = e map updateRef
- repl getOrElse (ex.serialize, ex)
+ val ex = e.map(updateRef)
+ repl.getOrElse(ex.serialize, ex)
}
def hasEmptyExpr(stmt: Statement): Boolean = {
@@ -24,16 +24,16 @@ object MemTransformUtils {
def testEmptyExpr(e: Expression): Expression = {
e match {
case EmptyExpression => foundEmpty = true
- case _ =>
+ case _ =>
}
- e map testEmptyExpr // map must return; no foreach
+ e.map(testEmptyExpr) // map must return; no foreach
}
- stmt map testEmptyExpr
+ stmt.map(testEmptyExpr)
foundEmpty
}
def updateStmtRefs(s: Statement): Statement =
- s map updateStmtRefs map updateRef match {
+ s.map(updateStmtRefs).map(updateRef) match {
case c: Connect if hasEmptyExpr(c) => EmptyStmt
case s => s
}
@@ -42,6 +42,6 @@ object MemTransformUtils {
}
def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem)
- def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
+ def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
MemPortUtils.memPortField(s.toMem, p, f)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
index 69c6b284..f325c0ba 100644
--- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
@@ -7,19 +7,19 @@ import firrtl.ir._
import firrtl.Utils._
/** Given a mask, return a bitmask corresponding to the desired datatype.
- * Requirements:
- * - The mask type and datatype must be equivalent, except any ground type in
- * datatype must be matched by a 1-bit wide UIntType.
- * - The mask must be a reference, subfield, or subindex
- * The bitmask is a series of concatenations of the single mask bit over the
- * length of the corresponding ground type, e.g.:
- *{{{
- * wire mask: {x: UInt<1>, y: UInt<1>}
- * wire data: {x: UInt<2>, y: SInt<2>}
- * // this would return:
- * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
- * }}}
- */
+ * Requirements:
+ * - The mask type and datatype must be equivalent, except any ground type in
+ * datatype must be matched by a 1-bit wide UIntType.
+ * - The mask must be a reference, subfield, or subindex
+ * The bitmask is a series of concatenations of the single mask bit over the
+ * length of the corresponding ground type, e.g.:
+ * {{{
+ * wire mask: {x: UInt<1>, y: UInt<1>}
+ * wire data: {x: UInt<2>, y: SInt<2>}
+ * // this would return:
+ * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
+ * }}}
+ */
object toBitMask {
def apply(mask: Expression, dataType: Type): Expression = mask match {
case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType)
@@ -28,12 +28,13 @@ object toBitMask {
private def hiermask(mask: Expression, dataType: Type): Expression =
(mask.tpe, dataType) match {
case (mt: VectorType, dt: VectorType) =>
- seqCat((0 until mt.size).reverse map { i =>
+ seqCat((0 until mt.size).reverse.map { i =>
hiermask(WSubIndex(mask, i, mt.tpe, UnknownFlow), dt.tpe)
})
case (mt: BundleType, dt: BundleType) =>
- seqCat((mt.fields zip dt.fields) map { case (mf, df) =>
- hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe)
+ seqCat((mt.fields.zip(dt.fields)).map {
+ case (mf, df) =>
+ hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe)
})
case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) =>
seqCat(List.fill(bitWidth(dt).intValue)(mask))
@@ -44,7 +45,7 @@ object toBitMask {
object createMask {
def apply(dt: Type): Type = dt match {
case t: VectorType => VectorType(apply(t.tpe), t.size)
- case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe))))
+ case t: BundleType => BundleType(t.fields.map(f => f.copy(tpe = apply(f.tpe))))
case GroundType(w) if w == IntWidth(0) => UIntType(IntWidth(0))
case t: GroundType => BoolType
}
@@ -56,27 +57,33 @@ object MemPortUtils {
type Modules = collection.mutable.ArrayBuffer[DefModule]
def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq(
- Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))),
+ Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1).max(1)))),
Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)
// Todo: merge it with memToBundle
def memType(mem: DefMemory): BundleType = {
- val rType = BundleType(defaultPortSeq(mem) :+
- Field("data", Flip, mem.dataType))
- val wType = BundleType(defaultPortSeq(mem) ++ Seq(
- Field("data", Default, mem.dataType),
- Field("mask", Default, createMask(mem.dataType))))
- val rwType = BundleType(defaultPortSeq(mem) ++ Seq(
- Field("rdata", Flip, mem.dataType),
- Field("wmode", Default, BoolType),
- Field("wdata", Default, mem.dataType),
- Field("wmask", Default, createMask(mem.dataType))))
+ val rType = BundleType(
+ defaultPortSeq(mem) :+
+ Field("data", Flip, mem.dataType)
+ )
+ val wType = BundleType(
+ defaultPortSeq(mem) ++ Seq(Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType)))
+ )
+ val rwType = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("rdata", Flip, mem.dataType),
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("wmask", Default, createMask(mem.dataType))
+ )
+ )
BundleType(
- (mem.readers map (Field(_, Flip, rType))) ++
- (mem.writers map (Field(_, Flip, wType))) ++
- (mem.readwriters map (Field(_, Flip, rwType))))
+ (mem.readers.map(Field(_, Flip, rType))) ++
+ (mem.writers.map(Field(_, Flip, wType))) ++
+ (mem.readwriters.map(Field(_, Flip, rwType)))
+ )
}
def memPortField(s: DefMemory, p: String, f: String): WSubField = {
diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
index c51a0adc..30529119 100644
--- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
+++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
@@ -9,27 +9,27 @@ import firrtl.Mappers._
import MemPortUtils._
import MemTransformUtils._
-
/** Changes memory port names to standard port names (i.e. RW0 instead T_408)
- */
+ */
object RenameAnnotatedMemoryPorts extends Pass {
+
/** Renames memory ports to a standard naming scheme:
- * - R0, R1, ... for each read port
- * - W0, W1, ... for each write port
- * - RW0, RW1, ... for each readwrite port
- */
+ * - R0, R1, ... for each read port
+ * - W0, W1, ... for each write port
+ * - RW0, RW1, ... for each readwrite port
+ */
def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = {
- val rports = m.readers.indices map (i => s"R$i")
- val wports = m.writers.indices map (i => s"W$i")
- val rwports = m.readwriters.indices map (i => s"RW$i")
- m copy (readers = rports, writers = wports, readwriters = rwports)
+ val rports = m.readers.indices.map(i => s"R$i")
+ val wports = m.writers.indices.map(i => s"W$i")
+ val rwports = m.readwriters.indices.map(i => s"RW$i")
+ m.copy(readers = rports, writers = wports, readwriters = rwports)
}
/** Maps the serialized form of all memory port field names to the
- * corresponding new memory port field Expression.
- * E.g.:
- * - ("m.read.addr") becomes (m.R0.addr)
- */
+ * corresponding new memory port field Expression.
+ * E.g.:
+ * - ("m.read.addr") becomes (m.R0.addr)
+ */
def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = {
val defaultFields = Seq("addr", "en", "clk")
val rFields = defaultFields :+ "data"
@@ -37,7 +37,10 @@ object RenameAnnotatedMemoryPorts extends Pass {
val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
- for ((p, i) <- ports.zipWithIndex; f <- fields) {
+ for {
+ (p, i) <- ports.zipWithIndex
+ f <- fields
+ } {
val newPort = WSubField(WRef(m.name), newPortKind + i)
val field = WSubField(newPort, f)
memPortMap(s"${m.name}.$p.$f") = field
@@ -55,16 +58,16 @@ object RenameAnnotatedMemoryPorts extends Pass {
val updatedMem = createMemProto(m)
getMemPortMap(m, memPortMap)
updatedMem
- case s => s map updateMemStmts(memPortMap)
+ case s => s.map(updateMemStmts(memPortMap))
}
/** Replaces candidate memories and their references with standard port names
- */
+ */
def updateMemMods(m: DefModule) = {
val memPortMap = new MemPortMap
- (m map updateMemStmts(memPortMap)
- map updateStmtRefs(memPortMap))
+ (m.map(updateMemStmts(memPortMap))
+ .map(updateStmtRefs(memPortMap)))
}
- def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods))
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index bfbc163a..fc381e88 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -13,7 +13,6 @@ import firrtl.annotations._
import firrtl.stage.Forms
import wiring._
-
/** Annotates the name of the pins to add for WiringTransform */
case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation
@@ -35,14 +34,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
/** Return true if mask granularity is per bit, false if per byte or unspecified
*/
private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match {
- case None => false
+ case None => false
case Some(v) => v == 1
}
private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)
+ )
private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))
+ )
/** Catch incorrect memory instantiations when there are masked memories with unsupported aggregate types.
*
@@ -82,7 +83,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
)
private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
(defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match {
- case None => Nil
+ case None => Nil
case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType)))
case Some(_) => {
checkMaskDatatype(mem)
@@ -111,7 +112,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
Field("wdata", Default, flattenType(mem.dataType)),
Field("rdata", Flip, flattenType(mem.dataType))
) ++ (mem.maskGran match {
- case None => Nil
+ case None => Nil
case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType)))
case Some(_) => {
checkMaskDatatype(mem)
@@ -122,32 +123,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
def memToBundle(s: DefAnnotatedMemory) = BundleType(
s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))
+ )
def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType(
s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))
+ )
/** Creates a wrapper module and external module to replace a candidate memory
- * The wrapper module has the same type as the memory it replaces
- * The external module
- */
+ * The wrapper module has the same type as the memory it replaces
+ * The external module
+ */
def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
val wrapperIoType = memToBundle(m)
- val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
// Creates a type with the write/readwrite masks omitted if necessary
val bbIoType = memToFlattenBundle(m)
- val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbIoPorts = bbIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
val bbRef = WRef(m.name, bbIoType)
val hasMask = m.maskGran.isDefined
val fillMask = getFillWMask(m)
def portRef(p: String) = WRef(p, field_type(wrapperIoType, p))
val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
- (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
- (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
- (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
+ (m.readers.flatMap(r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
+ (m.writers.flatMap(w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters.flatMap(rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty)
// TODO: Annotate? -- use actual annotation map
@@ -160,16 +163,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
// TODO(shunshou): get rid of copy pasta
// Connects the clk, en, and addr fields from the wrapperPort to the bbPort
def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] =
- Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
+ Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f))
// Generates mask bits (concatenates an aggregate to ground type)
// depending on mask granularity (# bits = data width / mask granularity)
def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression =
if (fillMask) toBitMask(mask, dataType) else toBits(mask)
- def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
defaultConnects(wrapperPort, bbPort) :+
- fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
+ fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
val wrapperData = WSubField(wrapperPort, "data")
@@ -177,11 +180,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData))
hasMask match {
case false => defaultSeq
- case true => defaultSeq :+ Connect(
- NoInfo,
- WSubField(bbPort, "mask"),
- maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
- )
+ case true =>
+ defaultSeq :+ Connect(
+ NoInfo,
+ WSubField(bbPort, "mask"),
+ maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
+ )
}
}
@@ -190,61 +194,67 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")),
connectFields(bbPort, "wmode", wrapperPort, "wmode"),
- Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))
+ )
hasMask match {
case false => defaultSeq
- case true => defaultSeq :+ Connect(
- NoInfo,
- WSubField(bbPort, "wmask"),
- maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
- )
+ case true =>
+ defaultSeq :+ Connect(
+ NoInfo,
+ WSubField(bbPort, "wmask"),
+ maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
+ )
}
}
/** Mapping from (module, memory name) pairs to blackbox names */
private type NameMap = collection.mutable.HashMap[(String, String), String]
+
/** Construct NameMap by assigning unique names for each memory blackbox */
def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = {
s match {
- case m: DefAnnotatedMemory => m.memRef match {
- case None => nameMap(mname -> m.name) = namespace newName m.name
- case Some(_) =>
- }
+ case m: DefAnnotatedMemory =>
+ m.memRef match {
+ case None => nameMap(mname -> m.name) = namespace.newName(m.name)
+ case Some(_) =>
+ }
case _ =>
}
- s map constructNameMap(namespace, nameMap, mname)
+ s.map(constructNameMap(namespace, nameMap, mname))
}
- def updateMemStmts(namespace: Namespace,
- nameMap: NameMap,
- mname: String,
- memPortMap: MemPortMap,
- memMods: Modules)
- (s: Statement): Statement = s match {
+ def updateMemStmts(
+ namespace: Namespace,
+ nameMap: NameMap,
+ mname: String,
+ memPortMap: MemPortMap,
+ memMods: Modules
+ )(s: Statement
+ ): Statement = s match {
case m: DefAnnotatedMemory =>
if (m.maskGran.isEmpty) {
- m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
- m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
+ m.writers.foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
+ m.readwriters.foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
}
m.memRef match {
case None =>
// prototype mem
val newWrapperName = nameMap(mname -> m.name)
- val newMemBBName = namespace newName s"${newWrapperName}_ext"
- val newMem = m copy (name = newMemBBName)
+ val newMemBBName = namespace.newName(s"${newWrapperName}_ext")
+ val newMem = m.copy(name = newMemBBName)
memMods ++= createMemModule(newMem, newWrapperName)
WDefInstance(m.info, m.name, newWrapperName, UnknownType)
case Some((module, mem)) =>
WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType)
}
- case sx => sx map updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)
+ case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods))
}
def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = {
val memPortMap = new MemPortMap
- (m map updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)
- map updateStmtRefs(memPortMap))
+ (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods))
+ .map(updateStmtRefs(memPortMap)))
}
def execute(state: CircuitState): CircuitState = {
@@ -252,15 +262,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
val namespace = Namespace(c)
val memMods = new Modules
val nameMap = new NameMap
- c.modules map (m => m map constructNameMap(namespace, nameMap, m.name))
- val modules = c.modules map updateMemMods(namespace, nameMap, memMods)
+ c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name)))
+ val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods))
// print conf
writer.serialize()
val pannos = state.annotations.collect { case a: PinAnnotation => a }
val pins = pannos match {
- case Seq() => Nil
+ case Seq() => Nil
case Seq(PinAnnotation(pins)) => pins
- case _ => throwInternalError("Something went wrong")
+ case _ => throwInternalError("Something went wrong")
}
val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) =>
seq ++ memMods.collect {
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index 87321ea0..79e07640 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -7,7 +7,7 @@ import firrtl._
import firrtl.annotations._
import firrtl.options.{HasShellOptions, ShellOption}
import Utils.error
-import java.io.{File, CharArrayWriter, PrintWriter}
+import java.io.{CharArrayWriter, File, PrintWriter}
import wiring._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
@@ -50,7 +50,15 @@ class ConfWriter(filename: String) {
// assert that we don't overflow going from BigInt to Int conversion
require(bitWidth(m.dataType) <= Int.MaxValue)
m.maskGran.foreach { case x => require(x <= Int.MaxValue) }
- val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt))
+ val conf = MemConf(
+ m.name,
+ m.depth,
+ bitWidth(m.dataType).toInt,
+ m.readers.length,
+ m.writers.length,
+ m.readwriters.length,
+ m.maskGran.map(_.toInt)
+ )
outputBuffer.append(conf.toString)
}
def serialize() = {
@@ -113,27 +121,31 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat
val options = Seq(
new ShellOption[String](
longOption = "repl-seq-mem",
- toAnnotationSeq = (a: String) => Seq( passes.memlib.ReplSeqMemAnnotation.parse(a),
- RunFirrtlTransformAnnotation(new ReplSeqMem) ),
+ toAnnotationSeq =
+ (a: String) => Seq(passes.memlib.ReplSeqMemAnnotation.parse(a), RunFirrtlTransformAnnotation(new ReplSeqMem)),
helpText = "Blackbox and emit a configuration file for each sequential memory",
shortOption = Some("frsq"),
- helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") ) )
+ helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>")
+ )
+ )
def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] =
- Seq(new SimpleMidTransform(Legalize),
- new SimpleMidTransform(ToMemIR),
- new SimpleMidTransform(ResolveMaskGranularity),
- new SimpleMidTransform(RenameAnnotatedMemoryPorts),
- new ResolveMemoryReference,
- new CreateMemoryAnnotations(inConfigFile),
- new ReplaceMemMacros(outConfigFile),
- new WiringTransform,
- new SimpleMidTransform(RemoveEmpty),
- new SimpleMidTransform(CheckInitialization),
- new SimpleMidTransform(InferTypes),
- Uniquify,
- new SimpleMidTransform(ResolveKinds),
- new SimpleMidTransform(ResolveFlows))
+ Seq(
+ new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(ToMemIR),
+ new SimpleMidTransform(ResolveMaskGranularity),
+ new SimpleMidTransform(RenameAnnotatedMemoryPorts),
+ new ResolveMemoryReference,
+ new CreateMemoryAnnotations(inConfigFile),
+ new ReplaceMemMacros(outConfigFile),
+ new WiringTransform,
+ new SimpleMidTransform(RemoveEmpty),
+ new SimpleMidTransform(CheckInitialization),
+ new SimpleMidTransform(InferTypes),
+ Uniquify,
+ new SimpleMidTransform(ResolveKinds),
+ new SimpleMidTransform(ResolveFlows)
+ )
def execute(state: CircuitState): CircuitState = {
val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a }
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 41c47dce..434c7602 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -28,10 +28,10 @@ object AnalysisUtils {
connects(value.serialize) = WInvalid
case _ => // do nothing
}
- s map getConnects(connects)
+ s.map(getConnects(connects))
}
val connects = new Connects
- m map getConnects(connects)
+ m.map(getConnects(connects))
connects
}
@@ -56,8 +56,8 @@ object AnalysisUtils {
else if (weq(tvOrigin, fvOrigin)) tvOrigin
else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
@@ -69,10 +69,10 @@ object AnalysisUtils {
case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
- connects get e.serialize match {
- case Some(ex) => getOrigin(connects)(ex)
- case None => e
- }
+ connects.get(e.serialize) match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
case _ => e
}
}
@@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass {
*/
def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
val wenOrigin = getOrigin(connects)(wen)
- val wmaskOrigin = connects.keys filter
- (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) }
// all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one))
if (redundantMask) None else Some(wmaskOrigin.size)
}
@@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass {
def updateStmts(connects: Connects)(s: Statement): Statement = s match {
case m: DefAnnotatedMemory =>
val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (rw =>
- getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
- val wMasks = m.writers map (w =>
- getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val rwMasks =
+ m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
val maskGran = (rwMasks ++ wMasks).head match {
- case None => None
+ case None => None
case Some(maskBits) => Some(dataBits / maskBits)
}
m.copy(maskGran = maskGran)
- case sx => sx map updateStmts(connects)
+ case sx => sx.map(updateStmts(connects))
}
- def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m))
- def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m)))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems))
}
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
index b5ff10c6..e80e0c4a 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -14,7 +14,7 @@ case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnot
}
/** Resolves annotation ref to memories that exactly match (except name) another memory
- */
+ */
class ResolveMemoryReference extends Transform with DependencyAPIMigration {
override def prerequisites = Forms.MidForm
@@ -45,10 +45,12 @@ class ResolveMemoryReference extends Transform with DependencyAPIMigration {
/** If a candidate memory is identical except for name to another, add an
* annotation that references the name of the other memory.
*/
- def updateMemStmts(mname: String,
- existingMems: AnnotatedMemories,
- noDedupMap: Map[String, Set[String]])
- (s: Statement): Statement = s match {
+ def updateMemStmts(
+ mname: String,
+ existingMems: AnnotatedMemories,
+ noDedupMap: Map[String, Set[String]]
+ )(s: Statement
+ ): Statement = s match {
// If not dedupable, no need to add to existing (since nothing can dedup with it)
// We just return the DefAnnotatedMemory as is in the default case below
case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) =>
diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
index 554a3572..9fe7f852 100644
--- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
@@ -14,16 +14,17 @@ import firrtl.ir._
* - undefined read-under-write behavior
*/
object ToMemIR extends Pass {
+
/** Only annotate memories that are candidates for memory macro replacements
* i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined."
*/
import ReadUnderWrite._
def updateStmts(s: Statement): Statement = s match {
- case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
+ case m @ DefMemory(_, _, _, _, 1, 1, r, w, rw, Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
DefAnnotatedMemory(m)
- case sx => sx map updateStmts
+ case sx => sx.map(updateStmts)
}
- def annotateModMems(m: DefModule) = m map updateStmts
- def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule) = m.map(updateStmts)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(annotateModMems))
}
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index dd644323..a2b14343 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -24,19 +24,19 @@ object MemDelayAndReadwriteTransformer {
case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect])
// Utilities for generating hardware
- def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType)
- def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType)
- def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r)
- def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe))
+ def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType)
+ def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType)
+ def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r)
+ def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe))
// Utilities for working with WithValid groups
def connect(l: WithValid, r: WithValid): Seq[Connect] = {
- val paired = (l.valid +: l.payload) zip (r.valid +: r.payload)
+ val paired = (l.valid +: l.payload).zip(r.valid +: r.payload)
paired.map { case (le, re) => connect(le, re) }
}
def condConnect(l: WithValid, r: WithValid): Seq[Connect] = {
- connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) }
+ connect(l.valid, r.valid) +: (l.payload.zip(r.payload)).map { case (le, re) => condConnect(r.valid)(le, re) }
}
// Internal representation of a pipeline stage with an associated valid signal
@@ -47,20 +47,23 @@ object MemDelayAndReadwriteTransformer {
private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_")
// Pipeline a group of signals with an associated valid signal. Gate registers when possible.
- def pipelineWithValid(ns: Namespace)(
- clock: Expression,
- depth: Int,
- src: WithValid,
- nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = {
+ def pipelineWithValid(
+ ns: Namespace
+ )(clock: Expression,
+ depth: Int,
+ src: WithValid,
+ nameTemplate: Option[WithValid] = None
+ ): (WithValid, Seq[Statement], Seq[Connect]) = {
def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e)
val template = nameTemplate.getOrElse(src)
- val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev =>
- def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind)
- val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef))
- val regs = (ref.valid +: ref.payload).map(asReg)
- PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref)))
+ val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) {
+ case prev =>
+ def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind)
+ val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef))
+ val regs = (ref.valid +: ref.payload).map(asReg)
+ PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref)))
}
(stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns))
}
@@ -84,10 +87,10 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
private def findMemConns(s: Statement): Unit = s match {
case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr
- case _ => s.foreach(findMemConns)
+ case _ => s.foreach(findMemConns)
}
- private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match {
+ private def swapMemRefs(e: Expression): Expression = e.map(swapMemRefs) match {
case sf: WSubField => exprReplacements.getOrElse(we(sf), sf)
case ex => ex
}
@@ -105,51 +108,57 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0
val wCmdDelay = mem.writeLatency - 1
- val readStmts = (mem.readers ++ mem.readwriters).map { case r =>
- def oldDriver(f: String) = netlist(we(memPortField(mem, r, f)))
- def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f)
- val clk = oldDriver("clk")
-
- // Pack sources of read command inputs into WithValid object -> different for readwriter
- val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en")
- val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr")))
- val cmdSink = WithValid(newField("en"), Seq(newField("addr")))
- val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
- val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
-
- // Pipeline read response using *last* command pipe stage enable as the valid signal
- val resp = WithValid(cmdPiped.valid, Seq(newField("data")))
- val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names
- val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate)
-
- // Make sure references to the read data get appropriately substituted
- val oldRDataName = if (rMap.contains(r)) "rdata" else "data"
- exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head
-
- // Return all statements; they're separated so connects can go after all declarations
- SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns)
+ val readStmts = (mem.readers ++ mem.readwriters).map {
+ case r =>
+ def oldDriver(f: String) = netlist(we(memPortField(mem, r, f)))
+ def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f)
+ val clk = oldDriver("clk")
+
+ // Pack sources of read command inputs into WithValid object -> different for readwriter
+ val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en")
+ val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr")))
+ val cmdSink = WithValid(newField("en"), Seq(newField("addr")))
+ val (cmdPiped, cmdDecls, cmdConns) =
+ pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
+ val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
+
+ // Pipeline read response using *last* command pipe stage enable as the valid signal
+ val resp = WithValid(cmdPiped.valid, Seq(newField("data")))
+ val respPipeNameTemplate =
+ Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names
+ val (respPiped, respDecls, respConns) =
+ pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate)
+
+ // Make sure references to the read data get appropriately substituted
+ val oldRDataName = if (rMap.contains(r)) "rdata" else "data"
+ exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head
+
+ // Return all statements; they're separated so connects can go after all declarations
+ SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns)
}
- val writeStmts = (mem.writers ++ mem.readwriters).map { case w =>
- def oldDriver(f: String) = netlist(we(memPortField(mem, w, f)))
- def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f)
- val clk = oldDriver("clk")
-
- // Pack sources of write command inputs into WithValid object -> different for readwriter
- val cmdSrc = if (wMap.contains(w)) {
- val en = AND(oldDriver("en"), oldDriver("wmode"))
- WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata")))
- } else {
- WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data")))
- }
-
- // Pipeline write command, connect to memory
- val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data")))
- val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
- val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
-
- // Return all statements; they're separated so connects can go after all declarations
- SplitStatements(cmdDecls, cmdConns ++ cmdPortConns)
+ val writeStmts = (mem.writers ++ mem.readwriters).map {
+ case w =>
+ def oldDriver(f: String) = netlist(we(memPortField(mem, w, f)))
+ def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f)
+ val clk = oldDriver("clk")
+
+ // Pack sources of write command inputs into WithValid object -> different for readwriter
+ val cmdSrc = if (wMap.contains(w)) {
+ val en = AND(oldDriver("en"), oldDriver("wmode"))
+ WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata")))
+ } else {
+ WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data")))
+ }
+
+ // Pipeline write command, connect to memory
+ val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data")))
+ val (cmdPiped, cmdDecls, cmdConns) =
+ pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
+ val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
+
+ // Return all statements; they're separated so connects can go after all declarations
+ SplitStatements(cmdDecls, cmdConns ++ cmdPortConns)
}
newConns ++= (readStmts ++ writeStmts).flatMap(_.conns)
@@ -171,8 +180,7 @@ object VerilogMemDelays extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf)
override val optionalPrerequisiteOf =
- Seq( Dependency[VerilogEmitter],
- Dependency[SystemVerilogEmitter] )
+ Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case _: transforms.ConstantPropagation | ResolveFlows => true
@@ -180,5 +188,5 @@ object VerilogMemDelays extends Pass {
}
def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed
- def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform))
}
diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
index a43adfe2..b5f91e7b 100644
--- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
@@ -6,7 +6,6 @@ import net.jcazevedo.moultingyaml._
import java.io.{CharArrayWriter, File, PrintWriter}
import firrtl.FileUtils
-
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
implicit val _pin = yamlFormat1(Pin)
@@ -20,17 +19,15 @@ case class Source(name: String, module: String)
case class Top(name: String)
case class Config(pin: Pin, source: Source, top: Top)
-
class YamlFileReader(file: String) {
- def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
+ def parse[A](implicit reader: YamlReader[A]): Seq[A] = {
if (new File(file).exists) {
val yamlString = FileUtils.getText(file)
- yamlString.parseYamls flatMap (x =>
- try Some(reader read x)
+ yamlString.parseYamls.flatMap(x =>
+ try Some(reader.read(x))
catch { case e: Exception => None }
)
- }
- else sys.error("Yaml file doesn't exist!")
+ } else sys.error("Yaml file doesn't exist!")
}
}
@@ -38,11 +35,11 @@ class YamlFileWriter(file: String) {
val outputBuffer = new CharArrayWriter
val separator = "--- \n"
def append(in: YamlValue): Unit = {
- outputBuffer append s"$separator${in.prettyPrint}"
+ outputBuffer.append(s"$separator${in.prettyPrint}")
}
def dump(): Unit = {
val outputFile = new PrintWriter(file)
- outputFile write outputBuffer.toString
+ outputFile.write(outputBuffer.toString)
outputFile.close()
}
}
diff --git a/src/main/scala/firrtl/passes/wiring/Wiring.scala b/src/main/scala/firrtl/passes/wiring/Wiring.scala
index 3f74e5d2..a69b7797 100644
--- a/src/main/scala/firrtl/passes/wiring/Wiring.scala
+++ b/src/main/scala/firrtl/passes/wiring/Wiring.scala
@@ -18,8 +18,7 @@ import firrtl.graph.EulerTour
case class WiringInfo(source: ComponentName, sinks: Seq[Named], pin: String)
/** A data store of wiring names */
-case class WiringNames(compName: String, source: String, sinks: Seq[Named],
- pin: String)
+case class WiringNames(compName: String, source: String, sinks: Seq[Named], pin: String)
/** Pass that computes and applies a sequence of wiring modifications
*
@@ -28,31 +27,39 @@ case class WiringNames(compName: String, source: String, sinks: Seq[Named],
*/
class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
def run(c: Circuit): Circuit = analyze(c)
- .foldLeft(c){
- case (cx, (tpe, modsMap)) => cx.copy(
- modules = cx.modules map onModule(tpe, modsMap)) }
+ .foldLeft(c) {
+ case (cx, (tpe, modsMap)) => cx.copy(modules = cx.modules.map(onModule(tpe, modsMap)))
+ }
/** Converts multiple units of wiring information to module modifications */
private def analyze(c: Circuit): Seq[(Type, Map[String, Modifications])] = {
val names = wiSeq
- .map ( wi => (wi.source, wi.sinks, wi.pin) match {
- case (ComponentName(comp, ModuleName(source,_)), sinks, pin) =>
- WiringNames(comp, source, sinks, pin) })
+ .map(wi =>
+ (wi.source, wi.sinks, wi.pin) match {
+ case (ComponentName(comp, ModuleName(source, _)), sinks, pin) =>
+ WiringNames(comp, source, sinks, pin)
+ }
+ )
val portNames = mutable.Seq.fill(names.size)(Map[String, String]())
- c.modules.foreach{ m =>
+ c.modules.foreach { m =>
val ns = Namespace(m)
- names.zipWithIndex.foreach{ case (WiringNames(c, so, si, p), i) =>
- portNames(i) = portNames(i) +
- ( m.name -> {
- if (si.exists(getModuleName(_) == m.name)) ns.newName(p)
- else ns.newName(tokenize(c) filterNot ("[]." contains _) mkString "_")
- })}}
+ names.zipWithIndex.foreach {
+ case (WiringNames(c, so, si, p), i) =>
+ portNames(i) = portNames(i) +
+ (m.name -> {
+ if (si.exists(getModuleName(_) == m.name)) ns.newName(p)
+ else ns.newName(tokenize(c).filterNot("[]." contains _).mkString("_"))
+ })
+ }
+ }
val iGraph = InstanceKeyGraph(c)
- names.zip(portNames).map{ case(WiringNames(comp, so, si, _), pn) =>
- computeModifications(c, iGraph, comp, so, si, pn) }
+ names.zip(portNames).map {
+ case (WiringNames(comp, so, si, _), pn) =>
+ computeModifications(c, iGraph, comp, so, si, pn)
+ }
}
/** Converts a single unit of wiring information to module modifications
@@ -69,19 +76,20 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
* @return a tuple of the component type and a map of module names
* to pending modifications
*/
- private def computeModifications(c: Circuit,
- iGraph: InstanceKeyGraph,
- compName: String,
- source: String,
- sinks: Seq[Named],
- portNames: Map[String, String]):
- (Type, Map[String, Modifications]) = {
+ private def computeModifications(
+ c: Circuit,
+ iGraph: InstanceKeyGraph,
+ compName: String,
+ source: String,
+ sinks: Seq[Named],
+ portNames: Map[String, String]
+ ): (Type, Map[String, Modifications]) = {
val sourceComponentType = getType(c, source, compName)
- val sinkComponents: Map[String, Seq[String]] = sinks
- .collect{ case ComponentName(c, ModuleName(m, _)) => (c, m) }
- .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]){
- case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)) ) }
+ val sinkComponents: Map[String, Seq[String]] = sinks.collect { case ComponentName(c, ModuleName(m, _)) => (c, m) }
+ .foldLeft(new scala.collection.immutable.HashMap[String, Seq[String]]) {
+ case (a, (c, m)) => a ++ Map(m -> (Seq(c) ++ a.getOrElse(m, Nil)))
+ }
// Determine "ownership" of sources to sinks via minimum distance
val owners = sinksToSourcesSeq(sinks, source, iGraph)
@@ -95,86 +103,88 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
def makeWire(m: Modifications, portName: String): Modifications =
m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))))
def makeWireC(m: Modifications, portName: String, c: (String, String)): Modifications =
- m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct )
+ m.copy(addPortOrWire = Some(m.addPortOrWire.getOrElse((portName, DecWire))), cons = (m.cons :+ c).distinct)
val tour = EulerTour(iGraph.graph, iGraph.top)
// Finds the lowest common ancestor instances for two module names in a design
def lowestCommonAncestor(moduleA: Seq[InstanceKey], moduleB: Seq[InstanceKey]): Seq[InstanceKey] =
tour.rmq(moduleA, moduleB)
- owners.foreach { case (sink, source) =>
- val lca = lowestCommonAncestor(sink, source)
-
- // Compute metadata along Sink to LCA paths.
- sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach {
- case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) =>
- val to = s"$ci.${portNames(cm)}"
- val from = s"${portNames(pm)}"
- meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
- meta(cm) = meta(cm).copy(
- addPortOrWire = Some((portNames(cm), DecInput))
- )
- // Case where the sink is the LCA
- case Seq(InstanceKey(_,pm)) =>
- // Case where the source is also the LCA
- if (source.drop(lca.size).isEmpty) {
- meta(pm) = makeWire(meta(pm), portNames(pm))
- } else {
- val InstanceKey(ci,cm) = source.drop(lca.size).head
- val to = s"${portNames(pm)}"
- val from = s"$ci.${portNames(cm)}"
- meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
- }
- }
+ owners.foreach {
+ case (sink, source) =>
+ val lca = lowestCommonAncestor(sink, source)
- // Compute metadata for the Sink
- sink.last match { case InstanceKey( _, m) =>
- if (sinkComponents.contains(m)) {
- val from = s"${portNames(m)}"
- sinkComponents(m).foreach( to =>
- meta(m) = meta(m).copy(
- cons = (meta(m).cons :+( (to, from) )).distinct
+ // Compute metadata along Sink to LCA paths.
+ sink.drop(lca.size - 1).sliding(2).toList.reverse.foreach {
+ case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) =>
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecInput))
)
- )
+ // Case where the sink is the LCA
+ case Seq(InstanceKey(_, pm)) =>
+ // Case where the source is also the LCA
+ if (source.drop(lca.size).isEmpty) {
+ meta(pm) = makeWire(meta(pm), portNames(pm))
+ } else {
+ val InstanceKey(ci, cm) = source.drop(lca.size).head
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
+ meta(pm) = makeWireC(meta(pm), portNames(pm), (to, from))
+ }
}
- }
- // Compute metadata for the Source
- source.last match { case InstanceKey( _, m) =>
- val to = s"${portNames(m)}"
- val from = compName
- meta(m) = meta(m).copy(
- cons = (meta(m).cons :+( (to, from) )).distinct
- )
- }
+ // Compute metadata for the Sink
+ sink.last match {
+ case InstanceKey(_, m) =>
+ if (sinkComponents.contains(m)) {
+ val from = s"${portNames(m)}"
+ sinkComponents(m).foreach(to =>
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ ((to, from))).distinct
+ )
+ )
+ }
+ }
- // Compute metadata along Source to LCA path
- source.drop(lca.size - 1).sliding(2).toList.reverse.map {
- case Seq(InstanceKey(_,pm), InstanceKey(ci,cm)) => {
- val to = s"${portNames(pm)}"
- val from = s"$ci.${portNames(cm)}"
- meta(pm) = meta(pm).copy(
- cons = (meta(pm).cons :+( (to, from) )).distinct
- )
- meta(cm) = meta(cm).copy(
- addPortOrWire = Some((portNames(cm), DecOutput))
- )
+ // Compute metadata for the Source
+ source.last match {
+ case InstanceKey(_, m) =>
+ val to = s"${portNames(m)}"
+ val from = compName
+ meta(m) = meta(m).copy(
+ cons = (meta(m).cons :+ ((to, from))).distinct
+ )
}
- // Case where the source is the LCA
- case Seq(InstanceKey(_,pm)) => {
- // Case where the sink is also the LCA. We do nothing here,
- // as we've created the connecting wire above
- if (sink.drop(lca.size).isEmpty) {
- } else {
- val InstanceKey(ci,cm) = sink.drop(lca.size).head
- val to = s"$ci.${portNames(cm)}"
- val from = s"${portNames(pm)}"
+
+ // Compute metadata along Source to LCA path
+ source.drop(lca.size - 1).sliding(2).toList.reverse.map {
+ case Seq(InstanceKey(_, pm), InstanceKey(ci, cm)) => {
+ val to = s"${portNames(pm)}"
+ val from = s"$ci.${portNames(cm)}"
meta(pm) = meta(pm).copy(
- cons = (meta(pm).cons :+( (to, from) )).distinct
+ cons = (meta(pm).cons :+ ((to, from))).distinct
)
+ meta(cm) = meta(cm).copy(
+ addPortOrWire = Some((portNames(cm), DecOutput))
+ )
+ }
+ // Case where the source is the LCA
+ case Seq(InstanceKey(_, pm)) => {
+ // Case where the sink is also the LCA. We do nothing here,
+ // as we've created the connecting wire above
+ if (sink.drop(lca.size).isEmpty) {} else {
+ val InstanceKey(ci, cm) = sink.drop(lca.size).head
+ val to = s"$ci.${portNames(cm)}"
+ val from = s"${portNames(pm)}"
+ meta(pm) = meta(pm).copy(
+ cons = (meta(pm).cons :+ ((to, from))).distinct
+ )
+ }
}
}
- }
}
(sourceComponentType, meta.toMap)
}
@@ -189,20 +199,22 @@ class Wiring(wiSeq: Seq[WiringInfo]) extends Pass {
val ports = mutable.ArrayBuffer[Port]()
l.addPortOrWire match {
case None =>
- case Some((s, dt)) => dt match {
- case DecInput => ports += Port(NoInfo, s, Input, t)
- case DecOutput => ports += Port(NoInfo, s, Output, t)
- case DecWire => defines += DefWire(NoInfo, s, t)
- }
+ case Some((s, dt)) =>
+ dt match {
+ case DecInput => ports += Port(NoInfo, s, Input, t)
+ case DecOutput => ports += Port(NoInfo, s, Output, t)
+ case DecWire => defines += DefWire(NoInfo, s, t)
+ }
}
- connects ++= (l.cons map { case ((l, r)) =>
- Connect(NoInfo, toExp(l), toExp(r))
+ connects ++= (l.cons.map {
+ case ((l, r)) =>
+ Connect(NoInfo, toExp(l), toExp(r))
})
m match {
case Module(i, n, ps, body) =>
val stmts = body match {
case Block(sx) => sx
- case s => Seq(s)
+ case s => Seq(s)
}
Module(i, n, ps ++ ports, Block(List() ++ defines ++ stmts ++ connects))
case ExtModule(i, n, ps, dn, p) => ExtModule(i, n, ps ++ ports, dn, p)
diff --git a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
index 20fb1215..d6658f16 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringTransform.scala
@@ -14,14 +14,12 @@ import firrtl.stage.Forms
case class WiringException(msg: String) extends PassException(msg)
/** A component, e.g. register etc. Must be declared only once under the TopAnnotation */
-case class SourceAnnotation(target: ComponentName, pin: String) extends
- SingleTargetAnnotation[ComponentName] {
+case class SourceAnnotation(target: ComponentName, pin: String) extends SingleTargetAnnotation[ComponentName] {
def duplicate(n: ComponentName) = this.copy(target = n)
}
/** A module, e.g. ExtModule etc., that should add the input pin */
-case class SinkAnnotation(target: Named, pin: String) extends
- SingleTargetAnnotation[Named] {
+case class SinkAnnotation(target: Named, pin: String) extends SingleTargetAnnotation[Named] {
def duplicate(n: Named) = this.copy(target = n)
}
@@ -76,8 +74,9 @@ class WiringTransform extends Transform with DependencyAPIMigration {
(sources.size, sinks.size) match {
case (0, p) => state
case (s, p) if (p > 0) =>
- val wis = sources.foldLeft(Seq[WiringInfo]()) { case (seq, (pin, source)) =>
- seq :+ WiringInfo(source, sinks(pin), pin)
+ val wis = sources.foldLeft(Seq[WiringInfo]()) {
+ case (seq, (pin, source)) =>
+ seq :+ WiringInfo(source, sinks(pin), pin)
}
val annosx = state.annotations.filterNot(annos.toSet.contains)
transforms(wis)
diff --git a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
index c220692a..5e8f8616 100644
--- a/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
+++ b/src/main/scala/firrtl/passes/wiring/WiringUtils.scala
@@ -25,54 +25,54 @@ case object DecWire extends DecKind
/** Store of pending wiring information for a Module */
case class Modifications(
addPortOrWire: Option[(String, DecKind)] = None,
- cons: Seq[(String, String)] = Seq.empty) {
+ cons: Seq[(String, String)] = Seq.empty) {
override def toString: String = serialize("")
def serialize(tab: String): String = s"""
- |$tab addPortOrWire: $addPortOrWire
- |$tab cons: $cons
- |""".stripMargin
+ |$tab addPortOrWire: $addPortOrWire
+ |$tab cons: $cons
+ |""".stripMargin
}
/** A lineage tree representing the instance hierarchy in a design
*/
@deprecated("Use DiGraph/InstanceGraph", "1.1.1")
case class Lineage(
- name: String,
- children: Seq[(String, Lineage)] = Seq.empty,
- source: Boolean = false,
- sink: Boolean = false,
- sourceParent: Boolean = false,
- sinkParent: Boolean = false,
- sharedParent: Boolean = false,
- addPort: Option[(String, DecKind)] = None,
- cons: Seq[(String, String)] = Seq.empty) {
+ name: String,
+ children: Seq[(String, Lineage)] = Seq.empty,
+ source: Boolean = false,
+ sink: Boolean = false,
+ sourceParent: Boolean = false,
+ sinkParent: Boolean = false,
+ sharedParent: Boolean = false,
+ addPort: Option[(String, DecKind)] = None,
+ cons: Seq[(String, String)] = Seq.empty) {
def map(f: Lineage => Lineage): Lineage =
- this.copy(children = children.map{ case (i, m) => (i, f(m)) })
+ this.copy(children = children.map { case (i, m) => (i, f(m)) })
override def toString: String = shortSerialize("")
def shortSerialize(tab: String): String = s"""
- |$tab name: $name,
- |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
- |""".stripMargin
+ |$tab name: $name,
+ |$tab children: ${children.map(c => tab + " " + c._2.shortSerialize(tab + " "))}
+ |""".stripMargin
def foldLeft[B](z: B)(op: (B, (String, Lineage)) => B): B =
this.children.foldLeft(z)(op)
def serialize(tab: String): String = s"""
- |$tab name: $name,
- |$tab source: $source,
- |$tab sink: $sink,
- |$tab sourceParent: $sourceParent,
- |$tab sinkParent: $sinkParent,
- |$tab sharedParent: $sharedParent,
- |$tab addPort: $addPort
- |$tab cons: $cons
- |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))}
- |""".stripMargin
+ |$tab name: $name,
+ |$tab source: $source,
+ |$tab sink: $sink,
+ |$tab sourceParent: $sourceParent,
+ |$tab sinkParent: $sinkParent,
+ |$tab sharedParent: $sharedParent,
+ |$tab addPort: $addPort
+ |$tab cons: $cons
+ |$tab children: ${children.map(c => tab + " " + c._2.serialize(tab + " "))}
+ |""".stripMargin
}
object WiringUtils {
@@ -87,12 +87,12 @@ object WiringUtils {
val childrenMap = new ChildrenMap()
def getChildren(mname: String)(s: Statement): Unit = s match {
case s: WDefInstance =>
- childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) )
+ childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module))
case s: DefInstance =>
- childrenMap(mname) = childrenMap(mname) :+( (s.name, s.module) )
+ childrenMap(mname) = childrenMap(mname) :+ ((s.name, s.module))
case s => s.foreach(getChildren(mname))
}
- c.modules.foreach{ m =>
+ c.modules.foreach { m =>
childrenMap(m.name) = Nil
m.foreach(getChildren(m.name))
}
@@ -103,7 +103,7 @@ object WiringUtils {
*/
@deprecated("Use DiGraph/InstanceGraph", "1.1.1")
def getLineage(childrenMap: ChildrenMap, module: String): Lineage =
- Lineage(module, childrenMap(module) map { case (i, m) => (i, getLineage(childrenMap, m)) } )
+ Lineage(module, childrenMap(module).map { case (i, m) => (i, getLineage(childrenMap, m)) })
/** Return a map of sink instances to source instances that minimizes
* distance
@@ -114,22 +114,25 @@ object WiringUtils {
* @return a map of sink instance names to source instance names
* @throws WiringException if a sink is equidistant to two sources
*/
- @deprecated("This method can lead to non-determinism in your compiler pass and exposes internal details." +
- " Please file an issue with firrtl if you have a use case!", "Firrtl 1.4")
+ @deprecated(
+ "This method can lead to non-determinism in your compiler pass and exposes internal details." +
+ " Please file an issue with firrtl if you have a use case!",
+ "Firrtl 1.4"
+ )
def sinksToSources(sinks: Seq[Named], source: String, i: InstanceGraph): Map[Seq[WDefInstance], Seq[WDefInstance]] = {
// The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap.
val owners = new mutable.LinkedHashMap[Seq[WDefInstance], Vector[Seq[WDefInstance]]]
val queue = new mutable.Queue[Seq[WDefInstance]]
val visited = new mutable.HashMap[Seq[WDefInstance], Boolean].withDefaultValue(false)
- val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v }
+ val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v }
sourcePaths.flatten.foreach { l =>
queue.enqueue(l)
owners(l) = Vector(l)
}
val sinkModuleNames = sinks.map(getModuleName).toSet
- val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v }
+ val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v }
// sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet
val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten
@@ -156,8 +159,8 @@ object WiringUtils {
// [todo] This is the critical section
edges
- .filter( e => !visited(e) && e.nonEmpty )
- .foreach{ v =>
+ .filter(e => !visited(e) && e.nonEmpty)
+ .foreach { v =>
owners(v) = owners.getOrElse(v, Vector()) ++ owners(u)
queue.enqueue(v)
}
@@ -167,8 +170,8 @@ object WiringUtils {
// this should fail is if a sink is equidistant to two sources.
sinkInsts.foreach { s =>
if (!owners.contains(s) || owners(s).size > 1) {
- throw new WiringException(
- s"Unable to determine source mapping for sink '${s.map(_.name)}'") }
+ throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'")
+ }
}
}
@@ -184,21 +187,24 @@ object WiringUtils {
* @return a map of sink instance names to source instance names
* @throws WiringException if a sink is equidistant to two sources
*/
- private[firrtl] def sinksToSourcesSeq(sinks: Seq[Named], source: String, i: InstanceKeyGraph):
- Seq[(Seq[InstanceKey], Seq[InstanceKey])] = {
+ private[firrtl] def sinksToSourcesSeq(
+ sinks: Seq[Named],
+ source: String,
+ i: InstanceKeyGraph
+ ): Seq[(Seq[InstanceKey], Seq[InstanceKey])] = {
// The order of owners influences the order of the results, it thus needs to be deterministic with a LinkedHashMap.
val owners = new mutable.LinkedHashMap[Seq[InstanceKey], Vector[Seq[InstanceKey]]]
val queue = new mutable.Queue[Seq[InstanceKey]]
val visited = new mutable.HashMap[Seq[InstanceKey], Boolean].withDefaultValue(false)
- val sourcePaths = i.fullHierarchy.collect { case (k,v) if k.module == source => v }
+ val sourcePaths = i.fullHierarchy.collect { case (k, v) if k.module == source => v }
sourcePaths.flatten.foreach { l =>
queue.enqueue(l)
owners(l) = Vector(l)
}
val sinkModuleNames = sinks.map(getModuleName).toSet
- val sinkPaths = i.fullHierarchy.collect { case (k,v) if sinkModuleNames.contains(k.module) => v }
+ val sinkPaths = i.fullHierarchy.collect { case (k, v) if sinkModuleNames.contains(k.module) => v }
// sinkInsts needs to have unique entries but is also iterated over which is why we use a LinkedHashSet
val sinkInsts = mutable.LinkedHashSet() ++ sinkPaths.flatten
@@ -225,8 +231,8 @@ object WiringUtils {
// [todo] This is the critical section
edges
- .filter( e => !visited(e) && e.nonEmpty )
- .foreach{ v =>
+ .filter(e => !visited(e) && e.nonEmpty)
+ .foreach { v =>
owners(v) = owners.getOrElse(v, Vector()) ++ owners(u)
queue.enqueue(v)
}
@@ -236,8 +242,8 @@ object WiringUtils {
// this should fail is if a sink is equidistant to two sources.
sinkInsts.foreach { s =>
if (!owners.contains(s) || owners(s).size > 1) {
- throw new WiringException(
- s"Unable to determine source mapping for sink '${s.map(_.name)}'") }
+ throw new WiringException(s"Unable to determine source mapping for sink '${s.map(_.name)}'")
+ }
}
}
@@ -249,8 +255,7 @@ object WiringUtils {
n match {
case ModuleName(m, _) => m
case ComponentName(_, ModuleName(m, _)) => m
- case _ => throw new WiringException(
- "Only Components or Modules have an associated Module name")
+ case _ => throw new WiringException("Only Components or Modules have an associated Module name")
}
}
@@ -266,9 +271,9 @@ object WiringUtils {
def getType(c: Circuit, module: String, comp: String): Type = {
def getRoot(e: Expression): String = e match {
case r: Reference => r.name
- case i: SubIndex => getRoot(i.expr)
+ case i: SubIndex => getRoot(i.expr)
case a: SubAccess => getRoot(a.expr)
- case f: SubField => getRoot(f.expr)
+ case f: SubField => getRoot(f.expr)
}
val eComp = toExp(comp)
val root = getRoot(eComp)
@@ -289,11 +294,12 @@ object WiringUtils {
case sx: DefMemory if sx.name == root =>
tpe = Some(MemPortUtils.memType(sx))
sx
- case sx => sx map getType
+ case sx => sx.map(getType)
+ }
+ val m = c.modules.find(_.name == module).getOrElse {
+ throw new WiringException(s"Must have a module named $module")
}
- val m = c.modules find (_.name == module) getOrElse {
- throw new WiringException(s"Must have a module named $module") }
- tpe = m.ports find (_.name == root) map (_.tpe)
+ tpe = m.ports.find(_.name == root).map(_.tpe)
m match {
case Module(i, n, ps, b) => getType(b)
case e: ExtModule =>
@@ -301,10 +307,10 @@ object WiringUtils {
tpe match {
case None => throw new WiringException(s"Didn't find $comp in $module!")
case Some(t) =>
- def setType(e: Expression): Expression = e map setType match {
+ def setType(e: Expression): Expression = e.map(setType) match {
case ex: Reference => ex.copy(tpe = t)
- case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
- case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
+ case ex: SubField => ex.copy(tpe = field_type(ex.expr.tpe, ex.name))
+ case ex: SubIndex => ex.copy(tpe = sub_type(ex.expr.tpe))
case ex: SubAccess => ex.copy(tpe = sub_type(ex.expr.tpe))
}
setType(eComp).tpe
diff --git a/src/main/scala/firrtl/proto/FromProto.scala b/src/main/scala/firrtl/proto/FromProto.scala
index 41a7e1de..5b9dd371 100644
--- a/src/main/scala/firrtl/proto/FromProto.scala
+++ b/src/main/scala/firrtl/proto/FromProto.scala
@@ -35,9 +35,9 @@ object FromProto {
// Convert from ProtoBuf message repeated Statements to FIRRRTL Block
private def compressStmts(stmts: scala.collection.Seq[ir.Statement]): ir.Statement = stmts match {
- case scala.collection.Seq() => ir.EmptyStmt
+ case scala.collection.Seq() => ir.EmptyStmt
case scala.collection.Seq(stmt) => stmt
- case multiple => ir.Block(multiple.toSeq)
+ case multiple => ir.Block(multiple.toSeq)
}
def convert(info: Firrtl.SourceInfo): ir.Info =
@@ -100,16 +100,16 @@ object FromProto {
def convert(expr: Firrtl.Expression): ir.Expression = {
import Firrtl.Expression._
expr.getExpressionCase.getNumber match {
- case REFERENCE_FIELD_NUMBER => ir.Reference(expr.getReference.getId, ir.UnknownType)
- case SUB_FIELD_FIELD_NUMBER => convert(expr.getSubField)
- case SUB_INDEX_FIELD_NUMBER => convert(expr.getSubIndex)
- case SUB_ACCESS_FIELD_NUMBER => convert(expr.getSubAccess)
- case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral)
- case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral)
+ case REFERENCE_FIELD_NUMBER => ir.Reference(expr.getReference.getId, ir.UnknownType)
+ case SUB_FIELD_FIELD_NUMBER => convert(expr.getSubField)
+ case SUB_INDEX_FIELD_NUMBER => convert(expr.getSubIndex)
+ case SUB_ACCESS_FIELD_NUMBER => convert(expr.getSubAccess)
+ case UINT_LITERAL_FIELD_NUMBER => convert(expr.getUintLiteral)
+ case SINT_LITERAL_FIELD_NUMBER => convert(expr.getSintLiteral)
case FIXED_LITERAL_FIELD_NUMBER => convert(expr.getFixedLiteral)
- case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp)
- case MUX_FIELD_NUMBER => convert(expr.getMux)
- case VALID_IF_FIELD_NUMBER => convert(expr.getValidIf)
+ case PRIM_OP_FIELD_NUMBER => convert(expr.getPrimOp)
+ case MUX_FIELD_NUMBER => convert(expr.getMux)
+ case VALID_IF_FIELD_NUMBER => convert(expr.getValidIf)
}
}
@@ -123,8 +123,14 @@ object FromProto {
ir.DefWire(convert(info), wire.getId, convert(wire.getType))
def convert(reg: Firrtl.Statement.Register, info: Firrtl.SourceInfo): ir.DefRegister =
- ir.DefRegister(convert(info), reg.getId, convert(reg.getType), convert(reg.getClock),
- convert(reg.getReset), convert(reg.getInit))
+ ir.DefRegister(
+ convert(info),
+ reg.getId,
+ convert(reg.getType),
+ convert(reg.getClock),
+ convert(reg.getReset),
+ convert(reg.getInit)
+ )
def convert(node: Firrtl.Statement.Node, info: Firrtl.SourceInfo): ir.DefNode =
ir.DefNode(convert(info), node.getId, convert(node.getExpression))
@@ -140,8 +146,8 @@ object FromProto {
def convert(ruw: ReadUnderWrite): ir.ReadUnderWrite.Value = ruw match {
case ReadUnderWrite.UNDEFINED => ir.ReadUnderWrite.Undefined
- case ReadUnderWrite.OLD => ir.ReadUnderWrite.Old
- case ReadUnderWrite.NEW => ir.ReadUnderWrite.New
+ case ReadUnderWrite.OLD => ir.ReadUnderWrite.Old
+ case ReadUnderWrite.NEW => ir.ReadUnderWrite.New
}
def convert(dt: Firrtl.Statement.CMemory.TypeAndDepth): (ir.Type, BigInt) =
@@ -161,9 +167,9 @@ object FromProto {
import Firrtl.Statement.MemoryPort.Direction._
def convert(mportdir: Firrtl.Statement.MemoryPort.Direction): MPortDir = mportdir match {
- case MEMORY_PORT_DIRECTION_INFER => MInfer
- case MEMORY_PORT_DIRECTION_READ => MRead
- case MEMORY_PORT_DIRECTION_WRITE => MWrite
+ case MEMORY_PORT_DIRECTION_INFER => MInfer
+ case MEMORY_PORT_DIRECTION_READ => MRead
+ case MEMORY_PORT_DIRECTION_WRITE => MWrite
case MEMORY_PORT_DIRECTION_READ_WRITE => MReadWrite
}
@@ -184,12 +190,18 @@ object FromProto {
def convert(formal: Formal): ir.Formal.Value = formal match {
case Formal.ASSERT => ir.Formal.Assert
case Formal.ASSUME => ir.Formal.Assume
- case Formal.COVER => ir.Formal.Cover
+ case Formal.COVER => ir.Formal.Cover
}
def convert(ver: Firrtl.Statement.Verification, info: Firrtl.SourceInfo): ir.Verification =
- ir.Verification(convert(ver.getOp), convert(info), convert(ver.getClk),
- convert(ver.getCond), convert(ver.getEn), ir.StringLit(ver.getMsg))
+ ir.Verification(
+ convert(ver.getOp),
+ convert(info),
+ convert(ver.getClk),
+ convert(ver.getCond),
+ convert(ver.getEn),
+ ir.StringLit(ver.getMsg)
+ )
def convert(mem: Firrtl.Statement.Memory, info: Firrtl.SourceInfo): ir.DefMemory = {
val dtype = convert(mem.getType)
@@ -198,11 +210,21 @@ object FromProto {
val rws = mem.getReadwriterIdList.asScala.toSeq
import Firrtl.Statement.Memory._
val depth = mem.getDepthCase.getNumber match {
- case UINT_DEPTH_FIELD_NUMBER => BigInt(mem.getUintDepth)
+ case UINT_DEPTH_FIELD_NUMBER => BigInt(mem.getUintDepth)
case BIGINT_DEPTH_FIELD_NUMBER => convert(mem.getBigintDepth)
}
- ir.DefMemory(convert(info), mem.getId, dtype, depth, mem.getWriteLatency, mem.getReadLatency,
- rs, ws, rws, convert(mem.getReadUnderWrite))
+ ir.DefMemory(
+ convert(info),
+ mem.getId,
+ dtype,
+ depth,
+ mem.getWriteLatency,
+ mem.getReadLatency,
+ rs,
+ ws,
+ rws,
+ convert(mem.getReadUnderWrite)
+ )
}
def convert(attach: Firrtl.Statement.Attach, info: Firrtl.SourceInfo): ir.Attach = {
@@ -214,21 +236,21 @@ object FromProto {
import Firrtl.Statement._
val info = stmt.getSourceInfo
stmt.getStatementCase.getNumber match {
- case NODE_FIELD_NUMBER => convert(stmt.getNode, info)
- case CONNECT_FIELD_NUMBER => convert(stmt.getConnect, info)
+ case NODE_FIELD_NUMBER => convert(stmt.getNode, info)
+ case CONNECT_FIELD_NUMBER => convert(stmt.getConnect, info)
case PARTIAL_CONNECT_FIELD_NUMBER => convert(stmt.getPartialConnect, info)
- case WIRE_FIELD_NUMBER => convert(stmt.getWire, info)
- case REGISTER_FIELD_NUMBER => convert(stmt.getRegister, info)
- case WHEN_FIELD_NUMBER => convert(stmt.getWhen, info)
- case INSTANCE_FIELD_NUMBER => convert(stmt.getInstance, info)
- case PRINTF_FIELD_NUMBER => convert(stmt.getPrintf, info)
- case STOP_FIELD_NUMBER => convert(stmt.getStop, info)
- case MEMORY_FIELD_NUMBER => convert(stmt.getMemory, info)
+ case WIRE_FIELD_NUMBER => convert(stmt.getWire, info)
+ case REGISTER_FIELD_NUMBER => convert(stmt.getRegister, info)
+ case WHEN_FIELD_NUMBER => convert(stmt.getWhen, info)
+ case INSTANCE_FIELD_NUMBER => convert(stmt.getInstance, info)
+ case PRINTF_FIELD_NUMBER => convert(stmt.getPrintf, info)
+ case STOP_FIELD_NUMBER => convert(stmt.getStop, info)
+ case MEMORY_FIELD_NUMBER => convert(stmt.getMemory, info)
case IS_INVALID_FIELD_NUMBER =>
ir.IsInvalid(convert(info), convert(stmt.getIsInvalid.getExpression))
- case CMEMORY_FIELD_NUMBER => convert(stmt.getCmemory, info)
+ case CMEMORY_FIELD_NUMBER => convert(stmt.getCmemory, info)
case MEMORY_PORT_FIELD_NUMBER => convert(stmt.getMemoryPort, info)
- case ATTACH_FIELD_NUMBER => convert(stmt.getAttach, info)
+ case ATTACH_FIELD_NUMBER => convert(stmt.getAttach, info)
}
}
@@ -244,7 +266,7 @@ object FromProto {
val w = if (ut.hasWidth) convert(ut.getWidth) else ir.UnknownWidth
ir.UIntType(w)
}
-
+
def convert(st: Firrtl.Type.SIntType): ir.SIntType = {
val w = if (st.hasWidth) convert(st.getWidth) else ir.UnknownWidth
ir.SIntType(w)
@@ -272,13 +294,13 @@ object FromProto {
def convert(tpe: Firrtl.Type): ir.Type = {
import Firrtl.Type._
tpe.getTypeCase.getNumber match {
- case UINT_TYPE_FIELD_NUMBER => convert(tpe.getUintType)
- case SINT_TYPE_FIELD_NUMBER => convert(tpe.getSintType)
- case FIXED_TYPE_FIELD_NUMBER => convert(tpe.getFixedType)
- case CLOCK_TYPE_FIELD_NUMBER => ir.ClockType
+ case UINT_TYPE_FIELD_NUMBER => convert(tpe.getUintType)
+ case SINT_TYPE_FIELD_NUMBER => convert(tpe.getSintType)
+ case FIXED_TYPE_FIELD_NUMBER => convert(tpe.getFixedType)
+ case CLOCK_TYPE_FIELD_NUMBER => ir.ClockType
case ASYNC_RESET_TYPE_FIELD_NUMBER => ir.AsyncResetType
- case RESET_TYPE_FIELD_NUMBER => ir.ResetType
- case ANALOG_TYPE_FIELD_NUMBER => convert(tpe.getAnalogType)
+ case RESET_TYPE_FIELD_NUMBER => ir.ResetType
+ case ANALOG_TYPE_FIELD_NUMBER => convert(tpe.getAnalogType)
case BUNDLE_TYPE_FIELD_NUMBER =>
ir.BundleType(tpe.getBundleType.getFieldList.asScala.map(convert(_)).toSeq)
case VECTOR_TYPE_FIELD_NUMBER => convert(tpe.getVectorType)
@@ -287,7 +309,7 @@ object FromProto {
def convert(dir: Firrtl.Port.Direction): ir.Direction = {
dir match {
- case Firrtl.Port.Direction.PORT_DIRECTION_IN => ir.Input
+ case Firrtl.Port.Direction.PORT_DIRECTION_IN => ir.Input
case Firrtl.Port.Direction.PORT_DIRECTION_OUT => ir.Output
}
}
@@ -302,9 +324,9 @@ object FromProto {
import Firrtl.Module.ExternalModule.Parameter._
val name = param.getId
param.getValueCase.getNumber match {
- case INTEGER_FIELD_NUMBER => ir.IntParam(name, convert(param.getInteger))
- case DOUBLE_FIELD_NUMBER => ir.DoubleParam(name, param.getDouble)
- case STRING_FIELD_NUMBER => ir.StringParam(name, ir.StringLit(param.getString))
+ case INTEGER_FIELD_NUMBER => ir.IntParam(name, convert(param.getInteger))
+ case DOUBLE_FIELD_NUMBER => ir.DoubleParam(name, param.getDouble)
+ case STRING_FIELD_NUMBER => ir.StringParam(name, ir.StringLit(param.getString))
case RAW_STRING_FIELD_NUMBER => ir.RawStringParam(name, param.getRawString)
}
}
diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala
index 47fb3cec..78b95582 100644
--- a/src/main/scala/firrtl/proto/ToProto.scala
+++ b/src/main/scala/firrtl/proto/ToProto.scala
@@ -6,7 +6,7 @@ package proto
import java.io.OutputStream
import FirrtlProtos._
-import Firrtl.Statement.{ReadUnderWrite, Formal}
+import Firrtl.Statement.{Formal, ReadUnderWrite}
import Firrtl.Expression.PrimOp.Op
import com.google.protobuf.{CodedOutputStream, WireFormat}
import firrtl.PrimOps._
@@ -15,7 +15,6 @@ import scala.collection.JavaConverters._
object ToProto {
-
/** Serialize a FIRRTL Circuit to an Output Stream as a ProtoBuf message
*
* @param ostream Output stream that will be written
@@ -38,9 +37,9 @@ object ToProto {
// Note this function is sensitive to changes to the Firrtl and Circuit protobuf message definitions
def writeToStreamFast(
ostream: OutputStream,
- info: ir.Info,
+ info: ir.Info,
modules: Seq[() => ir.DefModule],
- main: String
+ main: String
): Unit = {
val costream = CodedOutputStream.newInstance(ostream)
@@ -110,23 +109,25 @@ object ToProto {
def convert(ruw: ir.ReadUnderWrite.Value): ReadUnderWrite = ruw match {
case ir.ReadUnderWrite.Undefined => ReadUnderWrite.UNDEFINED
- case ir.ReadUnderWrite.Old => ReadUnderWrite.OLD
- case ir.ReadUnderWrite.New => ReadUnderWrite.NEW
+ case ir.ReadUnderWrite.Old => ReadUnderWrite.OLD
+ case ir.ReadUnderWrite.New => ReadUnderWrite.NEW
}
def convert(formal: ir.Formal.Value): Formal = formal match {
case ir.Formal.Assert => Formal.ASSERT
case ir.Formal.Assume => Formal.ASSUME
- case ir.Formal.Cover => Formal.COVER
+ case ir.Formal.Cover => Formal.COVER
}
def convertToIntegerLiteral(value: BigInt): Firrtl.Expression.IntegerLiteral.Builder = {
- Firrtl.Expression.IntegerLiteral.newBuilder()
+ Firrtl.Expression.IntegerLiteral
+ .newBuilder()
.setValue(value.toString)
}
def convertToBigInt(value: BigInt): Firrtl.BigInt.Builder = {
- Firrtl.BigInt.newBuilder()
+ Firrtl.BigInt
+ .newBuilder()
.setValue(com.google.protobuf.ByteString.copyFrom(value.toByteArray))
}
@@ -135,7 +136,7 @@ object ToProto {
info match {
case ir.NoInfo =>
ib.setNone(Firrtl.SourceInfo.None.newBuilder)
- case f : ir.FileInfo =>
+ case f: ir.FileInfo =>
ib.setText(f.unescaped)
// TODO properly implement MultiInfo
case ir.MultiInfo(infos) =>
@@ -148,54 +149,64 @@ object ToProto {
val eb = Firrtl.Expression.newBuilder()
expr match {
case ir.Reference(name, _, _, _) =>
- val rb = Firrtl.Expression.Reference.newBuilder()
+ val rb = Firrtl.Expression.Reference
+ .newBuilder()
.setId(name)
eb.setReference(rb)
case ir.SubField(e, name, _, _) =>
- val sb = Firrtl.Expression.SubField.newBuilder()
+ val sb = Firrtl.Expression.SubField
+ .newBuilder()
.setExpression(convert(e))
.setField(name)
eb.setSubField(sb)
case ir.SubIndex(e, value, _, _) =>
- val sb = Firrtl.Expression.SubIndex.newBuilder()
+ val sb = Firrtl.Expression.SubIndex
+ .newBuilder()
.setExpression(convert(e))
.setIndex(convertToIntegerLiteral(value))
eb.setSubIndex(sb)
case ir.SubAccess(e, index, _, _) =>
- val sb = Firrtl.Expression.SubAccess.newBuilder()
+ val sb = Firrtl.Expression.SubAccess
+ .newBuilder()
.setExpression(convert(e))
.setIndex(convert(index))
eb.setSubAccess(sb)
case ir.UIntLiteral(value, width) =>
- val ub = Firrtl.Expression.UIntLiteral.newBuilder()
+ val ub = Firrtl.Expression.UIntLiteral
+ .newBuilder()
.setValue(convertToIntegerLiteral(value))
convert(width).foreach(ub.setWidth)
eb.setUintLiteral(ub)
case ir.SIntLiteral(value, width) =>
- val sb = Firrtl.Expression.SIntLiteral.newBuilder()
+ val sb = Firrtl.Expression.SIntLiteral
+ .newBuilder()
.setValue(convertToIntegerLiteral(value))
convert(width).foreach(sb.setWidth)
eb.setSintLiteral(sb)
case ir.FixedLiteral(value, width, point) =>
- val fb = Firrtl.Expression.FixedLiteral.newBuilder()
+ val fb = Firrtl.Expression.FixedLiteral
+ .newBuilder()
.setValue(convertToBigInt(value))
convert(width).foreach(fb.setWidth)
convert(point).foreach(fb.setPoint)
eb.setFixedLiteral(fb)
case ir.DoPrim(op, args, consts, _) =>
- val db = Firrtl.Expression.PrimOp.newBuilder()
+ val db = Firrtl.Expression.PrimOp
+ .newBuilder()
.setOp(convert(op))
consts.foreach(c => db.addConst(convertToIntegerLiteral(c)))
args.foreach(a => db.addArg(convert(a)))
eb.setPrimOp(db)
case ir.Mux(cond, tval, fval, _) =>
- val mb = Firrtl.Expression.Mux.newBuilder()
+ val mb = Firrtl.Expression.Mux
+ .newBuilder()
.setCondition(convert(cond))
.setTValue(convert(tval))
.setFValue(convert(fval))
eb.setMux(mb)
case ir.ValidIf(cond, value, _) =>
- val vb = Firrtl.Expression.ValidIf.newBuilder()
+ val vb = Firrtl.Expression.ValidIf
+ .newBuilder()
.setCondition(convert(cond))
.setValue(convert(value))
eb.setValidIf(vb)
@@ -205,37 +216,41 @@ object ToProto {
def convert(dir: MPortDir): Firrtl.Statement.MemoryPort.Direction = {
import Firrtl.Statement.MemoryPort.Direction._
dir match {
- case MInfer => MEMORY_PORT_DIRECTION_INFER
- case MRead => MEMORY_PORT_DIRECTION_READ
- case MWrite => MEMORY_PORT_DIRECTION_WRITE
+ case MInfer => MEMORY_PORT_DIRECTION_INFER
+ case MRead => MEMORY_PORT_DIRECTION_READ
+ case MWrite => MEMORY_PORT_DIRECTION_WRITE
case MReadWrite => MEMORY_PORT_DIRECTION_READ_WRITE
}
}
def convert(tpe: ir.Type, depth: BigInt): Firrtl.Statement.CMemory.TypeAndDepth.Builder =
- Firrtl.Statement.CMemory.TypeAndDepth.newBuilder()
+ Firrtl.Statement.CMemory.TypeAndDepth
+ .newBuilder()
.setDataType(convert(tpe))
.setDepth(convertToBigInt(depth))
def convert(stmt: ir.Statement): Seq[Firrtl.Statement.Builder] = {
stmt match {
case ir.Block(stmts) => stmts.flatMap(convert(_))
- case ir.EmptyStmt => Seq.empty
+ case ir.EmptyStmt => Seq.empty
case other =>
val sb = Firrtl.Statement.newBuilder()
other match {
case ir.DefNode(_, name, expr) =>
- val nb = Firrtl.Statement.Node.newBuilder()
+ val nb = Firrtl.Statement.Node
+ .newBuilder()
.setId(name)
.setExpression(convert(expr))
sb.setNode(nb)
case ir.DefWire(_, name, tpe) =>
- val wb = Firrtl.Statement.Wire.newBuilder()
+ val wb = Firrtl.Statement.Wire
+ .newBuilder()
.setId(name)
.setType(convert(tpe))
sb.setWire(wb)
case ir.DefRegister(_, name, tpe, clock, reset, init) =>
- val rb = Firrtl.Statement.Register.newBuilder()
+ val rb = Firrtl.Statement.Register
+ .newBuilder()
.setId(name)
.setType(convert(tpe))
.setClock(convert(clock))
@@ -243,54 +258,63 @@ object ToProto {
.setInit(convert(init))
sb.setRegister(rb)
case ir.DefInstance(_, name, module, _) =>
- val ib = Firrtl.Statement.Instance.newBuilder()
+ val ib = Firrtl.Statement.Instance
+ .newBuilder()
.setId(name)
.setModuleId(module)
sb.setInstance(ib)
case ir.Connect(_, loc, expr) =>
- val cb = Firrtl.Statement.Connect.newBuilder()
+ val cb = Firrtl.Statement.Connect
+ .newBuilder()
.setLocation(convert(loc))
.setExpression(convert(expr))
sb.setConnect(cb)
case ir.PartialConnect(_, loc, expr) =>
- val cb = Firrtl.Statement.PartialConnect.newBuilder()
+ val cb = Firrtl.Statement.PartialConnect
+ .newBuilder()
.setLocation(convert(loc))
.setExpression(convert(expr))
sb.setPartialConnect(cb)
case ir.Conditionally(_, pred, conseq, alt) =>
val cs = convert(conseq)
val as = convert(alt)
- val wb = Firrtl.Statement.When.newBuilder()
+ val wb = Firrtl.Statement.When
+ .newBuilder()
.setPredicate(convert(pred))
cs.foreach(wb.addConsequent)
as.foreach(wb.addOtherwise)
sb.setWhen(wb)
case ir.Print(_, string, args, clk, en) =>
- val pb = Firrtl.Statement.Printf.newBuilder()
+ val pb = Firrtl.Statement.Printf
+ .newBuilder()
.setValue(string.string)
.setClk(convert(clk))
.setEn(convert(en))
args.foreach(a => pb.addArg(convert(a)))
sb.setPrintf(pb)
case ir.Stop(_, ret, clk, en) =>
- val stopb = Firrtl.Statement.Stop.newBuilder()
+ val stopb = Firrtl.Statement.Stop
+ .newBuilder()
.setReturnValue(ret)
.setClk(convert(clk))
.setEn(convert(en))
sb.setStop(stopb)
case ir.Verification(op, _, clk, cond, en, msg) =>
- val vb = Firrtl.Statement.Verification.newBuilder()
+ val vb = Firrtl.Statement.Verification
+ .newBuilder()
.setOp(convert(op))
.setClk(convert(clk))
.setCond(convert(cond))
.setEn(convert(en))
.setMsg(msg.string)
case ir.IsInvalid(_, expr) =>
- val ib = Firrtl.Statement.IsInvalid.newBuilder()
+ val ib = Firrtl.Statement.IsInvalid
+ .newBuilder()
.setExpression(convert(expr))
sb.setIsInvalid(ib)
case ir.DefMemory(_, name, dtype, depth, wlat, rlat, rs, ws, rws, ruw) =>
- val mem = Firrtl.Statement.Memory.newBuilder()
+ val mem = Firrtl.Statement.Memory
+ .newBuilder()
.setId(name)
.setType(convert(dtype))
.setBigintDepth(convertToBigInt(depth))
@@ -302,14 +326,16 @@ object ToProto {
mem.addAllReadwriterId(rws.asJava)
sb.setMemory(mem)
case CDefMemory(_, name, tpe, size, seq, ruw) =>
- val mb = Firrtl.Statement.CMemory.newBuilder()
+ val mb = Firrtl.Statement.CMemory
+ .newBuilder()
.setId(name)
.setTypeAndDepth(convert(tpe, size))
.setSyncRead(seq)
.setReadUnderWrite(convert(ruw))
sb.setCmemory(mb)
case CDefMPort(_, name, _, mem, exprs, dir) =>
- val pb = Firrtl.Statement.MemoryPort.newBuilder()
+ val pb = Firrtl.Statement.MemoryPort
+ .newBuilder()
.setId(name)
.setMemoryId(mem)
.setMemoryIndex(convert(exprs.head))
@@ -330,7 +356,8 @@ object ToProto {
}
def convert(field: ir.Field): Firrtl.Type.BundleType.Field.Builder = {
- val b = Firrtl.Type.BundleType.Field.newBuilder()
+ val b = Firrtl.Type.BundleType.Field
+ .newBuilder()
.setId(field.name)
.setIsFlipped(field.flip == ir.Flip)
.setType(convert(field.tpe))
@@ -343,12 +370,13 @@ object ToProto {
* @return Option width where None means the width field should be cleared in the parent object
*/
def convert(width: ir.Width): Option[Firrtl.Width.Builder] = width match {
- case ir.IntWidth(w) => Some(Firrtl.Width.newBuilder().setValue(w.toInt))
+ case ir.IntWidth(w) => Some(Firrtl.Width.newBuilder().setValue(w.toInt))
case ir.UnknownWidth => None
}
def convert(vtpe: ir.VectorType): Firrtl.Type.VectorType.Builder =
- Firrtl.Type.VectorType.newBuilder()
+ Firrtl.Type.VectorType
+ .newBuilder()
.setType(convert(vtpe.tpe))
.setSize(vtpe.size)
@@ -379,7 +407,7 @@ object ToProto {
tb.setResetType(rt)
case ir.AnalogType(width) =>
val at = Firrtl.Type.AnalogType.newBuilder()
- convert(width).foreach(at.setWidth)
+ convert(width).foreach(at.setWidth)
tb.setAnalogType(at)
case ir.BundleType(fields) =>
val bt = Firrtl.Type.BundleType.newBuilder()
@@ -392,12 +420,13 @@ object ToProto {
}
def convert(direction: ir.Direction): Firrtl.Port.Direction = direction match {
- case ir.Input => Firrtl.Port.Direction.PORT_DIRECTION_IN
+ case ir.Input => Firrtl.Port.Direction.PORT_DIRECTION_IN
case ir.Output => Firrtl.Port.Direction.PORT_DIRECTION_OUT
}
def convert(port: ir.Port): Firrtl.Port.Builder = {
- Firrtl.Port.newBuilder()
+ Firrtl.Port
+ .newBuilder()
.setId(port.name)
.setDirection(convert(port.direction))
.setType(convert(port.tpe))
@@ -405,7 +434,8 @@ object ToProto {
def convert(param: ir.Param): Firrtl.Module.ExternalModule.Parameter.Builder = {
import Firrtl.Module.ExternalModule._
- val pb = Parameter.newBuilder()
+ val pb = Parameter
+ .newBuilder()
.setId(param.name)
param match {
case ir.IntParam(_, value) =>
@@ -425,13 +455,15 @@ object ToProto {
module match {
case mod: ir.Module =>
val stmts = convert(mod.body)
- val mb = Firrtl.Module.UserModule.newBuilder()
+ val mb = Firrtl.Module.UserModule
+ .newBuilder()
.setId(mod.name)
ports.foreach(mb.addPort)
stmts.foreach(mb.addStatement)
b.setUserModule(mb)
case ext: ir.ExtModule =>
- val eb = Firrtl.Module.ExternalModule.newBuilder()
+ val eb = Firrtl.Module.ExternalModule
+ .newBuilder()
.setId(ext.name)
.setDefinedName(ext.defname)
ports.foreach(eb.addPort)
@@ -448,7 +480,8 @@ object ToProto {
for (m <- moduleBuilders) {
cb.addModule(m)
}
- Firrtl.newBuilder()
+ Firrtl
+ .newBuilder()
.addCircuit(cb.build())
.build()
}
diff --git a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
index d587fd8c..d37d2881 100644
--- a/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
+++ b/src/main/scala/firrtl/stage/FirrtlAnnotations.scala
@@ -35,14 +35,16 @@ sealed trait CircuitOption extends Unserializable { this: Annotation =>
case class FirrtlFileAnnotation(file: String) extends NoTargetAnnotation with CircuitOption {
def toCircuit(info: Parser.InfoMode): FirrtlCircuitAnnotation = {
- val circuit = try {
- FirrtlStageUtils.getFileExtension(file) match {
- case ProtoBufFile => proto.FromProto.fromFile(file)
- case FirrtlFile => Parser.parseFile(file, info) }
- } catch {
- case a @ (_: FileNotFoundException | _: NoSuchFileException) =>
- throw new OptionsException(s"Input file '$file' not found! (Did you misspell it?)", a)
- }
+ val circuit =
+ try {
+ FirrtlStageUtils.getFileExtension(file) match {
+ case ProtoBufFile => proto.FromProto.fromFile(file)
+ case FirrtlFile => Parser.parseFile(file, info)
+ }
+ } catch {
+ case a @ (_: FileNotFoundException | _: NoSuchFileException) =>
+ throw new OptionsException(s"Input file '$file' not found! (Did you misspell it?)", a)
+ }
FirrtlCircuitAnnotation(circuit)
}
@@ -52,11 +54,13 @@ object FirrtlFileAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[String](
- longOption = "input-file",
+ longOption = "input-file",
toAnnotationSeq = a => Seq(FirrtlFileAnnotation(a)),
- helpText = "An input FIRRTL file",
- shortOption = Some("i"),
- helpValueName = Some("<file>") ) )
+ helpText = "An input FIRRTL file",
+ shortOption = Some("i"),
+ helpValueName = Some("<file>")
+ )
+ )
}
@@ -70,11 +74,13 @@ object OutputFileAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[String](
- longOption = "output-file",
+ longOption = "output-file",
toAnnotationSeq = a => Seq(OutputFileAnnotation(a)),
- helpText = "The output FIRRTL file",
- shortOption = Some("o"),
- helpValueName = Some("<file>") ) )
+ helpText = "The output FIRRTL file",
+ shortOption = Some("o"),
+ helpValueName = Some("<file>")
+ )
+ )
}
@@ -84,8 +90,10 @@ object OutputFileAnnotation extends HasShellOptions {
* @note This cannote be directly converted to [[Parser.InfoMode]] as that depends on an optional [[FirrtlFileAnnotation]]
*/
case class InfoModeAnnotation(modeName: String = "use") extends NoTargetAnnotation with FirrtlOption {
- require(modeName match { case "use" | "ignore" | "gen" | "append" => true; case _ => false },
- s"Unknown info mode '$modeName'! (Did you misspell it?)")
+ require(
+ modeName match { case "use" | "ignore" | "gen" | "append" => true; case _ => false },
+ s"Unknown info mode '$modeName'! (Did you misspell it?)"
+ )
/** Return the [[Parser.InfoMode]] equivalent for this [[firrtl.annotations.Annotation Annotation]]
* @param infoSource the name of a file to use for "gen" or "append" info modes
@@ -93,7 +101,7 @@ case class InfoModeAnnotation(modeName: String = "use") extends NoTargetAnnotati
def toInfoMode(infoSource: Option[String] = None): Parser.InfoMode = modeName match {
case "use" => Parser.UseInfo
case "ignore" => Parser.IgnoreInfo
- case _ =>
+ case _ =>
val a = infoSource.getOrElse("unknown source")
modeName match {
case "gen" => Parser.GenInfo(a)
@@ -106,10 +114,12 @@ object InfoModeAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[String](
- longOption = "info-mode",
+ longOption = "info-mode",
toAnnotationSeq = a => Seq(InfoModeAnnotation(a)),
- helpText = s"Source file info handling mode (default: ${apply().modeName})",
- helpValueName = Some("<ignore|use|gen|append>") ) )
+ helpText = s"Source file info handling mode (default: ${apply().modeName})",
+ helpValueName = Some("<ignore|use|gen|append>")
+ )
+ )
}
@@ -128,10 +138,12 @@ object FirrtlSourceAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[String](
- longOption = "firrtl-source",
+ longOption = "firrtl-source",
toAnnotationSeq = a => Seq(FirrtlSourceAnnotation(a)),
- helpText = "An input FIRRTL circuit string",
- helpValueName = Some("<string>") ) )
+ helpText = "An input FIRRTL circuit string",
+ helpValueName = Some("<string>")
+ )
+ )
}
@@ -144,27 +156,29 @@ case class CompilerAnnotation(compiler: Compiler = new VerilogCompiler()) extend
object CompilerAnnotation extends HasShellOptions {
- private [firrtl] def apply(compilerName: String): CompilerAnnotation = {
+ private[firrtl] def apply(compilerName: String): CompilerAnnotation = {
val c = compilerName match {
- case "none" => new NoneCompiler()
- case "high" => new HighFirrtlCompiler()
- case "low" => new LowFirrtlCompiler()
- case "middle" => new MiddleFirrtlCompiler()
- case "verilog" => new VerilogCompiler()
- case "mverilog" => new MinimumVerilogCompiler()
- case "sverilog" => new SystemVerilogCompiler()
- case _ => throw new OptionsException(s"Unknown compiler name '$compilerName'! (Did you misspell it?)")
+ case "none" => new NoneCompiler()
+ case "high" => new HighFirrtlCompiler()
+ case "low" => new LowFirrtlCompiler()
+ case "middle" => new MiddleFirrtlCompiler()
+ case "verilog" => new VerilogCompiler()
+ case "mverilog" => new MinimumVerilogCompiler()
+ case "sverilog" => new SystemVerilogCompiler()
+ case _ => throw new OptionsException(s"Unknown compiler name '$compilerName'! (Did you misspell it?)")
}
CompilerAnnotation(c)
}
val options = Seq(
new ShellOption[String](
- longOption = "compiler",
+ longOption = "compiler",
toAnnotationSeq = a => Seq(CompilerAnnotation(a)),
- helpText = "The FIRRTL compiler to use (default: verilog)",
- shortOption = Some("X"),
- helpValueName = Some("<none|high|middle|low|verilog|mverilog|sverilog>") ) )
+ helpText = "The FIRRTL compiler to use (default: verilog)",
+ shortOption = Some("X"),
+ helpValueName = Some("<none|high|middle|low|verilog|mverilog|sverilog>")
+ )
+ )
}
@@ -188,21 +202,26 @@ object RunFirrtlTransformAnnotation extends HasShellOptions {
val tx = Class.forName(txName).asInstanceOf[Class[_ <: Transform]].newInstance()
RunFirrtlTransformAnnotation(tx)
} catch {
- case e: ClassNotFoundException => throw new OptionsException(
- s"Unable to locate custom transform $txName (did you misspell it?)", e)
- case e: InstantiationException => throw new OptionsException(
- s"Unable to create instance of Transform $txName (is this an anonymous class?)", e)
- case e: Throwable => throw new OptionsException(
- s"Unknown error when instantiating class $txName", e) }),
+ case e: ClassNotFoundException =>
+ throw new OptionsException(s"Unable to locate custom transform $txName (did you misspell it?)", e)
+ case e: InstantiationException =>
+ throw new OptionsException(
+ s"Unable to create instance of Transform $txName (is this an anonymous class?)",
+ e
+ )
+ case e: Throwable => throw new OptionsException(s"Unknown error when instantiating class $txName", e)
+ }
+ ),
helpText = "Run these transforms during compilation",
shortOption = Some("fct"),
- helpValueName = Some("<package>.<class>") ),
+ helpValueName = Some("<package>.<class>")
+ ),
new ShellOption[String](
longOption = "change-name-case",
toAnnotationSeq = _ match {
case "lower" => Seq(RunFirrtlTransformAnnotation(new firrtl.features.LowerCaseNames))
case "upper" => Seq(RunFirrtlTransformAnnotation(new firrtl.features.UpperCaseNames))
- case a => throw new OptionsException(s"Unknown case '$a'. Did you misspell it?")
+ case a => throw new OptionsException(s"Unknown case '$a'. Did you misspell it?")
},
helpText = "Convert all FIRRTL names to a specific case",
helpValueName = Some("<lower|upper>")
@@ -231,9 +250,9 @@ case object SuppressScalaVersionWarning extends NoTargetAnnotation with FirrtlOp
def longOption: String = "Wno-scala-version-warning"
val options = Seq(
new ShellOption[Unit](
- longOption = longOption,
+ longOption = longOption,
toAnnotationSeq = { _ => Seq(this) },
- helpText = "Suppress Scala 2.11 deprecation warning (ignored in Scala 2.12+)"
+ helpText = "Suppress Scala 2.11 deprecation warning (ignored in Scala 2.12+)"
)
)
}
diff --git a/src/main/scala/firrtl/stage/FirrtlCli.scala b/src/main/scala/firrtl/stage/FirrtlCli.scala
index 39b89bea..fb5aa09f 100644
--- a/src/main/scala/firrtl/stage/FirrtlCli.scala
+++ b/src/main/scala/firrtl/stage/FirrtlCli.scala
@@ -11,16 +11,18 @@ import firrtl.transforms.NoCircuitDedupAnnotation
*/
trait FirrtlCli { this: Shell =>
parser.note("FIRRTL Compiler Options")
- Seq( FirrtlFileAnnotation,
- OutputFileAnnotation,
- InfoModeAnnotation,
- FirrtlSourceAnnotation,
- CompilerAnnotation,
- RunFirrtlTransformAnnotation,
- firrtl.EmitCircuitAnnotation,
- firrtl.EmitAllModulesAnnotation,
- NoCircuitDedupAnnotation,
- SuppressScalaVersionWarning)
+ Seq(
+ FirrtlFileAnnotation,
+ OutputFileAnnotation,
+ InfoModeAnnotation,
+ FirrtlSourceAnnotation,
+ CompilerAnnotation,
+ RunFirrtlTransformAnnotation,
+ firrtl.EmitCircuitAnnotation,
+ firrtl.EmitAllModulesAnnotation,
+ NoCircuitDedupAnnotation,
+ SuppressScalaVersionWarning
+ )
.map(_.addOptions(parser))
phases.DriverCompatibility.TopNameAnnotation.addOptions(parser)
diff --git a/src/main/scala/firrtl/stage/FirrtlOptions.scala b/src/main/scala/firrtl/stage/FirrtlOptions.scala
index 61dec7c5..55d4cc31 100644
--- a/src/main/scala/firrtl/stage/FirrtlOptions.scala
+++ b/src/main/scala/firrtl/stage/FirrtlOptions.scala
@@ -9,19 +9,17 @@ import firrtl.ir.Circuit
* @param infoModeName the policy for generating [[firrtl.ir Info]] when processing FIRRTL (default: "append")
* @param firrtlCircuit a [[firrtl.ir Circuit]]
*/
-class FirrtlOptions private [stage] (
- val outputFileName: Option[String] = None,
- val infoModeName: String = InfoModeAnnotation().modeName,
- val firrtlCircuit: Option[Circuit] = None) {
+class FirrtlOptions private[stage] (
+ val outputFileName: Option[String] = None,
+ val infoModeName: String = InfoModeAnnotation().modeName,
+ val firrtlCircuit: Option[Circuit] = None) {
- private [stage] def copy(
- outputFileName: Option[String] = outputFileName,
- infoModeName: String = infoModeName,
- firrtlCircuit: Option[Circuit] = firrtlCircuit ): FirrtlOptions = {
+ private[stage] def copy(
+ outputFileName: Option[String] = outputFileName,
+ infoModeName: String = infoModeName,
+ firrtlCircuit: Option[Circuit] = firrtlCircuit
+ ): FirrtlOptions = {
- new FirrtlOptions(
- outputFileName = outputFileName,
- infoModeName = infoModeName,
- firrtlCircuit = firrtlCircuit )
+ new FirrtlOptions(outputFileName = outputFileName, infoModeName = infoModeName, firrtlCircuit = firrtlCircuit)
}
}
diff --git a/src/main/scala/firrtl/stage/FirrtlStage.scala b/src/main/scala/firrtl/stage/FirrtlStage.scala
index 1042f979..58d07e43 100644
--- a/src/main/scala/firrtl/stage/FirrtlStage.scala
+++ b/src/main/scala/firrtl/stage/FirrtlStage.scala
@@ -7,8 +7,7 @@ import firrtl.options.{Dependency, Phase, PhaseManager, Shell, Stage, StageMain}
import firrtl.options.phases.DeletedWrapper
import firrtl.stage.phases.CatchExceptions
-class FirrtlPhase
- extends PhaseManager(targets=Seq(Dependency[firrtl.stage.phases.Compiler])) {
+class FirrtlPhase extends PhaseManager(targets = Seq(Dependency[firrtl.stage.phases.Compiler])) {
override def invalidates(a: Phase) = false
diff --git a/src/main/scala/firrtl/stage/FirrtlStageUtils.scala b/src/main/scala/firrtl/stage/FirrtlStageUtils.scala
index e2304a92..aa9781db 100644
--- a/src/main/scala/firrtl/stage/FirrtlStageUtils.scala
+++ b/src/main/scala/firrtl/stage/FirrtlStageUtils.scala
@@ -2,14 +2,14 @@
package firrtl.stage
-private [stage] sealed trait FileExtension
-private [stage] case object FirrtlFile extends FileExtension
-private [stage] case object ProtoBufFile extends FileExtension
+private[stage] sealed trait FileExtension
+private[stage] case object FirrtlFile extends FileExtension
+private[stage] case object ProtoBufFile extends FileExtension
/** Utilities that help with processing FIRRTL options */
object FirrtlStageUtils {
- private [stage] def getFileExtension(file: String): FileExtension = file.drop(file.lastIndexOf('.')) match {
+ private[stage] def getFileExtension(file: String): FileExtension = file.drop(file.lastIndexOf('.')) match {
case ".pb" => ProtoBufFile
case _ => FirrtlFile
}
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index 636d0609..a0c5ea0c 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -17,28 +17,34 @@ object Forms {
val ChirrtlForm: Seq[TransformDependency] = Seq.empty
val MinimalHighForm: Seq[TransformDependency] = ChirrtlForm ++
- Seq( Dependency(passes.CheckChirrtl),
- Dependency(passes.CInferTypes),
- Dependency(passes.CInferMDir),
- Dependency(passes.RemoveCHIRRTL),
- Dependency[annotations.transforms.CleanupNamedTargets] )
+ Seq(
+ Dependency(passes.CheckChirrtl),
+ Dependency(passes.CInferTypes),
+ Dependency(passes.CInferMDir),
+ Dependency(passes.RemoveCHIRRTL),
+ Dependency[annotations.transforms.CleanupNamedTargets]
+ )
val WorkingIR: Seq[TransformDependency] = MinimalHighForm :+ Dependency(passes.ToWorkingIR)
val Checks: Seq[TransformDependency] =
- Seq( Dependency(passes.CheckHighForm),
- Dependency(passes.CheckTypes),
- Dependency(passes.CheckFlows),
- Dependency(passes.CheckWidths) )
+ Seq(
+ Dependency(passes.CheckHighForm),
+ Dependency(passes.CheckTypes),
+ Dependency(passes.CheckFlows),
+ Dependency(passes.CheckWidths)
+ )
val Resolved: Seq[TransformDependency] = WorkingIR ++ Checks ++
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferBinaryPoints],
- Dependency[passes.TrimIntervals],
- Dependency[passes.InferWidths],
- Dependency[firrtl.transforms.InferResets] )
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferBinaryPoints],
+ Dependency[passes.TrimIntervals],
+ Dependency[passes.InferWidths],
+ Dependency[firrtl.transforms.InferResets]
+ )
val Deduped: Seq[TransformDependency] = Resolved :+ Dependency[firrtl.transforms.DedupModules]
@@ -49,61 +55,71 @@ object Forms {
Deduped
val MidForm: Seq[TransformDependency] = HighForm ++
- Seq( Dependency(passes.PullMuxes),
- Dependency(passes.ReplaceAccesses),
- Dependency(passes.ExpandConnects),
- Dependency(passes.RemoveAccesses),
- Dependency(passes.ZeroLengthVecs),
- Dependency[passes.ExpandWhensAndCheck],
- Dependency[passes.RemoveIntervals],
- Dependency(passes.ConvertFixedToSInt),
- Dependency(passes.ZeroWidth),
- Dependency[firrtl.transforms.formal.AssertSubmoduleAssumptions] )
+ Seq(
+ Dependency(passes.PullMuxes),
+ Dependency(passes.ReplaceAccesses),
+ Dependency(passes.ExpandConnects),
+ Dependency(passes.RemoveAccesses),
+ Dependency(passes.ZeroLengthVecs),
+ Dependency[passes.ExpandWhensAndCheck],
+ Dependency[passes.RemoveIntervals],
+ Dependency(passes.ConvertFixedToSInt),
+ Dependency(passes.ZeroWidth),
+ Dependency[firrtl.transforms.formal.AssertSubmoduleAssumptions]
+ )
val LowForm: Seq[TransformDependency] = MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
- Dependency(firrtl.transforms.RemoveReset),
- Dependency[firrtl.transforms.CheckCombLoops],
- Dependency[checks.CheckResets],
- Dependency[firrtl.transforms.RemoveWires] )
+ Seq(
+ Dependency(passes.LowerTypes),
+ Dependency(passes.Legalize),
+ Dependency(firrtl.transforms.RemoveReset),
+ Dependency[firrtl.transforms.CheckCombLoops],
+ Dependency[checks.CheckResets],
+ Dependency[firrtl.transforms.RemoveWires]
+ )
val LowFormMinimumOptimized: Seq[TransformDependency] = LowForm ++
- Seq( Dependency(passes.RemoveValidIf),
- Dependency(passes.PadWidths),
- Dependency(passes.memlib.VerilogMemDelays),
- Dependency(passes.SplitExpressions),
- Dependency[firrtl.transforms.LegalizeAndReductionsTransform] )
+ Seq(
+ Dependency(passes.RemoveValidIf),
+ Dependency(passes.PadWidths),
+ Dependency(passes.memlib.VerilogMemDelays),
+ Dependency(passes.SplitExpressions),
+ Dependency[firrtl.transforms.LegalizeAndReductionsTransform]
+ )
val LowFormOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.ConstantPropagation],
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination),
- Dependency[firrtl.transforms.DeadCodeElimination] )
+ Seq(
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency[firrtl.transforms.CombineCats],
+ Dependency(passes.CommonSubexpressionElimination),
+ Dependency[firrtl.transforms.DeadCodeElimination]
+ )
val VerilogMinimumOptimized: Seq[TransformDependency] = LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(passes.VerilogPrep),
- Dependency[firrtl.AddDescriptionNodes] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(passes.VerilogPrep),
+ Dependency[firrtl.AddDescriptionNodes]
+ )
val VerilogOptimized: Seq[TransformDependency] = LowFormOptimized ++ VerilogMinimumOptimized
val AssertsRemoved: Seq[TransformDependency] =
- Seq( Dependency(firrtl.transforms.formal.ConvertAsserts),
- Dependency[firrtl.transforms.formal.RemoveVerificationStatements] )
+ Seq(
+ Dependency(firrtl.transforms.formal.ConvertAsserts),
+ Dependency[firrtl.transforms.formal.RemoveVerificationStatements]
+ )
val BackendEmitters =
- Seq( Dependency[VerilogEmitter],
- Dependency[MinimumVerilogEmitter],
- Dependency[SystemVerilogEmitter] )
+ Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[SystemVerilogEmitter])
val LowEmitters = Dependency[LowFirrtlEmitter] +: BackendEmitters
diff --git a/src/main/scala/firrtl/stage/TransformManager.scala b/src/main/scala/firrtl/stage/TransformManager.scala
index 1b3032be..aa96ca86 100644
--- a/src/main/scala/firrtl/stage/TransformManager.scala
+++ b/src/main/scala/firrtl/stage/TransformManager.scala
@@ -12,15 +12,17 @@ import firrtl.options.{Dependency, DependencyManager}
* @param knownObjects existing transform objects that have already been constructed
*/
class TransformManager(
- val targets: Seq[TransformManager.TransformDependency],
+ val targets: Seq[TransformManager.TransformDependency],
val currentState: Seq[TransformManager.TransformDependency] = Seq.empty,
- val knownObjects: Set[Transform] = Set.empty) extends Transform
+ val knownObjects: Set[Transform] = Set.empty)
+ extends Transform
with DependencyAPIMigration
with DependencyManager[CircuitState, Transform] {
override def execute(state: CircuitState): CircuitState = transform(state)
- override protected def copy(a: Seq[Dependency[Transform]], b: Seq[Dependency[Transform]], c: Set[Transform]) = new TransformManager(a, b, c)
+ override protected def copy(a: Seq[Dependency[Transform]], b: Seq[Dependency[Transform]], c: Set[Transform]) =
+ new TransformManager(a, b, c)
}
diff --git a/src/main/scala/firrtl/stage/package.scala b/src/main/scala/firrtl/stage/package.scala
index 123c763a..37e2d13c 100644
--- a/src/main/scala/firrtl/stage/package.scala
+++ b/src/main/scala/firrtl/stage/package.scala
@@ -25,46 +25,50 @@ package object stage {
/**
* @todo custom transforms are appended as discovered, can this be prepended safely?
*/
- def view(options: AnnotationSeq): FirrtlOptions = options
- .collect { case a: FirrtlOption => a }
- .foldLeft(new FirrtlOptions()){ (c, x) =>
+ def view(options: AnnotationSeq): FirrtlOptions = options.collect { case a: FirrtlOption => a }
+ .foldLeft(new FirrtlOptions()) { (c, x) =>
x match {
- case OutputFileAnnotation(f) => c.copy(outputFileName = Some(f))
- case InfoModeAnnotation(i) => c.copy(infoModeName = i)
- case FirrtlCircuitAnnotation(cir) => c.copy(firrtlCircuit = Some(cir))
- case a : CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a") ; c
- case SuppressScalaVersionWarning => c
+ case OutputFileAnnotation(f) => c.copy(outputFileName = Some(f))
+ case InfoModeAnnotation(i) => c.copy(infoModeName = i)
+ case FirrtlCircuitAnnotation(cir) => c.copy(firrtlCircuit = Some(cir))
+ case a: CompilerAnnotation => logger.warn(s"Use of CompilerAnnotation is deprecated. Ignoring $a"); c
+ case SuppressScalaVersionWarning => c
}
}
}
- private [firrtl] implicit object FirrtlExecutionResultView extends OptionsView[FirrtlExecutionResult] with LazyLogging {
+ private[firrtl] implicit object FirrtlExecutionResultView
+ extends OptionsView[FirrtlExecutionResult]
+ with LazyLogging {
def view(options: AnnotationSeq): FirrtlExecutionResult = {
- val emittedRes = options
- .collect{ case a: EmittedAnnotation[_] => a.value.value }
+ val emittedRes = options.collect { case a: EmittedAnnotation[_] => a.value.value }
.mkString("\n")
- val emitters = options.collect{ case RunFirrtlTransformAnnotation(e: Emitter) => e }
- if(emitters.length > 1) {
- logger.warn("More than one emitter used which cannot be accurately represented" +
- "in the deprecated FirrtlExecutionResult: " + emitters.map(_.name).mkString(", "))
+ val emitters = options.collect { case RunFirrtlTransformAnnotation(e: Emitter) => e }
+ if (emitters.length > 1) {
+ logger.warn(
+ "More than one emitter used which cannot be accurately represented" +
+ "in the deprecated FirrtlExecutionResult: " + emitters.map(_.name).mkString(", ")
+ )
}
- val compilers = options.collect{ case CompilerAnnotation(c) => c }
+ val compilers = options.collect { case CompilerAnnotation(c) => c }
val emitType = emitters.headOption.orElse(compilers.headOption).map(_.name).getOrElse("N/A")
val form = emitters.headOption.orElse(compilers.headOption).map(_.outputForm).getOrElse(UnknownForm)
- options.collectFirst{ case a: FirrtlCircuitAnnotation => a.circuit } match {
+ options.collectFirst { case a: FirrtlCircuitAnnotation => a.circuit } match {
case None => FirrtlExecutionFailure("No circuit found in AnnotationSeq!")
- case Some(a) => FirrtlExecutionSuccess(
- emitType = emitType,
- emitted = emittedRes,
- circuitState = CircuitState(
- circuit = a,
- form = form,
- annotations = options,
- renames = None
- ))
+ case Some(a) =>
+ FirrtlExecutionSuccess(
+ emitType = emitType,
+ emitted = emittedRes,
+ circuitState = CircuitState(
+ circuit = a,
+ form = form,
+ annotations = options,
+ renames = None
+ )
+ )
}
}
}
diff --git a/src/main/scala/firrtl/stage/phases/AddCircuit.scala b/src/main/scala/firrtl/stage/phases/AddCircuit.scala
index f3ff3372..c00e71b6 100644
--- a/src/main/scala/firrtl/stage/phases/AddCircuit.scala
+++ b/src/main/scala/firrtl/stage/phases/AddCircuit.scala
@@ -39,11 +39,10 @@ class AddCircuit extends Phase {
* @throws $infoModeException
*/
private def infoMode(annotations: AnnotationSeq): Parser.InfoMode = {
- val infoModeAnnotation = annotations
- .collectFirst{ case a: InfoModeAnnotation => a }
- .getOrElse { throw new PhasePrerequisiteException(
- "An InfoModeAnnotation must be present (did you forget to run AddDefaults?)") }
- val infoSource = annotations.collectFirst{
+ val infoModeAnnotation = annotations.collectFirst { case a: InfoModeAnnotation => a }.getOrElse {
+ throw new PhasePrerequisiteException("An InfoModeAnnotation must be present (did you forget to run AddDefaults?)")
+ }
+ val infoSource = annotations.collectFirst {
case FirrtlFileAnnotation(f) => f
case _: FirrtlSourceAnnotation => "anonymous source"
}.getOrElse("not defined")
@@ -58,7 +57,7 @@ class AddCircuit extends Phase {
lazy val info = infoMode(annotations)
annotations.map {
case a: CircuitOption => a.toCircuit(info)
- case a => a
+ case a => a
}
}
diff --git a/src/main/scala/firrtl/stage/phases/AddDefaults.scala b/src/main/scala/firrtl/stage/phases/AddDefaults.scala
index d4c5bab4..9f4163cc 100644
--- a/src/main/scala/firrtl/stage/phases/AddDefaults.scala
+++ b/src/main/scala/firrtl/stage/phases/AddDefaults.scala
@@ -26,21 +26,21 @@ class AddDefaults extends Phase {
var bb, c, em, im = true
annotations.foreach {
case _: BlackBoxTargetDirAnno => bb = false
- case _: CompilerAnnotation => c = false
- case _: InfoModeAnnotation => im = false
- case RunFirrtlTransformAnnotation(_ : firrtl.Emitter) => em = false
+ case _: CompilerAnnotation => c = false
+ case _: InfoModeAnnotation => im = false
+ case RunFirrtlTransformAnnotation(_: firrtl.Emitter) => em = false
case _ =>
}
val default = new FirrtlOptions()
- val targetDir = annotations
- .collectFirst { case d: TargetDirAnnotation => d }
- .getOrElse(TargetDirAnnotation()).directory
-
- (if (bb) Seq(BlackBoxTargetDirAnno(targetDir)) else Seq() ) ++
- // if there is no compiler or emitter specified, add the default emitter
- (if (c && em) Seq(RunFirrtlTransformAnnotation(DefaultEmitterTarget)) else Seq() ) ++
- (if (im) Seq(InfoModeAnnotation()) else Seq() ) ++
+ val targetDir = annotations.collectFirst { case d: TargetDirAnnotation => d }
+ .getOrElse(TargetDirAnnotation())
+ .directory
+
+ (if (bb) Seq(BlackBoxTargetDirAnno(targetDir)) else Seq()) ++
+ // if there is no compiler or emitter specified, add the default emitter
+ (if (c && em) Seq(RunFirrtlTransformAnnotation(DefaultEmitterTarget)) else Seq()) ++
+ (if (im) Seq(InfoModeAnnotation()) else Seq()) ++
annotations
}
diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala
index edf62c3a..3c0a2388 100644
--- a/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala
+++ b/src/main/scala/firrtl/stage/phases/AddImplicitEmitter.scala
@@ -18,16 +18,19 @@ class AddImplicitEmitter extends Phase {
override def invalidates(a: Phase) = false
def transform(annos: AnnotationSeq): AnnotationSeq = {
- val emit = annos.collectFirst{ case a: EmitAnnotation => a }
- val emitter = annos.collectFirst{ case RunFirrtlTransformAnnotation(e : Emitter) => e }
- val compiler = annos.collectFirst{ case CompilerAnnotation(a) => a }
+ val emit = annos.collectFirst { case a: EmitAnnotation => a }
+ val emitter = annos.collectFirst { case RunFirrtlTransformAnnotation(e: Emitter) => e }
+ val compiler = annos.collectFirst { case CompilerAnnotation(a) => a }
if (emit.isEmpty && (compiler.nonEmpty || emitter.nonEmpty)) {
- annos.flatMap{
- case a: CompilerAnnotation => Seq(a,
- RunFirrtlTransformAnnotation(compiler.get.emitter),
- EmitCircuitAnnotation(compiler.get.emitter.getClass))
- case a @ RunFirrtlTransformAnnotation(e : Emitter) => Seq(a, EmitCircuitAnnotation(e.getClass))
+ annos.flatMap {
+ case a: CompilerAnnotation =>
+ Seq(
+ a,
+ RunFirrtlTransformAnnotation(compiler.get.emitter),
+ EmitCircuitAnnotation(compiler.get.emitter.getClass)
+ )
+ case a @ RunFirrtlTransformAnnotation(e: Emitter) => Seq(a, EmitCircuitAnnotation(e.getClass))
case a => Some(a)
}
} else {
diff --git a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala
index f57e9c39..10af13d5 100644
--- a/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala
+++ b/src/main/scala/firrtl/stage/phases/AddImplicitOutputFile.scala
@@ -30,13 +30,12 @@ class AddImplicitOutputFile extends Phase {
/** Add an [[OutputFileAnnotation]] to an [[AnnotationSeq]] */
def transform(annotations: AnnotationSeq): AnnotationSeq =
- annotations
- .collectFirst { case _: OutputFileAnnotation | _: EmitAllModulesAnnotation => annotations }
- .getOrElse {
- val topName = Viewer[FirrtlOptions].view(annotations)
- .firrtlCircuit
- .map(_.main)
- .getOrElse("a")
- OutputFileAnnotation(topName) +: annotations
- }
+ annotations.collectFirst { case _: OutputFileAnnotation | _: EmitAllModulesAnnotation => annotations }.getOrElse {
+ val topName = Viewer[FirrtlOptions]
+ .view(annotations)
+ .firrtlCircuit
+ .map(_.main)
+ .getOrElse("a")
+ OutputFileAnnotation(topName) +: annotations
+ }
}
diff --git a/src/main/scala/firrtl/stage/phases/CatchExceptions.scala b/src/main/scala/firrtl/stage/phases/CatchExceptions.scala
index f65ed481..5181653b 100644
--- a/src/main/scala/firrtl/stage/phases/CatchExceptions.scala
+++ b/src/main/scala/firrtl/stage/phases/CatchExceptions.scala
@@ -4,8 +4,12 @@ package firrtl.stage.phases
import firrtl.options.{DependencyManagerException, OptionsException, Phase, PhaseException}
import firrtl.{
- AnnotationSeq, CustomTransformException, FIRRTLException,
- FirrtlInternalException, FirrtlUserException, Utils
+ AnnotationSeq,
+ CustomTransformException,
+ FIRRTLException,
+ FirrtlInternalException,
+ FirrtlUserException,
+ Utils
}
import scala.util.control.ControlThrowable
@@ -27,15 +31,15 @@ class CatchExceptions(val underlying: Phase) extends Phase {
} catch {
/* Rethrow the exceptions which are expected or due to the runtime environment (out of memory, stack overflow, etc.).
* Any UNEXPECTED exceptions should be treated as internal errors. */
- case p @ (_: ControlThrowable | _: FIRRTLException | _: OptionsException | _: FirrtlUserException
- | _: FirrtlInternalException | _: PhaseException | _: DependencyManagerException) => throw p
+ case p @ (_: ControlThrowable | _: FIRRTLException | _: OptionsException | _: FirrtlUserException |
+ _: FirrtlInternalException | _: PhaseException | _: DependencyManagerException) =>
+ throw p
case CustomTransformException(cause) => throw cause
case e: Exception => Utils.throwInternalError(exception = Some(e))
}
}
-
object CatchExceptions {
def apply(p: Phase): CatchExceptions = new CatchExceptions(p)
diff --git a/src/main/scala/firrtl/stage/phases/Checks.scala b/src/main/scala/firrtl/stage/phases/Checks.scala
index 7ecdc47e..6576d311 100644
--- a/src/main/scala/firrtl/stage/phases/Checks.scala
+++ b/src/main/scala/firrtl/stage/phases/Checks.scala
@@ -32,67 +32,70 @@ class Checks extends Phase {
*/
def transform(annos: AnnotationSeq): AnnotationSeq = {
val inF, inS, eam, ec, outF, comp, emitter, im, inC = collection.mutable.ListBuffer[Annotation]()
- annos.foreach(
- _ match {
- case a: FirrtlFileAnnotation => a +=: inF
- case a: FirrtlSourceAnnotation => a +=: inS
- case a: EmitAllModulesAnnotation => a +=: eam
- case a: EmitCircuitAnnotation => a +=: ec
- case a: OutputFileAnnotation => a +=: outF
- case a: CompilerAnnotation => a +=: comp
- case a: InfoModeAnnotation => a +=: im
- case a: FirrtlCircuitAnnotation => a +=: inC
- case a @ RunFirrtlTransformAnnotation(_ : firrtl.Emitter) => a +=: emitter
- case _ => })
+ annos.foreach(_ match {
+ case a: FirrtlFileAnnotation => a +=: inF
+ case a: FirrtlSourceAnnotation => a +=: inS
+ case a: EmitAllModulesAnnotation => a +=: eam
+ case a: EmitCircuitAnnotation => a +=: ec
+ case a: OutputFileAnnotation => a +=: outF
+ case a: CompilerAnnotation => a +=: comp
+ case a: InfoModeAnnotation => a +=: im
+ case a: FirrtlCircuitAnnotation => a +=: inC
+ case a @ RunFirrtlTransformAnnotation(_: firrtl.Emitter) => a +=: emitter
+ case _ =>
+ })
/* At this point, only a FIRRTL Circuit should exist */
if (inF.isEmpty && inS.isEmpty && inC.isEmpty) {
- throw new OptionsException(
- s"""|Unable to determine FIRRTL source to read. None of the following were found:
- | - an input file: -i, --input-file, FirrtlFileAnnotation
- | - FIRRTL source: --firrtl-source, FirrtlSourceAnnotation
- | - FIRRTL circuit: FirrtlCircuitAnnotation""".stripMargin )}
+ throw new OptionsException(s"""|Unable to determine FIRRTL source to read. None of the following were found:
+ | - an input file: -i, --input-file, FirrtlFileAnnotation
+ | - FIRRTL source: --firrtl-source, FirrtlSourceAnnotation
+ | - FIRRTL circuit: FirrtlCircuitAnnotation""".stripMargin)
+ }
/* Only one FIRRTL input can exist */
if (inF.size + inS.size + inC.size > 1) {
- throw new OptionsException(
- s"""|Multiply defined input FIRRTL sources. More than one of the following was found:
- | - an input file (${inF.size} times): -i, --input-file, FirrtlFileAnnotation
- | - FIRRTL source (${inS.size} times): --firrtl-source, FirrtlSourceAnnotation
- | - FIRRTL circuit (${inC.size} times): FirrtlCircuitAnnotation""".stripMargin )}
+ throw new OptionsException(s"""|Multiply defined input FIRRTL sources. More than one of the following was found:
+ | - an input file (${inF.size} times): -i, --input-file, FirrtlFileAnnotation
+ | - FIRRTL source (${inS.size} times): --firrtl-source, FirrtlSourceAnnotation
+ | - FIRRTL circuit (${inC.size} times): FirrtlCircuitAnnotation""".stripMargin)
+ }
/* Specifying an output file and one-file-per module conflict */
if (eam.nonEmpty && outF.nonEmpty) {
throw new OptionsException(
s"""|Output file is incompatible with emit all modules annotation, but multiples were found:
| - explicit output file (${outF.size} times): -o, --output-file, OutputFileAnnotation
- | - one file per module (${eam.size} times): -e, --emit-modules, EmitAllModulesAnnotation"""
- .stripMargin )}
+ | - one file per module (${eam.size} times): -e, --emit-modules, EmitAllModulesAnnotation""".stripMargin
+ )
+ }
/* Only one output file can be specified */
if (outF.size > 1) {
- val x = outF.map{ case OutputFileAnnotation(x) => x }
+ val x = outF.map { case OutputFileAnnotation(x) => x }
throw new OptionsException(
s"""|No more than one output file can be specified, but found '${x.mkString(", ")}' specified via:
- | - option or annotation: -o, --output-file, OutputFileAnnotation""".stripMargin) }
+ | - option or annotation: -o, --output-file, OutputFileAnnotation""".stripMargin
+ )
+ }
/* One mandatory compiler (or emitter) must be specified */
if (comp.size != 1 && emitter.isEmpty) {
- val x = comp.map{ case CompilerAnnotation(x) => x }
- val (msg, suggest) = if (comp.size == 0) { ("none found", "forget one of") }
- else { (s"""found '${x.mkString(", ")}'""", "use multiple of") }
- throw new OptionsException(
- s"""|Exactly one compiler must be specified, but $msg. Did you $suggest the following?
- | - an option or annotation: -X, --compiler, CompilerAnnotation""".stripMargin )}
+ val x = comp.map { case CompilerAnnotation(x) => x }
+ val (msg, suggest) = if (comp.size == 0) { ("none found", "forget one of") }
+ else { (s"""found '${x.mkString(", ")}'""", "use multiple of") }
+ throw new OptionsException(s"""|Exactly one compiler must be specified, but $msg. Did you $suggest the following?
+ | - an option or annotation: -X, --compiler, CompilerAnnotation""".stripMargin)
+ }
/* One mandatory info mode must be specified */
if (im.size != 1) {
- val x = im.map{ case InfoModeAnnotation(x) => x }
- val (msg, suggest) = if (im.size == 0) { ("none found", "forget one of") }
- else { (s"""found '${x.mkString(", ")}'""", "use multiple of") }
- throw new OptionsException(
- s"""|Exactly one info mode must be specified, but $msg. Did you $suggest the following?
- | - an option or annotation: --info-mode, InfoModeAnnotation""".stripMargin )}
+ val x = im.map { case InfoModeAnnotation(x) => x }
+ val (msg, suggest) = if (im.size == 0) { ("none found", "forget one of") }
+ else { (s"""found '${x.mkString(", ")}'""", "use multiple of") }
+ throw new OptionsException(s"""|Exactly one info mode must be specified, but $msg. Did you $suggest the following?
+ | - an option or annotation: --info-mode, InfoModeAnnotation""".stripMargin)
+ }
annos
}
diff --git a/src/main/scala/firrtl/stage/phases/Compiler.scala b/src/main/scala/firrtl/stage/phases/Compiler.scala
index b73e3058..0d1181a6 100644
--- a/src/main/scala/firrtl/stage/phases/Compiler.scala
+++ b/src/main/scala/firrtl/stage/phases/Compiler.scala
@@ -10,17 +10,17 @@ import firrtl.stage.TransformManager.TransformDependency
import scala.collection.mutable
/** An encoding of the information necessary to run the FIRRTL compiler once */
-private [stage] case class CompilerRun(
- stateIn: CircuitState,
- stateOut: Option[CircuitState],
+private[stage] case class CompilerRun(
+ stateIn: CircuitState,
+ stateOut: Option[CircuitState],
transforms: Seq[Transform],
- compiler: Option[FirrtlCompiler] )
+ compiler: Option[FirrtlCompiler])
/** An encoding of possible defaults for a [[CompilerRun]] */
-private [stage] case class Defaults(
+private[stage] case class Defaults(
annotations: AnnotationSeq = Seq.empty,
- transforms: Seq[Transform] = Seq.empty,
- compiler: Option[FirrtlCompiler] = None)
+ transforms: Seq[Transform] = Seq.empty,
+ compiler: Option[FirrtlCompiler] = None)
/** Runs the FIRRTL compilers on an [[AnnotationSeq]]. If the input [[AnnotationSeq]] contains more than one circuit
* (i.e., more than one [[firrtl.stage.FirrtlCircuitAnnotation FirrtlCircuitAnnotation]]), then annotations will be
@@ -45,11 +45,13 @@ private [stage] case class Defaults(
class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] {
override def prerequisites =
- Seq(Dependency[AddDefaults],
- Dependency[AddImplicitEmitter],
- Dependency[Checks],
- Dependency[AddCircuit],
- Dependency[AddImplicitOutputFile])
+ Seq(
+ Dependency[AddDefaults],
+ Dependency[AddImplicitEmitter],
+ Dependency[Checks],
+ Dependency[AddCircuit],
+ Dependency[AddImplicitOutputFile]
+ )
override def optionalPrerequisiteOf = Seq.empty
@@ -59,28 +61,30 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] {
protected def aToB(a: AnnotationSeq): Seq[CompilerRun] = {
var foundFirstCircuit = false
val c = mutable.ArrayBuffer.empty[CompilerRun]
- a.foldLeft(Defaults()){
+ a.foldLeft(Defaults()) {
case (d, FirrtlCircuitAnnotation(circuit)) =>
foundFirstCircuit = true
CompilerRun(CircuitState(circuit, ChirrtlForm, d.annotations, None), None, d.transforms, d.compiler) +=: c
d
- case (d, a) if foundFirstCircuit => a match {
- case RunFirrtlTransformAnnotation(transform) =>
- c(0) = c(0).copy(transforms = transform +: c(0).transforms)
- d
- case CompilerAnnotation(compiler) =>
- c(0) = c(0).copy(compiler = Some(compiler))
- d
- case annotation =>
- val state = c(0).stateIn
- c(0) = c(0).copy(stateIn = state.copy(annotations = annotation +: state.annotations))
- d
- }
- case (d, a) if !foundFirstCircuit => a match {
- case RunFirrtlTransformAnnotation(transform) => d.copy(transforms = transform +: d.transforms)
- case CompilerAnnotation(compiler) => d.copy(compiler = Some(compiler))
- case annotation => d.copy(annotations = annotation +: d.annotations)
- }
+ case (d, a) if foundFirstCircuit =>
+ a match {
+ case RunFirrtlTransformAnnotation(transform) =>
+ c(0) = c(0).copy(transforms = transform +: c(0).transforms)
+ d
+ case CompilerAnnotation(compiler) =>
+ c(0) = c(0).copy(compiler = Some(compiler))
+ d
+ case annotation =>
+ val state = c(0).stateIn
+ c(0) = c(0).copy(stateIn = state.copy(annotations = annotation +: state.annotations))
+ d
+ }
+ case (d, a) if !foundFirstCircuit =>
+ a match {
+ case RunFirrtlTransformAnnotation(transform) => d.copy(transforms = transform +: d.transforms)
+ case CompilerAnnotation(compiler) => d.copy(compiler = Some(compiler))
+ case annotation => d.copy(annotations = annotation +: d.annotations)
+ }
}
c.toSeq
}
@@ -89,7 +93,7 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] {
* removed ([[CompilerAnnotation]]s and [[RunFirrtlTransformAnnotation]]s).
*/
protected def bToA(b: Seq[CompilerRun]): AnnotationSeq =
- b.flatMap( bb => FirrtlCircuitAnnotation(bb.stateOut.get.circuit) +: bb.stateOut.get.annotations )
+ b.flatMap(bb => FirrtlCircuitAnnotation(bb.stateOut.get.circuit) +: bb.stateOut.get.annotations)
/** Run the FIRRTL compiler some number of times. If more than one run is specified, a parallel collection will be
* used.
@@ -98,9 +102,9 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] {
def f(c: CompilerRun): CompilerRun = {
val targets = c.compiler match {
case Some(d) => c.transforms.reverse.map(Dependency.fromTransform(_)) ++ compilerToTransforms(d)
- case None =>
+ case None =>
val hasEmitter = c.transforms.collectFirst { case _: firrtl.Emitter => true }.isDefined
- if(!hasEmitter) {
+ if (!hasEmitter) {
throw new PhasePrerequisiteException("No compiler specified!")
} else {
c.transforms.reverse.map(Dependency.fromTransform)
@@ -118,18 +122,19 @@ class Compiler extends Phase with Translator[AnnotationSeq, Seq[CompilerRun]] {
c.copy(stateOut = Some(annotationsOut))
}
- if (b.size <= 1) { b.map(f) } else {
- collection.parallel.immutable.ParVector(b :_*).par.map(f).seq
+ if (b.size <= 1) { b.map(f) }
+ else {
+ collection.parallel.immutable.ParVector(b: _*).par.map(f).seq
}
}
private def compilerToTransforms(a: FirrtlCompiler): Seq[TransformDependency] = a match {
- case _: firrtl.NoneCompiler => Forms.ChirrtlForm
- case _: firrtl.HighFirrtlCompiler => Forms.MinimalHighForm
- case _: firrtl.MiddleFirrtlCompiler => Forms.MidForm
- case _: firrtl.LowFirrtlCompiler => Forms.LowForm
+ case _: firrtl.NoneCompiler => Forms.ChirrtlForm
+ case _: firrtl.HighFirrtlCompiler => Forms.MinimalHighForm
+ case _: firrtl.MiddleFirrtlCompiler => Forms.MidForm
+ case _: firrtl.LowFirrtlCompiler => Forms.LowForm
case _: firrtl.VerilogCompiler | _: firrtl.SystemVerilogCompiler => Forms.LowFormOptimized
- case _: firrtl.MinimumVerilogCompiler => Forms.LowFormMinimumOptimized
+ case _: firrtl.MinimumVerilogCompiler => Forms.LowFormMinimumOptimized
}
}
diff --git a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala
index b149a791..0b558cc0 100644
--- a/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala
+++ b/src/main/scala/firrtl/stage/phases/DriverCompatibility.scala
@@ -47,7 +47,10 @@ object DriverCompatibility {
/** Holds the name of the top (main) module in an input circuit
* @param value top module name
*/
- @deprecated(""""top-name" is deprecated as part of the Stage/Phase refactor. Use explicit input/output files.""", "1.2")
+ @deprecated(
+ """"top-name" is deprecated as part of the Stage/Phase refactor. Use explicit input/output files.""",
+ "1.2"
+ )
case class TopNameAnnotation(topName: String) extends NoTargetAnnotation
object TopNameAnnotation {
@@ -57,7 +60,7 @@ object DriverCompatibility {
.abbr("tn")
.hidden
.unbounded
- .action( (_, _) => throw new OptionsException(optionRemoved("--top-name/-tn")) )
+ .action((_, _) => throw new OptionsException(optionRemoved("--top-name/-tn")))
}
/** Indicates that the implicit emitter, derived from a [[CompilerAnnotation]] should be an [[EmitAllModulesAnnotation]]
@@ -70,7 +73,7 @@ object DriverCompatibility {
.abbr("fsm")
.hidden
.unbounded
- .action( (_, _) => throw new OptionsException(optionRemoved("--split-modules/-fsm")) )
+ .action((_, _) => throw new OptionsException(optionRemoved("--split-modules/-fsm")))
}
@@ -84,13 +87,16 @@ object DriverCompatibility {
* @return the top module ''if it can be determined''
*/
private def topName(annotations: AnnotationSeq): Option[String] =
- annotations.collectFirst{ case TopNameAnnotation(n) => n }.orElse(
- annotations.collectFirst{ case FirrtlCircuitAnnotation(c) => c.main }.orElse(
- annotations.collectFirst{ case FirrtlSourceAnnotation(s) => Parser.parse(s).main }.orElse(
- annotations.collectFirst{ case FirrtlFileAnnotation(f) =>
- FirrtlStageUtils.getFileExtension(f) match {
- case ProtoBufFile => FromProto.fromFile(f).main
- case FirrtlFile => Parser.parse(FileUtils.getText(f)).main } } )))
+ annotations.collectFirst { case TopNameAnnotation(n) => n }
+ .orElse(annotations.collectFirst { case FirrtlCircuitAnnotation(c) => c.main }.orElse(annotations.collectFirst {
+ case FirrtlSourceAnnotation(s) => Parser.parse(s).main
+ }.orElse(annotations.collectFirst {
+ case FirrtlFileAnnotation(f) =>
+ FirrtlStageUtils.getFileExtension(f) match {
+ case ProtoBufFile => FromProto.fromFile(f).main
+ case FirrtlFile => Parser.parse(FileUtils.getText(f)).main
+ }
+ })))
/** Determine the target directory with the following precedence (highest to lowest):
* - Explicitly from the user-specified [[firrtl.options.TargetDirAnnotation TargetDirAnnotation]]
@@ -131,22 +137,27 @@ object DriverCompatibility {
override def invalidates(a: Phase) = false
/** Try to add an [[firrtl.options.InputAnnotationFileAnnotation InputAnnotationFileAnnotation]] implicitly specified by
- * an [[AnnotationSeq]]. */
- def transform(annotations: AnnotationSeq): AnnotationSeq = annotations
- .collectFirst{ case a: InputAnnotationFileAnnotation => a } match {
- case Some(_) => annotations
- case None => topName(annotations) match {
+ * an [[AnnotationSeq]].
+ */
+ def transform(annotations: AnnotationSeq): AnnotationSeq = annotations.collectFirst {
+ case a: InputAnnotationFileAnnotation => a
+ } match {
+ case Some(_) => annotations
+ case None =>
+ topName(annotations) match {
case Some(n) =>
val filename = targetDir(annotations) + "/" + n + ".anno"
if (new File(filename).exists) {
StageUtils.dramaticWarning(
- s"Implicit reading of the annotation file is deprecated! Use an explict --annotation-file argument.")
+ s"Implicit reading of the annotation file is deprecated! Use an explict --annotation-file argument."
+ )
annotations :+ InputAnnotationFileAnnotation(filename)
} else {
annotations
}
case None => annotations
- } }
+ }
+ }
}
@@ -180,7 +191,8 @@ object DriverCompatibility {
annotations
} else if (main.nonEmpty) {
StageUtils.dramaticWarning(
- s"Implicit reading of the input file is deprecated! Use an explict --input-file argument.")
+ s"Implicit reading of the input file is deprecated! Use an explict --input-file argument."
+ )
FirrtlFileAnnotation(Viewer[StageOptions].view(annotations).getBuildFileName(s"${main.get}.fir")) +: annotations
} else {
annotations
@@ -194,8 +206,10 @@ object DriverCompatibility {
* this adds an [[EmitCircuitAnnotation]]. This replicates old behavior where specifying a compiler automatically
* meant that an emitter would also run.
*/
- @deprecated("""AddImplicitEmitter should only be used to build Driver compatibility wrappers. Switch to Stage.""",
- "1.2")
+ @deprecated(
+ """AddImplicitEmitter should only be used to build Driver compatibility wrappers. Switch to Stage.""",
+ "1.2"
+ )
class AddImplicitEmitter extends Phase {
override def prerequisites = Seq.empty
@@ -206,13 +220,13 @@ object DriverCompatibility {
/** Add one [[EmitAnnotation]] foreach [[CompilerAnnotation]]. */
def transform(annotations: AnnotationSeq): AnnotationSeq = {
- val splitModules = annotations.collectFirst{ case a: EmitOneFilePerModuleAnnotation.type => a }.isDefined
+ val splitModules = annotations.collectFirst { case a: EmitOneFilePerModuleAnnotation.type => a }.isDefined
annotations.flatMap {
case a @ CompilerAnnotation(c) =>
val b = RunFirrtlTransformAnnotation(a.compiler.emitter)
if (splitModules) { Seq(a, b, EmitAllModulesAnnotation(c.emitter.getClass)) }
- else { Seq(a, b, EmitCircuitAnnotation (c.emitter.getClass)) }
+ else { Seq(a, b, EmitCircuitAnnotation(c.emitter.getClass)) }
case a => Seq(a)
}
}
@@ -222,8 +236,10 @@ object DriverCompatibility {
/** Adds an [[OutputFileAnnotation]] derived from a [[TopNameAnnotation]] if no [[OutputFileAnnotation]] already
* exists. If no [[TopNameAnnotation]] exists, then no [[OutputFileAnnotation]] is added.
*/
- @deprecated("""AddImplicitOutputFile should only be used to build Driver compatibility wrappers. Switch to Stage.""",
- "1.2")
+ @deprecated(
+ """AddImplicitOutputFile should only be used to build Driver compatibility wrappers. Switch to Stage.""",
+ "1.2"
+ )
class AddImplicitOutputFile extends Phase {
override def prerequisites = Seq(Dependency[AddImplicitFirrtlFile])
@@ -234,9 +250,9 @@ object DriverCompatibility {
/** Add an [[OutputFileAnnotation]] derived from a [[TopNameAnnotation]] if needed. */
def transform(annotations: AnnotationSeq): AnnotationSeq = {
- val hasOutputFile = annotations
- .collectFirst{ case a @(_: EmitOneFilePerModuleAnnotation.type | _: OutputFileAnnotation) => a }
- .isDefined
+ val hasOutputFile = annotations.collectFirst {
+ case a @ (_: EmitOneFilePerModuleAnnotation.type | _: OutputFileAnnotation) => a
+ }.isDefined
val top = topName(annotations)
if (!hasOutputFile && top.isDefined) {
diff --git a/src/main/scala/firrtl/stage/phases/WriteEmitted.scala b/src/main/scala/firrtl/stage/phases/WriteEmitted.scala
index e2db2a94..614ce62f 100644
--- a/src/main/scala/firrtl/stage/phases/WriteEmitted.scala
+++ b/src/main/scala/firrtl/stage/phases/WriteEmitted.scala
@@ -2,7 +2,7 @@
package firrtl.stage.phases
-import firrtl.{AnnotationSeq, EmittedModuleAnnotation, EmittedCircuitAnnotation}
+import firrtl.{AnnotationSeq, EmittedCircuitAnnotation, EmittedModuleAnnotation}
import firrtl.options.{Phase, StageOptions, Viewer}
import firrtl.stage.FirrtlOptions
@@ -24,8 +24,11 @@ import java.io.PrintWriter
*
* Any annotations written to files will be deleted.
*/
-@deprecated("Annotations that mixin the CustomFileEmission trait are automatically serialized by stages." +
- "This will be removed in FIRRTL 1.5", "FIRRTL 1.4.0")
+@deprecated(
+ "Annotations that mixin the CustomFileEmission trait are automatically serialized by stages." +
+ "This will be removed in FIRRTL 1.5",
+ "FIRRTL 1.4.0"
+)
class WriteEmitted extends Phase {
override def prerequisites = Seq.empty
@@ -47,7 +50,8 @@ class WriteEmitted extends Phase {
None
case a: EmittedCircuitAnnotation[_] =>
val pw = new PrintWriter(
- sopts.getBuildFileName(fopts.outputFileName.getOrElse(a.value.name), Some(a.value.outputSuffix)))
+ sopts.getBuildFileName(fopts.outputFileName.getOrElse(a.value.name), Some(a.value.outputSuffix))
+ )
pw.write(a.value.value)
pw.close()
None
diff --git a/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala b/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala
index ebcd7cfb..742d2b7e 100644
--- a/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala
+++ b/src/main/scala/firrtl/stage/transforms/CatchCustomTransformExceptions.scala
@@ -9,7 +9,8 @@ class CatchCustomTransformExceptions(val underlying: Transform) extends Transfor
override def execute(c: CircuitState): CircuitState = try {
underlying.transform(c)
} catch {
- case e: Exception if CatchCustomTransformExceptions.isCustomTransform(trueUnderlying) => throw CustomTransformException(e)
+ case e: Exception if CatchCustomTransformExceptions.isCustomTransform(trueUnderlying) =>
+ throw CustomTransformException(e)
}
}
diff --git a/src/main/scala/firrtl/stage/transforms/Compiler.scala b/src/main/scala/firrtl/stage/transforms/Compiler.scala
index 9988e443..251f4387 100644
--- a/src/main/scala/firrtl/stage/transforms/Compiler.scala
+++ b/src/main/scala/firrtl/stage/transforms/Compiler.scala
@@ -7,12 +7,12 @@ import firrtl.stage.TransformManager
import firrtl.{Transform, VerilogEmitter}
/** A [[firrtl.stage.TransformManager TransformManager]] of
- *
*/
class Compiler(
- targets: Seq[TransformManager.TransformDependency],
+ targets: Seq[TransformManager.TransformDependency],
currentState: Seq[TransformManager.TransformDependency] = Seq.empty,
- knownObjects: Set[Transform] = Set.empty) extends TransformManager(targets, currentState, knownObjects) {
+ knownObjects: Set[Transform] = Set.empty)
+ extends TransformManager(targets, currentState, knownObjects) {
override val wrappers = Seq(
(a: Transform) => ExpandPrepares(a),
@@ -21,9 +21,10 @@ class Compiler(
)
override def customPrintHandling(
- tab: String,
+ tab: String,
charSet: CharSet,
- size: Int): Option[PartialFunction[(Transform, Int), Seq[String]]] = {
+ size: Int
+ ): Option[PartialFunction[(Transform, Int), Seq[String]]] = {
val (l, n, c) = (charSet.lastNode, charSet.notLastNode, charSet.continuation)
val last = size - 1
diff --git a/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala b/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala
index 7a0621e4..d0514f15 100644
--- a/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala
+++ b/src/main/scala/firrtl/stage/transforms/ExpandPrepares.scala
@@ -8,8 +8,10 @@ class ExpandPrepares(val underlying: Transform) extends Transform with WrappedTr
/* Assert that this is not wrapping other transforms. */
underlying match {
- case _: WrappedTransform => throw new Exception(
- s"'ExpandPrepares' must not wrap other 'WrappedTransforms', but wraps '${underlying.getClass.getName}'")
+ case _: WrappedTransform =>
+ throw new Exception(
+ s"'ExpandPrepares' must not wrap other 'WrappedTransforms', but wraps '${underlying.getClass.getName}'"
+ )
case _ =>
}
diff --git a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala
index 913ab5d2..c268332a 100644
--- a/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala
+++ b/src/main/scala/firrtl/stage/transforms/TrackTransforms.scala
@@ -8,8 +8,10 @@ import firrtl.options.{Dependency, DependencyManagerException}
case class TransformHistoryAnnotation(history: Seq[Transform], state: Set[Transform]) extends NoTargetAnnotation {
- def add(transform: Transform,
- invalidates: (Transform) => Boolean = (a: Transform) => false): TransformHistoryAnnotation =
+ def add(
+ transform: Transform,
+ invalidates: (Transform) => Boolean = (a: Transform) => false
+ ): TransformHistoryAnnotation =
this.copy(
history = transform +: this.history,
state = (this.state + transform).filterNot(invalidates)
@@ -44,8 +46,7 @@ class TrackTransforms(val underlying: Transform) extends Transform with WrappedT
}
override def execute(c: CircuitState): CircuitState = {
- val state = c.annotations
- .collectFirst{ case TransformHistoryAnnotation(_, state) => state }
+ val state = c.annotations.collectFirst { case TransformHistoryAnnotation(_, state) => state }
.getOrElse(Set.empty[Transform])
.map(Dependency.fromTransform(_))
@@ -53,7 +54,8 @@ class TrackTransforms(val underlying: Transform) extends Transform with WrappedT
throw new DependencyManagerException(
s"""|Tried to execute Transform '$trueUnderlying' for which run-time prerequisites were not satisfied:
| state: ${state.mkString("\n -", "\n -", "")}
- | prerequisites: ${trueUnderlying.prerequisites.mkString("\n -", "\n -", "")}""".stripMargin)
+ | prerequisites: ${trueUnderlying.prerequisites.mkString("\n -", "\n -", "")}""".stripMargin
+ )
}
val out = underlying.transform(c)
diff --git a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala
index cc0fbc6f..e36eef9b 100644
--- a/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala
+++ b/src/main/scala/firrtl/stage/transforms/UpdateAnnotations.scala
@@ -5,7 +5,9 @@ package firrtl.stage.transforms
import firrtl.{CircuitState, Transform}
import firrtl.options.Translator
-class UpdateAnnotations(val underlying: Transform) extends Transform with WrappedTransform
+class UpdateAnnotations(val underlying: Transform)
+ extends Transform
+ with WrappedTransform
with Translator[CircuitState, (CircuitState, CircuitState)] {
override def execute(c: CircuitState): CircuitState = underlying.transform(c)
diff --git a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala
index a57973d5..5000e07a 100644
--- a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala
+++ b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala
@@ -2,7 +2,7 @@
package firrtl.transforms
-import java.io.{File, FileNotFoundException, FileInputStream, FileOutputStream, PrintWriter}
+import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream, PrintWriter}
import firrtl._
import firrtl.annotations._
@@ -11,31 +11,32 @@ import scala.collection.immutable.ListSet
sealed trait BlackBoxHelperAnno extends Annotation
-case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno
- with NoTargetAnnotation {
+case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno with NoTargetAnnotation {
override def serialize: String = s"targetDir\n$targetDir"
}
-case class BlackBoxResourceAnno(target: ModuleName, resourceId: String) extends BlackBoxHelperAnno
+case class BlackBoxResourceAnno(target: ModuleName, resourceId: String)
+ extends BlackBoxHelperAnno
with SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(target = n)
override def serialize: String = s"resource\n$resourceId"
}
-case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String) extends BlackBoxHelperAnno
+case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String)
+ extends BlackBoxHelperAnno
with SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(target = n)
override def serialize: String = s"inline\n$name\n$text"
}
-case class BlackBoxPathAnno(target: ModuleName, path: String) extends BlackBoxHelperAnno
+case class BlackBoxPathAnno(target: ModuleName, path: String)
+ extends BlackBoxHelperAnno
with SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(target = n)
override def serialize: String = s"path\n$path"
}
-case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno
- with NoTargetAnnotation {
+case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno with NoTargetAnnotation {
override def serialize: String = s"resourceFileName\n$resourceFileName"
}
@@ -43,8 +44,10 @@ case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackB
* @param fileName the name of the BlackBox file (only used for error message generation)
* @param e an underlying exception that generated this
*/
-class BlackBoxNotFoundException(fileName: String, message: String) extends FirrtlUserException(
- s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message")
+class BlackBoxNotFoundException(fileName: String, message: String)
+ extends FirrtlUserException(
+ s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message"
+ )
/** Handle source for Verilog ExtModules (BlackBoxes)
*
@@ -72,15 +75,16 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration {
*/
def collectAnnos(annos: Seq[Annotation]): (ListSet[BlackBoxHelperAnno], File, File) =
annos.foldLeft((ListSet.empty[BlackBoxHelperAnno], DefaultTargetDir, new File(defaultFileListName))) {
- case ((acc, tdir, flistName), anno) => anno match {
- case BlackBoxTargetDirAnno(dir) =>
- val targetDir = new File(dir)
- if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) }
- (acc, targetDir, flistName)
- case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName))
- case a: BlackBoxHelperAnno => (acc + a, tdir, flistName)
- case _ => (acc, tdir, flistName)
- }
+ case ((acc, tdir, flistName), anno) =>
+ anno match {
+ case BlackBoxTargetDirAnno(dir) =>
+ val targetDir = new File(dir)
+ if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) }
+ (acc, targetDir, flistName)
+ case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName))
+ case a: BlackBoxHelperAnno => (acc + a, tdir, flistName)
+ case _ => (acc, tdir, flistName)
+ }
}
/**
@@ -112,14 +116,15 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration {
case BlackBoxInlineAnno(_, name, text) =>
val outFile = new File(targetDir, name)
(text, outFile)
- }.map { case (text, file) =>
- writeTextToFile(text, file)
- file
+ }.map {
+ case (text, file) =>
+ writeTextToFile(text, file)
+ file
}
// Issue #917 - We don't want to list Verilog header files ("*.vh") in our file list - they will automatically be included by reference.
def isHeader(name: String) = name.endsWith(".h") || name.endsWith(".vh") || name.endsWith(".svh")
- val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot{ f => isHeader(f.getName()) }
+ val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot { f => isHeader(f.getName()) }
val filelistFile = if (flistName.isAbsolute()) flistName else new File(targetDir, flistName.getName())
// We need the canonical path here, so verilator will create a path to the file that works from the targetDir,
@@ -137,12 +142,14 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration {
}
object BlackBoxSourceHelper {
+
/** Safely access a file converting [[FileNotFoundException]]s and [[NullPointerException]]s into
* [[BlackBoxNotFoundException]]s
* @param fileName the name of the file to be accessed (only used for error message generation)
* @param code some code to run
*/
- private def safeFile[A](fileName: String)(code: => A) = try { code } catch {
+ private def safeFile[A](fileName: String)(code: => A) = try { code }
+ catch {
case e @ (_: FileNotFoundException | _: NullPointerException) =>
throw new BlackBoxNotFoundException(fileName, e.getMessage)
}
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index 6403be23..ee4c1d0b 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -24,6 +24,7 @@ import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
object LogicNode {
+
/**
* Construct a LogicNode from a *Low FIRRTL* reference or subfield that refers to a component.
* Since aggregate types appear in Low FIRRTL only as the full types of instances or memories,
@@ -39,11 +40,11 @@ object LogicNode {
case s: WSubField =>
s.expr match {
case modref: WRef =>
- LogicNode(s.name,Some(modref.name))
+ LogicNode(s.name, Some(modref.name))
case memport: WSubField =>
memport.expr match {
case memref: WRef =>
- LogicNode(s.name,Some(memref.name),Some(memport.name))
+ LogicNode(s.name, Some(memref.name), Some(memport.name))
case _ => throwInternalError(s"LogicNode: unrecognized subsubfield expression - $memport")
}
case _ => throwInternalError(s"LogicNode: unrecognized subfield expression - $s")
@@ -56,9 +57,8 @@ object CheckCombLoops {
type ConnMap = DiGraph[LogicNode] with EdgeData[LogicNode, Info]
type MutableConnMap = MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
-
- class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException(
- s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))
+ class CombLoopException(info: Info, mname: String, cycle: Seq[String])
+ extends PassException(s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))
}
case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation
@@ -73,7 +73,7 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge
override def update(renames: RenameMap): Seq[Annotation] = {
val sources = renames.get(source).getOrElse(Seq(source))
val sinks = renames.get(sink).getOrElse(Seq(sink))
- val paths = sources flatMap { s => sinks.map((s, _)) }
+ val paths = sources.flatMap { s => sinks.map((s, _)) }
paths.collect {
case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink)
}
@@ -82,8 +82,8 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge
case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget]) extends Annotation {
override def update(renames: RenameMap): Seq[Annotation] = {
- val newSources = sources.flatMap { s => renames(s) }.collect {case x: ReferenceTarget if x.isLocal => x}
- val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x}
+ val newSources = sources.flatMap { s => renames(s) }.collect { case x: ReferenceTarget if x.isLocal => x }
+ val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x }
newSinks.map(snk => CombinationalPath(snk, newSources))
}
}
@@ -98,14 +98,10 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget
* @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules
* @note The pass will throw exceptions on "false paths"
*/
-class CheckCombLoops extends Transform
- with RegisteredTransform
- with DependencyAPIMigration {
+class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
- Dependency(firrtl.transforms.RemoveReset) )
+ Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset))
override def optionalPrerequisites = Seq.empty
@@ -119,17 +115,21 @@ class CheckCombLoops extends Transform
new ShellOption[Unit](
longOption = "no-check-comb-loops",
toAnnotationSeq = (_: Unit) => Seq(DontCheckCombLoopsAnnotation),
- helpText = "Disable combinational loop checking" ) )
+ helpText = "Disable combinational loop checking"
+ )
+ )
private def getExprDeps(deps: MutableConnMap, v: LogicNode, info: Info)(e: Expression): Unit = e match {
- case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info)
+ case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info)
case s: WSubField => deps.addEdgeIfValid(v, LogicNode(s), info)
case _ => e.foreach(getExprDeps(deps, v, info))
}
private def getStmtDeps(
simplifiedModules: mutable.Map[String, AbstractConnMap],
- deps: MutableConnMap)(s: Statement): Unit = s match {
+ deps: MutableConnMap
+ )(s: Statement
+ ): Unit = s match {
case Connect(info, loc, expr) =>
val lhs = LogicNode(loc)
if (deps.contains(lhs)) {
@@ -152,9 +152,9 @@ class CheckCombLoops extends Transform
case i: WDefInstance =>
val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
iGraph.getVertices.foreach(deps.addVertex(_))
- iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v, _) } })
case _ =>
- s.foreach(getStmtDeps(simplifiedModules,deps))
+ s.foreach(getStmtDeps(simplifiedModules, deps))
}
// Pretty-print a LogicNode with a prepended hierarchical path
@@ -169,24 +169,26 @@ class CheckCombLoops extends Transform
* recovered.
*/
private def expandInstancePaths(
- m: String,
+ m: String,
moduleGraphs: mutable.Map[String, ConnMap],
- moduleDeps: Map[String, Map[String, String]],
- hierPrefix: Seq[String],
- path: Seq[LogicNode]): Seq[String] = {
+ moduleDeps: Map[String, Map[String, String]],
+ hierPrefix: Seq[String],
+ path: Seq[LogicNode]
+ ): Seq[String] = {
// Recover info from edge data, add to error string
def info(u: LogicNode, v: LogicNode): String =
moduleGraphs(m).getEdgeData(u, v).map(_.toString).mkString("\t", "", "")
// lhs comes after rhs
- val pathNodes = (path zip path.tail) map { case (rhs, lhs) =>
- if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) {
- val child = moduleDeps(m)(lhs.inst.get)
- val newHierPrefix = hierPrefix :+ lhs.inst.get
- val subpath = moduleGraphs(child).path(lhs.copy(inst=None),rhs.copy(inst=None)).reverse
- expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath)
- } else {
- Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs))
- }
+ val pathNodes = (path.zip(path.tail)).map {
+ case (rhs, lhs) =>
+ if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) {
+ val child = moduleDeps(m)(lhs.inst.get)
+ val newHierPrefix = hierPrefix :+ lhs.inst.get
+ val subpath = moduleGraphs(child).path(lhs.copy(inst = None), rhs.copy(inst = None)).reverse
+ expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath)
+ } else {
+ Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs))
+ }
}
pathNodes.flatten
}
@@ -238,12 +240,13 @@ class CheckCombLoops extends Transform
val errors = new Errors()
val extModulePaths = state.annotations.groupBy {
case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module)
- case ann: Annotation => CircuitTarget(c.main)
+ case ann: Annotation => CircuitTarget(c.main)
}
- val moduleMap = c.modules.map({m => (m.name,m) }).toMap
+ val moduleMap = c.modules.map({ m => (m.name, m) }).toMap
val iGraph = InstanceKeyGraph(c).graph
- val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
- val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
+ val moduleDeps =
+ iGraph.getEdgeMap.map({ case (k, v) => (k.module, (v.map { i => (i.name, i.module) }).toMap) }).toMap
+ val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse.map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String, ConnMap]
val simplifiedModuleGraphs = new mutable.HashMap[String, AbstractConnMap]
topoSortedModules.foreach {
@@ -252,7 +255,8 @@ class CheckCombLoops extends Transform
val extModuleDeps = new MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
portSet.foreach(extModuleDeps.addVertex(_))
extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect {
- case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
+ case a: ExtModulePathAnnotation =>
+ extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
}
moduleGraphs(em.name) = extModuleDeps
simplifiedModuleGraphs(em.name) = extModuleDeps.simplify(portSet)
@@ -270,7 +274,7 @@ class CheckCombLoops extends Transform
for (scc <- internalDeps.findSCCs.filter(_.length > 1)) {
val sccSubgraph = internalDeps.subgraph(scc.toSet)
val cycle = findCycleInSCC(sccSubgraph)
- (cycle zip cycle.tail).foreach({ case (a,b) => require(internalDeps.getEdges(a).contains(b)) })
+ (cycle.zip(cycle.tail)).foreach({ case (a, b) => require(internalDeps.getEdges(a).contains(b)) })
// Reverse to make sure LHS comes after RHS, print repeated vertex at start for legibility
val intuitiveCycle = cycle.reverse
val repeatedInitial = prettyPrintAbsoluteRef(Seq(m.name), intuitiveCycle.head)
@@ -280,10 +284,11 @@ class CheckCombLoops extends Transform
case m => throwInternalError(s"Module ${m.name} has unrecognized type")
}
val mt = ModuleTarget(c.main, c.main)
- val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty =>
- val sink = mt.ref(from.name)
- val sources = tos.map(to => mt.ref(to.name))
- CombinationalPath(sink, sources.toSeq)
+ val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect {
+ case (from, tos) if tos.nonEmpty =>
+ val sink = mt.ref(from.name)
+ val sources = tos.map(to => mt.ref(to.name))
+ CombinationalPath(sink, sources.toSeq)
}
(state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs, moduleGraphs)
}
@@ -291,7 +296,7 @@ class CheckCombLoops extends Transform
/**
* Returns a Map from Module name to port connectivity
*/
- def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = {
+ def analyze(state: CircuitState): collection.Map[String, DiGraph[String]] = {
val (result, errors, connectivity, _) = run(state)
connectivity.map {
case (k, v) => (k, v.transformNodes(ln => ln.name))
@@ -301,7 +306,7 @@ class CheckCombLoops extends Transform
/**
* Returns a Map from Module name to complete netlist connectivity
*/
- def analyzeFull(state: CircuitState): collection.Map[String,DiGraph[LogicNode]] = {
+ def analyzeFull(state: CircuitState): collection.Map[String, DiGraph[LogicNode]] = {
run(state)._4
}
diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala
index 7fa01e46..3014d0e3 100644
--- a/src/main/scala/firrtl/transforms/CombineCats.scala
+++ b/src/main/scala/firrtl/transforms/CombineCats.scala
@@ -1,4 +1,3 @@
-
package firrtl
package transforms
@@ -14,26 +13,30 @@ import scala.collection.mutable
case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation
object CombineCats {
+
/** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them paired with their Cat length */
type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)]
def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match {
- case cat@DoPrim(Cat, args, _, _) =>
+ case cat @ DoPrim(Cat, args, _, _) =>
val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head)
val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1))
(a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression])
case other =>
- netlist.get(we(expr)).collect {
- case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat)
- }.getOrElse((1, other))
+ netlist
+ .get(we(expr))
+ .collect {
+ case (len, cat @ DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat)
+ }
+ .getOrElse((1, other))
}
def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = {
stmt.map(onStmt(maxCatLen, netlist)) match {
- case node@DefNode(_, name, value) =>
+ case node @ DefNode(_, name, value) =>
val catLenAndVal = value match {
- case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat)
- case other => (1, other)
+ case cat @ DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat)
+ case other => (1, other)
}
netlist(we(WRef(name))) = catLenAndVal
node.copy(value = catLenAndVal._2)
@@ -55,16 +58,16 @@ object CombineCats {
class CombineCats extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions) )
+ Seq(
+ Dependency(passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions)
+ )
override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Seq(
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = false
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index ce36dd72..dc9b2bbe 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -28,7 +28,7 @@ object ConstantPropagation {
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
- case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
+ case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
}
@@ -44,38 +44,40 @@ object ConstantPropagation {
case lit: Literal =>
require(hi >= lo)
UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
+ case x if bitWidth(e.tpe) == bitWidth(x.tpe) =>
+ x.tpe match {
+ case t: UIntType => x
+ case _ => asUInt(x, e.tpe)
+ }
case _ => e
}
}
def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
- case x => e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
- // take sign bit if shift amount is larger than arg width
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
- case _ => e
- }
+ case x =>
+ e.args.head match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1)))
+ // take sign bit if shift amount is larger than arg width
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1)))
+ case _ => e
+ }
}
-
- /**********************************************
- * REGISTER CONSTANT PROPAGATION HELPER TYPES *
- **********************************************/
+ /** ********************************************
+ * REGISTER CONSTANT PROPAGATION HELPER TYPES *
+ * ********************************************
+ */
// A utility class that is somewhat like an Option but with two variants containing Nothing.
// for register constant propagation (register or literal).
private abstract class ConstPropBinding[+T] {
def resolve[V >: T](that: ConstPropBinding[V]): ConstPropBinding[V] = (this, that) match {
- case (x, y) if (x == y) => x
+ case (x, y) if (x == y) => x
case (x, UnboundConstant) => x
case (UnboundConstant, y) => y
- case _ => NonConstant
+ case _ => NonConstant
}
}
@@ -103,21 +105,23 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
override def prerequisites =
((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
+ ++ firrtl.stage.Forms.LowForm
+ - Dependency(firrtl.passes.Legalize)
+ + Dependency(firrtl.passes.RemoveValidIf)).toSeq
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf =
- Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[SystemVerilogEmitter],
+ Dependency[VerilogEmitter]
+ )
override def invalidates(a: Transform): Boolean = a match {
case firrtl.passes.Legalize => true
- case _ => false
+ case _ => false
}
override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation])
@@ -130,7 +134,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
sealed trait FoldCommutativeOp extends SimplifyBinaryOp {
- def fold(c1: Literal, c2: Literal): Expression
+ def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match {
@@ -138,7 +142,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe)
case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe)
case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs)
- case _ => e
+ case _ => e
}
}
@@ -177,20 +181,20 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
*/
def apply(prim: DoPrim): Expression = prim.args.head match {
case a: Literal => simplifyLiteral(a)
- case _ => prim
+ case _ => prim
}
}
object FoldADD extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
- case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
- case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1))
+ case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1))
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => rhs
case SIntLiteral(v, w) if v == BigInt(0) => rhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = e
}
@@ -209,77 +213,81 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
object FoldAND extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked(c1.value & c2.value, width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
- case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked((c1.value | c2.value), width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == BigInt(0) => rhs
- case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
+ case UIntLiteral(v, _) if v == BigInt(0) => rhs
+ case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldXOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked((c1.value ^ c2.value), width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe))
}
object FoldEqual extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) =>
+ DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1)
}
object FoldNotEqual extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) =>
+ DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0)
}
private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match {
- case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
+ case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) =>
+ UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
case _ => e
}
private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
- case x => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
- case _ => e
- }
+ case x =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
+ case _ => e
+ }
}
private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
@@ -296,53 +304,55 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case _ => e
}
-
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
case UIntType(_) => true
- case _ => false
+ case _ => false
}
def isZero(e: Expression) = e match {
- case UIntLiteral(value, _) => value == BigInt(0)
- case SIntLiteral(value, _) => value == BigInt(0)
- case _ => false
- }
+ case UIntLiteral(value, _) => value == BigInt(0)
+ case SIntLiteral(value, _) => value == BigInt(0)
+ case _ => false
+ }
x match {
- case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero
- case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one
- case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero
- case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one
- case ex => ex
+ case DoPrim(Lt, Seq(a, b), _, _) if isUInt(a) && isZero(b) => zero
+ case DoPrim(Leq, Seq(a, b), _, _) if isZero(a) && isUInt(b) => one
+ case DoPrim(Gt, Seq(a, b), _, _) if isZero(a) && isUInt(b) => zero
+ case DoPrim(Geq, Seq(a, b), _, _) if isUInt(a) && isZero(b) => one
+ case ex => ex
}
}
def foldIfOutsideRange(x: Expression): Expression = {
//Note, only abides by a partial ordering
case class Range(min: BigInt, max: BigInt) {
- def === (that: Range) =
+ def ===(that: Range) =
Seq(this.min, this.max, that.min, that.max)
- .sliding(2,1)
+ .sliding(2, 1)
.map(x => x.head == x(1))
.reduce(_ && _)
- def > (that: Range) = this.min > that.max
- def >= (that: Range) = this.min >= that.max
- def < (that: Range) = this.max < that.min
- def <= (that: Range) = this.max <= that.min
+ def >(that: Range) = this.min > that.max
+ def >=(that: Range) = this.min >= that.max
+ def <(that: Range) = this.max < that.min
+ def <=(that: Range) = this.max <= that.min
}
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
- case _ => e.tpe match {
- case SIntType(IntWidth(width)) => Range(
- min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
- max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
- )
- case UIntType(IntWidth(width)) => Range(
- min = BigInt(0),
- max = BigInt(2).pow(width.toInt) - BigInt(1)
- )
- }
+ case _ =>
+ e.tpe match {
+ case SIntType(IntWidth(width)) =>
+ Range(
+ min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
+ max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
+ )
+ case UIntType(IntWidth(width)) =>
+ Range(
+ min = BigInt(0),
+ max = BigInt(2).pow(width.toInt) - BigInt(1)
+ )
+ }
}
// Calculates an expression's range of values
x match {
@@ -351,27 +361,28 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
def r1 = range(ex.args(1))
ex.op match {
// Always true
- case Lt if r0 < r1 => one
+ case Lt if r0 < r1 => one
case Leq if r0 <= r1 => one
- case Gt if r0 > r1 => one
+ case Gt if r0 > r1 => one
case Geq if r0 >= r1 => one
// Always false
- case Lt if r0 >= r1 => zero
+ case Lt if r0 >= r1 => zero
case Leq if r0 > r1 => zero
- case Gt if r0 <= r1 => zero
+ case Gt if r0 <= r1 => zero
case Geq if r0 < r1 => zero
- case _ => ex
+ case _ => ex
}
case ex => ex
}
}
def foldIfMatchingArgs(x: Expression) = x match {
- case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match {
- case (Lt | Gt) => zero
- case (Leq | Geq) => one
- case _ => x
- }
+ case DoPrim(op, Seq(a, b), _, _) if (a == b) =>
+ op match {
+ case (Lt | Gt) => zero
+ case (Leq | Geq) => one
+ case _ => x
+ }
case _ => x
}
foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e)))
@@ -393,43 +404,47 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
private def constPropPrim(e: DoPrim): Expression = e.op match {
- case Shl => foldShiftLeft(e)
- case Dshl => foldDynamicShiftLeft(e)
- case Shr => foldShiftRight(e)
- case Dshr => foldDynamicShiftRight(e)
- case Cat => foldConcat(e)
- case Add => FoldADD(e)
- case Sub => SimplifySUB(e)
- case Div => SimplifyDIV(e)
- case Rem => SimplifyREM(e)
- case And => FoldAND(e)
- case Or => FoldOR(e)
- case Xor => FoldXOR(e)
- case Eq => FoldEqual(e)
- case Neq => FoldNotEqual(e)
- case Andr => FoldANDR(e)
- case Orr => FoldORR(e)
- case Xorr => FoldXORR(e)
+ case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
+ case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
+ case Cat => foldConcat(e)
+ case Add => FoldADD(e)
+ case Sub => SimplifySUB(e)
+ case Div => SimplifyDIV(e)
+ case Rem => SimplifyREM(e)
+ case And => FoldAND(e)
+ case Or => FoldOR(e)
+ case Xor => FoldXOR(e)
+ case Eq => FoldEqual(e)
+ case Neq => FoldNotEqual(e)
+ case Andr => FoldANDR(e)
+ case Orr => FoldORR(e)
+ case Xorr => FoldXORR(e)
case (Lt | Leq | Gt | Geq) => foldComparison(e)
- case Not => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
- case _ => e
- }
+ case Not =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
+ case _ => e
+ }
case AsUInt =>
e.args.head match {
case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case arg => arg.tpe match {
- case _: UIntType => arg
- case _ => e
- }
+ case arg =>
+ arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
}
- case AsSInt => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w))
- case arg => arg.tpe match {
- case _: SIntType => arg
- case _ => e
+ case AsSInt =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w))
+ case arg =>
+ arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
}
- }
case AsClock =>
val arg = e.args.head
arg.tpe match {
@@ -442,25 +457,27 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case AsyncResetType => arg
case _ => e
}
- case Pad => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w))
- case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
- case _ => e
- }
+ case Pad =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
+ case _ => e
+ }
case (Bits | Head | Tail) => constPropBitExtract(e)
- case _ => e
+ case _ => e
}
private def constPropMuxCond(m: Mux) = m.cond match {
case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe)
- case _ => m
+ case _ => m
}
private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {
case _ if m.tval == m.fval => m.tval
case (t: UIntLiteral, f: UIntLiteral)
- if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => m.cond
+ if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) =>
+ m.cond
case (t: UIntLiteral, _) if t.value == BigInt(1) && bitWidth(m.tpe) == BigInt(1) =>
DoPrim(Or, Seq(m.cond, m.fval), Nil, m.tpe)
case (_, f: UIntLiteral) if f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) =>
@@ -479,15 +496,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
- def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
- def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
-
- private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = {
- val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
+ def optimize(e: Expression): Expression =
+ constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ def optimize(e: Expression, nodeMap: NodeMap): Expression =
+ constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+
+ private def constPropExpression(
+ nodeMap: NodeMap,
+ instMap: collection.Map[Instance, OfModule],
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ )(e: Expression
+ ): Expression = {
+ val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs))
val propagated = old match {
case p: DoPrim => constPropPrim(p)
- case m: Mux => constPropMux(m)
- case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
+ case m: Mux => constPropMux(m)
+ case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) =>
val module = instMap(inst.Instance)
@@ -506,17 +530,17 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
* @todo generalize source locator propagation across Expressions and delete this method
* @todo is the `orElse` the way we want to do propagation here?
*/
- private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])
- (stmt: Statement): Statement = stmt match {
- // We check rname because inlining it would cause the original declaration to go away
- case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
- val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
- node.copy(info = InfoExpr.orElse(info1, info0))
- case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
- val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
- con.copy(info = InfoExpr.orElse(info1, info0))
- case other => other
- }
+ private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])(stmt: Statement): Statement =
+ stmt match {
+ // We check rname because inlining it would cause the original declaration to go away
+ case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ node.copy(info = InfoExpr.orElse(info1, info0))
+ case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ con.copy(info = InfoExpr.orElse(info1, info0))
+ case other => other
+ }
/* Constant propagate a Module
*
@@ -538,12 +562,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
*/
@tailrec
private def constPropModule(
- m: Module,
- dontTouches: Set[String],
- instMap: collection.Map[Instance, OfModule],
- constInputs: Map[String, Literal],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
- ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
+ m: Module,
+ dontTouches: Set[String],
+ instMap: collection.Map[Instance, OfModule],
+ constInputs: Map[String, Literal],
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
val nodeMap = new NodeMap()
@@ -571,13 +595,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// to constant wires, we don't need to worry about propagating primops or muxes since we'll do
// that on the next iteration if necessary
def backPropExpr(expr: Expression): Expression = {
- val old = expr map backPropExpr
+ val old = expr.map(backPropExpr)
val propagated = old match {
// When swapping, we swap both rhs and lhs
- case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) =>
+ case ref @ WRef(rname, _, _, _) if swapMap.contains(rname) =>
ref.copy(name = swapMap(rname))
// Only const prop on the rhs
- case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
+ case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case x => x
}
@@ -590,27 +614,29 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
def backPropStmt(stmt: Statement): Statement = stmt match {
case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) =>
// Self-init reset is an idiom for "no reset," and must be handled separately
- swapMap.get(reg.name)
- .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName)))
- .getOrElse(reg)
- case s => s map backPropExpr match {
- case decl: IsDeclaration if swapMap.contains(decl.name) =>
- val newName = swapMap(decl.name)
- nPropagated += 1
- decl match {
- case node: DefNode => node.copy(name = newName)
- case wire: DefWire => wire.copy(name = newName)
- case reg: DefRegister => reg.copy(name = newName)
- case other => throwInternalError()
- }
- case other => other map backPropStmt
- }
+ swapMap
+ .get(reg.name)
+ .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName)))
+ .getOrElse(reg)
+ case s =>
+ s.map(backPropExpr) match {
+ case decl: IsDeclaration if swapMap.contains(decl.name) =>
+ val newName = swapMap(decl.name)
+ nPropagated += 1
+ decl match {
+ case node: DefNode => node.copy(name = newName)
+ case wire: DefWire => wire.copy(name = newName)
+ case reg: DefRegister => reg.copy(name = newName)
+ case other => throwInternalError()
+ }
+ case other => other.map(backPropStmt)
+ }
}
// When propagating a reference, check if we want to keep the name that would be deleted
def propagateRef(lname: String, value: Expression, info: Info): Unit = {
value match {
- case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
+ case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a
// node declaration or the single connection to a wire or register
swapMap += (lname -> rname, rname -> lname)
@@ -639,25 +665,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) =>
-
- /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
- * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of
- * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can
- * be resolved: each side is allowed one candidate register and one candidate literal to
- * appear in their source trees, referring to the potential constant propagation case that
- * they could allow. If the two are compatible (no different bound sources of either of
- * the two types), they can be resolved by combining sources. Otherwise, they propagate
- * NonConstant values. When encountering a node reference, it expands the node by to its
- * RHS assignment and recurses.
- *
- * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This
- * can stifle register constant propagations, which looks at drivers through value-preserving
- * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into
- * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole
- * de-optimization that does not actually appear in the output code.
- *
- * @return a RegCPEntry describing the constant prop-compatible sources driving this expression
- */
+ /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
+ * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of
+ * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can
+ * be resolved: each side is allowed one candidate register and one candidate literal to
+ * appear in their source trees, referring to the potential constant propagation case that
+ * they could allow. If the two are compatible (no different bound sources of either of
+ * the two types), they can be resolved by combining sources. Otherwise, they propagate
+ * NonConstant values. When encountering a node reference, it expands the node by to its
+ * RHS assignment and recurses.
+ *
+ * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This
+ * can stifle register constant propagations, which looks at drivers through value-preserving
+ * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into
+ * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole
+ * de-optimization that does not actually appear in the output code.
+ *
+ * @return a RegCPEntry describing the constant prop-compatible sources driving this expression
+ */
val unbound = RegCPEntry(UnboundConstant, UnboundConstant)
val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant)
@@ -684,11 +709,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Updates nodeMap after analyzing the returned value from regConstant
def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match {
- case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
+ case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero)
- case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
+ case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
- case _ =>
+ case _ =>
}
def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
@@ -733,11 +758,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Unify two maps using f to combine values of duplicate keys
private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] =
- b.foldLeft(a) { case (acc, (k, v)) =>
- acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
+ b.foldLeft(a) {
+ case (acc, (k, v)) =>
+ acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
-
private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
val iGraph = InstanceKeyGraph(c)
val moduleDeps = iGraph.getChildInstanceMap
@@ -754,9 +779,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// are driven with the same constant value. Then, if we find a Module input where each instance
// is driven with the same constant (and not seen in a previous iteration), we iterate again
@tailrec
- def iterate(toVisit: Set[OfModule],
- modules: Map[OfModule, Module],
- constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = {
+ def iterate(
+ toVisit: Set[OfModule],
+ modules: Map[OfModule, Module],
+ constInputs: Map[OfModule, Map[String, Literal]]
+ ): Map[OfModule, DefModule] = {
if (toVisit.isEmpty) modules
else {
// Order from leaf modules to root so that any module driving an output
@@ -767,31 +794,36 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Aggreagte Module outputs that are driven constant for use by instaniating Modules
// Aggregate submodule inputs driven constant for checking later
val (modulesx, _, constInputsx) =
- order.foldLeft((modules,
- Map[OfModule, Map[String, Literal]](),
- Map[OfModule, Map[String, Seq[Literal]]]())) {
+ order.foldLeft((modules, Map[OfModule, Map[String, Literal]](), Map[OfModule, Map[String, Seq[Literal]]]())) {
case ((mmap, constOutputs, constInputsAcc), mname) =>
val dontTouches = dontTouchMap.getOrElse(mname, Set.empty)
- val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname),
- constInputs.getOrElse(mname, Map.empty), constOutputs)
+ val (mx, mco, mci) = constPropModule(
+ modules(mname),
+ dontTouches,
+ moduleDeps(mname),
+ constInputs.getOrElse(mname, Map.empty),
+ constOutputs
+ )
// Accumulate all Literals used to drive a particular Module port
val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
(mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx)
}
// Determine which module inputs have all of the same, new constants driving them
- val newProppedInputs = constInputsx.flatMap { case (mname, ports) =>
- val portsx = ports.flatMap { case (pname, lits) =>
- val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
- val isModule = modules.contains(mname) // ExtModules are not contained in modules
- val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
- if (isModule && newPort && allSameConst) Some(pname -> lits.head)
- else None
- }
- if (portsx.nonEmpty) Some(mname -> portsx) else None
+ val newProppedInputs = constInputsx.flatMap {
+ case (mname, ports) =>
+ val portsx = ports.flatMap {
+ case (pname, lits) =>
+ val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
+ val isModule = modules.contains(mname) // ExtModules are not contained in modules
+ val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
+ if (isModule && newPort && allSameConst) Some(pname -> lits.head)
+ else None
+ }
+ if (portsx.nonEmpty) Some(mname -> portsx) else None
}
val modsWithConstInputs = newProppedInputs.keySet
val newToVisit = modsWithConstInputs ++
- modsWithConstInputs.flatMap(parentGraph.reachableFrom)
+ modsWithConstInputs.flatMap(parentGraph.reachableFrom)
// Combine const inputs (there can't be duplicate values in the inner maps)
val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b)
iterate(newToVisit.toSet, modulesx, nextConstInputs)
@@ -805,7 +837,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
c.modules.map(m => mmap.getOrElse(m.OfModule, m))
}
-
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index c883bdfb..fb1bd1f6 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms
import firrtl._
@@ -8,7 +7,7 @@ import firrtl.annotations._
import firrtl.graph._
import firrtl.analyses.InstanceKeyGraph
import firrtl.Mappers._
-import firrtl.Utils.{throwInternalError, kind}
+import firrtl.Utils.{kind, throwInternalError}
import firrtl.MemoizedHash._
import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
@@ -29,29 +28,34 @@ import collection.mutable
* circumstances of their instantiation in their parent module, they will still not be removed. To
* remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication.
*/
-class DeadCodeElimination extends Transform
+class DeadCodeElimination
+ extends Transform
with ResolvedAnnotationPaths
with RegisteredTransform
with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination) )
+ Seq(
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[firrtl.transforms.CombineCats],
+ Dependency(passes.CommonSubexpressionElimination)
+ )
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf =
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(passes.VerilogPrep),
- Dependency[firrtl.AddDescriptionNodes] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(passes.VerilogPrep),
+ Dependency[firrtl.AddDescriptionNodes]
+ )
override def invalidates(a: Transform) = false
@@ -59,7 +63,9 @@ class DeadCodeElimination extends Transform
new ShellOption[Unit](
longOption = "no-dce",
toAnnotationSeq = (_: Unit) => Seq(NoDCEAnnotation),
- helpText = "Disable dead code elimination" ) )
+ helpText = "Disable dead code elimination"
+ )
+ )
/** Based on LogicNode ins CheckCombLoops, currently kind of faking it */
private type LogicNode = MemoizedHash[WrappedExpression]
@@ -72,6 +78,7 @@ class DeadCodeElimination extends Transform
val loweredName = LowerTypes.loweredName(component.name.split('.'))
apply(component.module.name, WRef(loweredName))
}
+
/** External Modules are representated as a single node driven by all inputs and driving all
* outputs
*/
@@ -87,7 +94,7 @@ class DeadCodeElimination extends Transform
def rec(e: Expression): Expression = {
e match {
case ref @ (_: WRef | _: WSubField) => refs += ref
- case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.map(rec)
case ignore @ (_: Literal) => // Do nothing
case unexpected => throwInternalError()
}
@@ -98,9 +105,7 @@ class DeadCodeElimination extends Transform
}
// Gets all dependencies and constructs LogicNodes from them
- private def getDepsImpl(mname: String,
- instMap: collection.Map[String, String])
- (expr: Expression): Seq[LogicNode] =
+ private def getDepsImpl(mname: String, instMap: collection.Map[String, String])(expr: Expression): Seq[LogicNode] =
extractRefs(expr).map { e =>
if (kind(e) == InstanceKind) {
val (inst, tail) = Utils.splitRef(e)
@@ -110,11 +115,12 @@ class DeadCodeElimination extends Transform
}
}
-
/** Construct the dependency graph within this module */
- private def setupDepGraph(depGraph: MutableDiGraph[LogicNode],
- instMap: collection.Map[String, String])
- (mod: Module): Unit = {
+ private def setupDepGraph(
+ depGraph: MutableDiGraph[LogicNode],
+ instMap: collection.Map[String, String]
+ )(mod: Module
+ ): Unit = {
def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
def onStmt(stmt: Statement): Unit = stmt match {
@@ -150,7 +156,7 @@ class DeadCodeElimination extends Transform
val node = getDeps(loc) match { case Seq(elt) => elt }
getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref))
// Simulation constructs are treated as top-level outputs
- case Stop(_,_, clk, en) =>
+ case Stop(_, _, clk, en) =>
Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Print(_, _, args, clk, en) =>
(args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
@@ -172,12 +178,14 @@ class DeadCodeElimination extends Transform
}
// TODO Make immutable?
- private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]],
- doTouchExtMods: Set[String],
- c: Circuit): MutableDiGraph[LogicNode] = {
+ private def createDependencyGraph(
+ instMaps: collection.Map[String, collection.Map[String, String]],
+ doTouchExtMods: Set[String],
+ c: Circuit
+ ): MutableDiGraph[LogicNode] = {
val depGraph = new MutableDiGraph[LogicNode]
c.modules.foreach {
- case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
+ case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
case ext: ExtModule =>
// Connect all inputs to all outputs
val node = LogicNode(ext)
@@ -205,23 +213,25 @@ class DeadCodeElimination extends Transform
depGraph
}
- private def deleteDeadCode(instMap: collection.Map[String, String],
- deadNodes: collection.Set[LogicNode],
- moduleMap: collection.Map[String, DefModule],
- renames: RenameMap,
- topName: String,
- doTouchExtMods: Set[String])
- (mod: DefModule): Option[DefModule] = {
+ private def deleteDeadCode(
+ instMap: collection.Map[String, String],
+ deadNodes: collection.Set[LogicNode],
+ moduleMap: collection.Map[String, DefModule],
+ renames: RenameMap,
+ topName: String,
+ doTouchExtMods: Set[String]
+ )(mod: DefModule
+ ): Option[DefModule] = {
// For log-level debug
def deleteMsg(decl: IsDeclaration): String = {
val tpe = decl match {
- case _: DefNode => "node"
+ case _: DefNode => "node"
case _: DefRegister => "reg"
- case _: DefWire => "wire"
- case _: Port => "port"
- case _: DefMemory => "mem"
+ case _: DefWire => "wire"
+ case _: Port => "port"
+ case _: DefMemory => "mem"
case (_: DefInstance | _: WDefInstance) => "inst"
- case _: Module => "module"
+ case _: Module => "module"
case _: ExtModule => "extmodule"
}
val ref = decl match {
@@ -237,7 +247,7 @@ class DeadCodeElimination extends Transform
def deleteIfNotEnabled(stmt: Statement, en: Expression): Statement = en match {
case UIntLiteral(v, _) if v == BigInt(0) => EmptyStmt
- case _ => stmt
+ case _ => stmt
}
def onStmt(stmt: Statement): Statement = {
@@ -256,12 +266,11 @@ class DeadCodeElimination extends Transform
logger.debug(deleteMsg(decl))
renames.delete(decl.name)
EmptyStmt
- }
- else decl
- case print: Print => deleteIfNotEnabled(print, print.en)
- case stop: Stop => deleteIfNotEnabled(stop, stop.en)
+ } else decl
+ case print: Print => deleteIfNotEnabled(print, print.en)
+ case stop: Stop => deleteIfNotEnabled(stop, stop.en)
case formal: Verification => deleteIfNotEnabled(formal, formal.en)
- case con: Connect =>
+ case con: Connect =>
val node = getDeps(con.loc) match { case Seq(elt) => elt }
if (deadNodes.contains(node)) EmptyStmt else con
case Attach(info, exprs) => // If any exprs are dead then all are
@@ -270,7 +279,7 @@ class DeadCodeElimination extends Transform
case IsInvalid(info, expr) =>
val node = getDeps(expr) match { case Seq(elt) => elt }
if (deadNodes.contains(node)) EmptyStmt else IsInvalid(info, expr)
- case block: Block => block map onStmt
+ case block: Block => block.map(onStmt)
case other => other
}
stmtx match { // Check if module empty
@@ -300,8 +309,7 @@ class DeadCodeElimination extends Transform
if (portsx.isEmpty && doTouchExtMods.contains(ext.name)) {
logger.debug(deleteMsg(mod))
None
- }
- else {
+ } else {
if (ext.ports != portsx) throwInternalError() // Sanity check
Some(ext.copy(ports = portsx))
}
@@ -309,14 +317,13 @@ class DeadCodeElimination extends Transform
}
- def run(state: CircuitState,
- dontTouches: Seq[LogicNode],
- doTouchExtMods: Set[String]): CircuitState = {
+ def run(state: CircuitState, dontTouches: Seq[LogicNode], doTouchExtMods: Set[String]): CircuitState = {
val c = state.circuit
val moduleMap = c.modules.map(m => m.name -> m).toMap
val iGraph = InstanceKeyGraph(c)
- val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) =>
- k.module -> v.map(i => i.name -> i.module).toMap
+ val moduleDeps = iGraph.graph.getEdgeMap.map({
+ case (k, v) =>
+ k.module -> v.map(i => i.name -> i.module).toMap
})
val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
@@ -347,11 +354,12 @@ class DeadCodeElimination extends Transform
// themselves. We iterate over the modules in a topological order from leaves to the top. The
// current status of the modulesxMap is used to either delete instances or update their types
val modulesxMap = mutable.HashMap.empty[String, DefModule]
- topoSortedModules.foreach { case mod =>
- deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match {
- case Some(m) => modulesxMap += m.name -> m
- case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
- }
+ topoSortedModules.foreach {
+ case mod =>
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match {
+ case Some(m) => modulesxMap += m.name -> m
+ case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
+ }
}
// Preserve original module order
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
index 627af11f..18e32cbc 100644
--- a/src/main/scala/firrtl/transforms/Dedup.scala
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -20,7 +20,6 @@ import scala.annotation.tailrec
// Datastructures
import scala.collection.mutable
-
/** A component, e.g. register etc. Must be declared only once under the TopAnnotation */
case class NoDedupAnnotation(target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] {
def duplicate(n: ModuleTarget): NoDedupAnnotation = NoDedupAnnotation(n)
@@ -36,7 +35,9 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt
new ShellOption[Unit](
longOption = "no-dedup",
toAnnotationSeq = _ => Seq(NoCircuitDedupAnnotation),
- helpText = "Do NOT dedup modules" ) )
+ helpText = "Do NOT dedup modules"
+ )
+ )
}
@@ -46,12 +47,13 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt
* @param original Original module
* @param index the normalized position of the original module in the original module list, fraction between 0 and 1
*/
-case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) extends MultiTargetAnnotation {
+case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double)
+ extends MultiTargetAnnotation {
override val targets: Seq[Seq[Target]] = Seq(Seq(original), duplicate.toList)
override def duplicate(n: Seq[Seq[Target]]): Annotation = {
n.toList match {
case Seq(_, List(dup: IsModule)) => DedupedResult(original, Some(dup), index)
- case _ => DedupedResult(original, None, -1)
+ case _ => DedupedResult(original, None, -1)
}
}
}
@@ -96,7 +98,7 @@ class DedupModules extends Transform with DependencyAPIMigration {
val noDedups = state.circuit.main +: state.annotations.collect { case NoDedupAnnotation(ModuleTarget(_, m)) => m }
val (remainingAnnotations, dupResults) = state.annotations.partition {
case _: DupedResult => false
- case _ => true
+ case _ => true
}
val previouslyDupedMap = dupResults.flatMap {
case DupedResult(newModules, original) =>
@@ -114,9 +116,11 @@ class DedupModules extends Transform with DependencyAPIMigration {
* @param noDedups Modules not to dedup
* @return Deduped Circuit and corresponding RenameMap
*/
- def run(c: Circuit,
- noDedups: Seq[String],
- previouslyDupedMap: Map[String, String]): (Circuit, RenameMap, AnnotationSeq) = {
+ def run(
+ c: Circuit,
+ noDedups: Seq[String],
+ previouslyDupedMap: Map[String, String]
+ ): (Circuit, RenameMap, AnnotationSeq) = {
// RenameMap
val componentRenameMap = RenameMap()
@@ -124,13 +128,16 @@ class DedupModules extends Transform with DependencyAPIMigration {
// Maps module name to corresponding dedup module
val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, previouslyDupedMap, componentRenameMap)
- val dedupCliques = dedupMap.foldLeft(Map.empty[String, Set[String]]) {
- case (dedupCliqueMap, (orig: String, dupMod: DefModule)) =>
- val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig
- dedupCliqueMap + (dupMod.name -> set)
- }.flatMap { case (dedupName, set) =>
- set.map { _ -> set }
- }
+ val dedupCliques = dedupMap
+ .foldLeft(Map.empty[String, Set[String]]) {
+ case (dedupCliqueMap, (orig: String, dupMod: DefModule)) =>
+ val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig
+ dedupCliqueMap + (dupMod.name -> set)
+ }
+ .flatMap {
+ case (dedupName, set) =>
+ set.map { _ -> set }
+ }
// Use old module list to preserve ordering
// Lookup what a module deduped to, if its a duplicate, remove it
@@ -149,9 +156,10 @@ class DedupModules extends Transform with DependencyAPIMigration {
val ct = CircuitTarget(c.main)
- val map = dedupMap.map { case (from, to) =>
- logger.debug(s"[Dedup] $from -> ${to.name}")
- ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name))
+ val map = dedupMap.map {
+ case (from, to) =>
+ logger.debug(s"[Dedup] $from -> ${to.name}")
+ ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name))
}
val moduleRenameMap = RenameMap()
moduleRenameMap.recordAll(map)
@@ -159,15 +167,19 @@ class DedupModules extends Transform with DependencyAPIMigration {
// Build instanceify renaming map
val instanceGraph = InstanceKeyGraph(c)
val instanceify = RenameMap()
- val moduleName2Index = c.modules.map(_.name).zipWithIndex.map { case (n, i) =>
- {
- c.modules.size match {
- case 0 => (n, 0.0)
- case 1 => (n, 1.0)
- case d => (n, i.toDouble / (d - 1))
+ val moduleName2Index = c.modules
+ .map(_.name)
+ .zipWithIndex
+ .map {
+ case (n, i) => {
+ c.modules.size match {
+ case 0 => (n, 0.0)
+ case 1 => (n, 1.0)
+ case d => (n, i.toDouble / (d - 1))
+ }
}
}
- }.toMap
+ .toMap
// get the ordered set of instances a module, includes new Deduped modules
val getChildrenInstances = {
@@ -182,56 +194,62 @@ class DedupModules extends Transform with DependencyAPIMigration {
}
val instanceNameMap: Map[OfModule, Map[Instance, Instance]] = {
- dedupMap.map { case (oldName, dedupedMod) =>
- val key = OfModule(oldName)
- val value = getChildrenInstances(oldName).zip(getChildrenInstances(dedupedMod.name)).map {
- case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name)
- }.toMap
- key -> value
+ dedupMap.map {
+ case (oldName, dedupedMod) =>
+ val key = OfModule(oldName)
+ val value = getChildrenInstances(oldName)
+ .zip(getChildrenInstances(dedupedMod.name))
+ .map {
+ case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name)
+ }
+ .toMap
+ key -> value
}.toMap
}
- val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { case mt@ModuleTarget(c, m) if dedupCliques(m).size > 1 =>
- dedupMap.get(m) match {
- case None => Nil
- case Some(module: DefModule) =>
- val paths = instanceGraph.findInstancesInHierarchy(m)
- // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option
- val newTargets = paths.map { path =>
- val root: IsModule = ct.module(c)
- path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) =>
- if(mod == c) {
- val mod = CircuitTarget(c).module(c)
- mod -> mod
- } else {
- val enclosingMod = oldRelPath match {
- case i: InstanceTarget => i.ofModule
- case m: ModuleTarget => m.module
- }
- val instMap = instanceNameMap(OfModule(enclosingMod))
- val newInstName = instMap(Instance(name)).value
- val old = oldRelPath.instOf(name, mod)
- old -> newRelPath.instOf(newInstName, mod)
+ val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap {
+ case mt @ ModuleTarget(c, m) if dedupCliques(m).size > 1 =>
+ dedupMap.get(m) match {
+ case None => Nil
+ case Some(module: DefModule) =>
+ val paths = instanceGraph.findInstancesInHierarchy(m)
+ // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option
+ val newTargets = paths.map { path =>
+ val root: IsModule = ct.module(c)
+ path.foldLeft(root -> root) {
+ case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) =>
+ if (mod == c) {
+ val mod = CircuitTarget(c).module(c)
+ mod -> mod
+ } else {
+ val enclosingMod = oldRelPath match {
+ case i: InstanceTarget => i.ofModule
+ case m: ModuleTarget => m.module
+ }
+ val instMap = instanceNameMap(OfModule(enclosingMod))
+ val newInstName = instMap(Instance(name)).value
+ val old = oldRelPath.instOf(name, mod)
+ old -> newRelPath.instOf(newInstName, mod)
+ }
}
}
- }
- // Add all relative paths to referredModule to map to new instances
- def addRecord(from: IsMember, to: IsMember): Unit = from match {
- case x: ModuleTarget =>
- instanceify.record(x, to)
- case x: IsComponent =>
- instanceify.record(x, to)
- addRecord(x.stripHierarchy(1), to)
- }
- // Instanceify deduped Modules!
- if (dedupCliques(module.name).size > 1) {
- newTargets.foreach { case (from, to) => addRecord(from, to) }
- }
- // Return Deduped Results
- if (newTargets.size == 1) {
- Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m)))
- } else Nil
- }
+ // Add all relative paths to referredModule to map to new instances
+ def addRecord(from: IsMember, to: IsMember): Unit = from match {
+ case x: ModuleTarget =>
+ instanceify.record(x, to)
+ case x: IsComponent =>
+ instanceify.record(x, to)
+ addRecord(x.stripHierarchy(1), to)
+ }
+ // Instanceify deduped Modules!
+ if (dedupCliques(module.name).size > 1) {
+ newTargets.foreach { case (from, to) => addRecord(from, to) }
+ }
+ // Return Deduped Results
+ if (newTargets.size == 1) {
+ Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m)))
+ } else Nil
+ }
case noDedups => Nil
}
@@ -242,6 +260,7 @@ class DedupModules extends Transform with DependencyAPIMigration {
/** Utility functions for [[DedupModules]] */
object DedupModules extends LazyLogging {
+
/** Change's a module's internal signal names, types, infos, and modules.
* @param rename Function to rename a signal. Called on declaration and references.
* @param retype Function to retype a signal. Called on declaration, references, and subfields
@@ -250,14 +269,16 @@ object DedupModules extends LazyLogging {
* @param module Module to change internals
* @return Changed Module
*/
- def changeInternals(rename: String=>String,
- retype: String=>Type=>Type,
- reinfo: Info=>Info,
- renameOfModule: (String, String)=>String,
- renameExps: Boolean = true
- )(module: DefModule): DefModule = {
+ def changeInternals(
+ rename: String => String,
+ retype: String => Type => Type,
+ reinfo: Info => Info,
+ renameOfModule: (String, String) => String,
+ renameExps: Boolean = true
+ )(module: DefModule
+ ): DefModule = {
def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe))
- def onExp(e: Expression): Expression = e match {
+ def onExp(e: Expression): Expression = e match {
case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g)
case WSubField(expr, n, tpe, kind) =>
val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n)
@@ -266,12 +287,12 @@ object DedupModules extends LazyLogging {
val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind)
//TODO: renameMap.rename(e.serialize, finalExpr.serialize)
finalExpr
- case other => other map onExp
+ case other => other.map(onExp)
}
def onStmt(s: Statement): Statement = s match {
case DefNode(info, name, value) =>
retype(name)(value.tpe)
- if(renameExps) DefNode(reinfo(info), rename(name), onExp(value))
+ if (renameExps) DefNode(reinfo(info), rename(name), onExp(value))
else DefNode(reinfo(info), rename(name), value)
case WDefInstance(i, n, m, t) =>
val newmod = renameOfModule(n, m)
@@ -283,12 +304,18 @@ object DedupModules extends LazyLogging {
val oldType = MemPortUtils.memType(d)
val newType = retype(d.name)(oldType)
val index = oldType
- .asInstanceOf[BundleType].fields.headOption
- .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere(
- {
- case Field("data" | "wdata" | "rdata", _, _) => true
- case _ => false
- }))
+ .asInstanceOf[BundleType]
+ .fields
+ .headOption
+ .map(
+ _.tpe
+ .asInstanceOf[BundleType]
+ .fields
+ .indexWhere({
+ case Field("data" | "wdata" | "rdata", _, _) => true
+ case _ => false
+ })
+ )
val newDataType = index match {
case Some(i) =>
//If index nonempty, then there exists a port
@@ -299,15 +326,15 @@ object DedupModules extends LazyLogging {
// associate it with the type of the memory (as the memory type is different than the datatype)
retype(d.name + ";&*^$")(d.dataType)
}
- d.copy(dataType = newDataType) map rename map reinfo
+ d.copy(dataType = newDataType).map(rename).map(reinfo)
case h: IsDeclaration =>
- val temp = h map rename map retype(h.name) map reinfo
- if(renameExps) temp map onExp else temp
+ val temp = h.map(rename).map(retype(h.name)).map(reinfo)
+ if (renameExps) temp.map(onExp) else temp
case other =>
- val temp = other map reinfo map onStmt
- if(renameExps) temp map onExp else temp
+ val temp = other.map(reinfo).map(onStmt)
+ if (renameExps) temp.map(onExp) else temp
}
- module map onPort map onStmt
+ module.map(onPort).map(onStmt)
}
/** Dedup a module's instances based on dedup map
@@ -321,11 +348,13 @@ object DedupModules extends LazyLogging {
* @param renameMap Will be modified to keep track of renames in this function
* @return fixed up module deduped instances
*/
- def dedupInstances(top: CircuitTarget,
- originalModule: String,
- moduleMap: Map[String, DefModule],
- name2name: Map[String, String],
- renameMap: RenameMap): DefModule = {
+ def dedupInstances(
+ top: CircuitTarget,
+ originalModule: String,
+ moduleMap: Map[String, DefModule],
+ name2name: Map[String, String],
+ renameMap: RenameMap
+ ): DefModule = {
val module = moduleMap(originalModule)
// If black box, return it (it has no instances)
@@ -340,7 +369,8 @@ object DedupModules extends LazyLogging {
}
val typeMap = mutable.HashMap[String, Type]()
def retype(name: String)(tpe: Type): Type = {
- if (typeMap.contains(name)) typeMap(name) else {
+ if (typeMap.contains(name)) typeMap(name)
+ else {
if (instanceModuleMap.contains(name)) {
val newType = Utils.module_type(getNewModule(instanceModuleMap(name)))
typeMap(name) = newType
@@ -360,7 +390,7 @@ object DedupModules extends LazyLogging {
def renameOfModule(instance: String, ofModule: String): String = {
name2name(ofModule)
}
- changeInternals({n => n}, retype, {i => i}, renameOfModule)(module)
+ changeInternals({ n => n }, retype, { i => i }, renameOfModule)(module)
}
@tailrec
@@ -415,10 +445,11 @@ object DedupModules extends LazyLogging {
* @return A map from tag to names of modules with the same structure and
* a RenameMap which maps Module names to their Tag.
*/
- def buildRTLTags(top: CircuitTarget,
- moduleLinearization: Seq[DefModule],
- noDedups: Set[String]
- ): (collection.Map[String, collection.Set[String]], RenameMap) = {
+ def buildRTLTags(
+ top: CircuitTarget,
+ moduleLinearization: Seq[DefModule],
+ noDedups: Set[String]
+ ): (collection.Map[String, collection.Set[String]], RenameMap) = {
// maps hash code to human readable tag
val hashToTag = mutable.HashMap[ir.HashCode, String]()
@@ -449,9 +480,9 @@ object DedupModules extends LazyLogging {
moduleNameToTag(originalModule.name) = hashToTag(hash)
}
- val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet }
+ val tag2all = hashToNames.map { case (hash, names) => hashToTag(hash) -> names.toSet }
val tagMap = RenameMap()
- moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) }
+ moduleNameToTag.foreach { case (name, tag) => tagMap.record(top.module(name), top.module(tag)) }
(tag2all, tagMap)
}
@@ -461,10 +492,12 @@ object DedupModules extends LazyLogging {
* @param renameMap rename map to populate when deduping
* @return Map of original Module name -> Deduped Module
*/
- def deduplicate(circuit: Circuit,
- noDedups: Set[String],
- previousDupResults: Map[String, String],
- renameMap: RenameMap): Map[String, DefModule] = {
+ def deduplicate(
+ circuit: Circuit,
+ noDedups: Set[String],
+ previousDupResults: Map[String, String],
+ renameMap: RenameMap
+ ): Map[String, DefModule] = {
val (moduleMap, moduleLinearization) = {
val iGraph = InstanceKeyGraph(circuit)
@@ -479,13 +512,14 @@ object DedupModules extends LazyLogging {
val (tag2all, tagMap) = buildRTLTags(top, moduleLinearization, noDedups)
// Set tag2name to be the best dedup module name
- val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap
+ val moduleIndex = circuit.modules.zipWithIndex.map { case (m, i) => m.name -> i }.toMap
// returns the module matching the circuit name or the module with lower index otherwise
def order(l: String, r: String): String = {
if (l == main) l
else if (r == main) r
- else if (moduleIndex(l) < moduleIndex(r)) l else r
+ else if (moduleIndex(l) < moduleIndex(r)) l
+ else r
}
// Maps a module's tag to its deduplicated module
@@ -499,7 +533,7 @@ object DedupModules extends LazyLogging {
tag2name(tag) = dedupName
val dedupModule = moduleMap(dedupWithoutOldName) match {
case e: ExtModule => e.copy(name = dedupName)
- case e: Module => e.copy(name = dedupName)
+ case e: Module => e.copy(name = dedupName)
}
dedupName -> dedupModule
}.toMap
@@ -508,32 +542,32 @@ object DedupModules extends LazyLogging {
val name2name = moduleMap.keysIterator.map { originalModule =>
tagMap.get(top.module(originalModule)) match {
case Some(Seq(Target(_, Some(tag), Nil))) => originalModule -> tag2name(tag)
- case None => originalModule -> originalModule
- case other => throwInternalError(other.toString)
+ case None => originalModule -> originalModule
+ case other => throwInternalError(other.toString)
}
}.toMap
// Build Remap for modules with deduped module references
val dedupedName2module = tag2name.map {
- case (tag, name) => name -> DedupModules.dedupInstances(
- top, name, moduleMapWithOldNames, name2name, renameMap)
+ case (tag, name) => name -> DedupModules.dedupInstances(top, name, moduleMapWithOldNames, name2name, renameMap)
}
// Build map from original name to corresponding deduped module
// It is important to flatMap before looking up the DefModules so that they aren't hashed
val name2module: Map[String, DefModule] =
tag2all.flatMap { case (tag, names) => names.map(_ -> tag) }
- .mapValues(tag => dedupedName2module(tag2name(tag)))
- .toMap
+ .mapValues(tag => dedupedName2module(tag2name(tag)))
+ .toMap
// Build renameMap
val indexedTargets = mutable.HashMap[String, IndexedSeq[ReferenceTarget]]()
- name2module.foreach { case (originalName, depModule) =>
- if(originalName != depModule.name) {
- val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule))
- val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName))
- computeRenameMap(fromSeq, toSeq, renameMap)
- }
+ name2module.foreach {
+ case (originalName, depModule) =>
+ if (originalName != depModule.name) {
+ val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule))
+ val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName))
+ computeRenameMap(fromSeq, toSeq, renameMap)
+ }
}
name2module
@@ -549,18 +583,21 @@ object DedupModules extends LazyLogging {
tpe
}
- changeInternals(rename, retype, {i => i}, {(x, y) => x}, renameExps = false)(m)
+ changeInternals(rename, retype, { i => i }, { (x, y) => x }, renameExps = false)(m)
refs.toIndexedSeq
}
- def computeRenameMap(originalNames: IndexedSeq[ReferenceTarget],
- dedupedNames: IndexedSeq[ReferenceTarget],
- renameMap: RenameMap): Unit = {
+ def computeRenameMap(
+ originalNames: IndexedSeq[ReferenceTarget],
+ dedupedNames: IndexedSeq[ReferenceTarget],
+ renameMap: RenameMap
+ ): Unit = {
originalNames.zip(dedupedNames).foreach {
- case (o, d) => if (o.component != d.component || o.ref != d.ref) {
- renameMap.record(o, d.copy(module = o.module))
- }
+ case (o, d) =>
+ if (o.component != d.component || o.ref != d.ref) {
+ renameMap.record(o, d.copy(module = o.module))
+ }
}
}
diff --git a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala
index a1e49d62..bfab31bf 100644
--- a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala
+++ b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala
@@ -33,7 +33,7 @@ object FixAddingNegativeLiterals {
*/
def fixupModule(m: DefModule): DefModule = {
val namespace = Namespace(m)
- m map fixupStatement(namespace)
+ m.map(fixupStatement(namespace))
}
/** Returns a statement with fixed additions of negative literals
@@ -43,8 +43,8 @@ object FixAddingNegativeLiterals {
*/
def fixupStatement(namespace: Namespace)(s: Statement): Statement = {
val stmtBuffer = mutable.ListBuffer[Statement]()
- val ret = s map fixupStatement(namespace) map fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer)
- if(stmtBuffer.isEmpty) {
+ val ret = s.map(fixupStatement(namespace)).map(fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer))
+ if (stmtBuffer.isEmpty) {
ret
} else {
stmtBuffer += ret
@@ -58,8 +58,7 @@ object FixAddingNegativeLiterals {
* @param e expression to fixup
* @return generated statements and the fixed expression
*/
- def fixupExpression(info: Info, namespace: Namespace)
- (e: Expression): (Seq[Statement], Expression) = {
+ def fixupExpression(info: Info, namespace: Namespace)(e: Expression): (Seq[Statement], Expression) = {
val stmtBuffer = mutable.ListBuffer[Statement]()
val retExpr = fixupOnExpr(info, namespace, stmtBuffer)(e)
(stmtBuffer.toList, retExpr)
@@ -72,12 +71,16 @@ object FixAddingNegativeLiterals {
* @param e expression to fixup
* @return fixed expression
*/
- private def fixupOnExpr(info: Info, namespace: Namespace, stmtBuffer: mutable.ListBuffer[Statement])
- (e: Expression): Expression = {
+ private def fixupOnExpr(
+ info: Info,
+ namespace: Namespace,
+ stmtBuffer: mutable.ListBuffer[Statement]
+ )(e: Expression
+ ): Expression = {
// Helper function to create the subtraction expression
def fixupAdd(expr: Expression, litValue: BigInt, litWidth: BigInt): DoPrim = {
- if(litValue == minNegValue(litWidth)) {
+ if (litValue == minNegValue(litWidth)) {
val posLiteral = SIntLiteral(-litValue)
assert(posLiteral.width.asInstanceOf[IntWidth].width - 1 == litWidth)
val sub = DefNode(info, namespace.newTemp, setType(DoPrim(Sub, Seq(expr, posLiteral), Nil, UnknownType)))
@@ -91,10 +94,10 @@ object FixAddingNegativeLiterals {
}
}
- e map fixupOnExpr(info, namespace, stmtBuffer) match {
- case DoPrim(Add, Seq(arg, lit@SIntLiteral(value, w@IntWidth(width))), Nil, t: SIntType) if value < 0 =>
+ e.map(fixupOnExpr(info, namespace, stmtBuffer)) match {
+ case DoPrim(Add, Seq(arg, lit @ SIntLiteral(value, w @ IntWidth(width))), Nil, t: SIntType) if value < 0 =>
fixupAdd(arg, value, width)
- case DoPrim(Add, Seq(lit@SIntLiteral(value, w@IntWidth(width)), arg), Nil, t: SIntType) if value < 0 =>
+ case DoPrim(Add, Seq(lit @ SIntLiteral(value, w @ IntWidth(width)), arg), Nil, t: SIntType) if value < 0 =>
fixupAdd(arg, value, width)
case other => other
}
diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala
index cc5b3504..36e71470 100644
--- a/src/main/scala/firrtl/transforms/Flatten.scala
+++ b/src/main/scala/firrtl/transforms/Flatten.scala
@@ -7,7 +7,7 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import scala.collection.mutable
-import firrtl.passes.{InlineInstances,PassException}
+import firrtl.passes.{InlineInstances, PassException}
import firrtl.stage.Forms
/** Tags an annotation to be consumed by this transform */
@@ -25,101 +25,114 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named
*/
class Flatten extends Transform with DependencyAPIMigration {
- override def prerequisites = Forms.LowForm
- override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Forms.LowEmitters
+ override def prerequisites = Forms.LowForm
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Forms.LowEmitters
override def invalidates(a: Transform) = false
- val inlineTransform = new InlineInstances
-
- private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
- anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) {
- case ((modNames, instNames), ann) => ann match {
- case FlattenAnnotation(CircuitName(c)) =>
- (circuit.modules.collect {
- case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
- }.toSet, instNames)
- case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
- case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
- case _ => throw new PassException("Annotation must be a FlattenAnnotation")
- }
- }
-
- /**
+ val inlineTransform = new InlineInstances
+
+ private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) {
+ case ((modNames, instNames), ann) =>
+ ann match {
+ case FlattenAnnotation(CircuitName(c)) =>
+ (
+ circuit.modules.collect {
+ case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
+ }.toSet,
+ instNames
+ )
+ case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
+ case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
+ case _ => throw new PassException("Annotation must be a FlattenAnnotation")
+ }
+ }
+
+ /**
* Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
* by rewriting the original circuit to refer to the new modules that will be inlined later.
* @return modified circuit and ModuleNames to inline
*/
- def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = {
- val modMap = c.modules.map(m => m.name->m).toMap
- val seedMods = mutable.Map.empty[String, String]
- val newModDefs = mutable.Set.empty[DefModule]
- val nsp = Namespace(c)
-
- /**
+ def duplicateSubCircuitsFromAnno(
+ c: Circuit,
+ mods: Set[ModuleName],
+ insts: Set[ComponentName]
+ ): (Circuit, Set[ModuleName]) = {
+ val modMap = c.modules.map(m => m.name -> m).toMap
+ val seedMods = mutable.Map.empty[String, String]
+ val newModDefs = mutable.Set.empty[DefModule]
+ val nsp = Namespace(c)
+
+ /**
* We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
* It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
* to a new module name that we will create as a replica of the original one.
* Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
*/
- def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
- case _: Block => x map rewriteMod(parent)
- case WDefInstance(info, instName, moduleName, instTpe) =>
- if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
- || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) {
- val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
- seedMods += moduleName -> newModName
- WDefInstance(info, instName, newModName, instTpe)
- } else x
- case _ => x
- }
-
- val modifMods = c.modules map { m => m map rewriteMod(m) }
-
- /**
+ def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
+ case _: Block => x.map(rewriteMod(parent))
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ if (
+ insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
+ || mods.contains(ModuleName(parent.name, CircuitName(c.main)))
+ ) {
+ val newModName =
+ if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
+ seedMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ } else x
+ case _ => x
+ }
+
+ val modifMods = c.modules.map { m => m.map(rewriteMod(m)) }
+
+ /**
* Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
* Populates newModDefs, which are replicated modules used in the subcircuit that we create
* by recursively traversing modules captured inside seedMods and replicating them
*/
- def recDupMods(mods: Map[String, String]): Unit = {
- val replMods = mutable.Map.empty[String, String]
-
- def dupMod(x: Statement): Statement = x match {
- case _: Block => x map dupMod
- case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match {
- case m: Module =>
- val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
- replMods += moduleName -> newModName
- WDefInstance(info, instName, newModName, instTpe)
- case _ => x // Ignore extmodules
- }
- case _ => x
- }
-
- def dupName(name: String): String = mods(name)
- val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName }
-
- newModDefs ++= newMods
-
- if(replMods.size > 0) recDupMods(replMods.toMap)
-
- }
- recDupMods(seedMods.toMap)
-
- //convert newly created modules to ModuleName for inlining next (outside this function)
- val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) }
- (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
- }
-
- override def execute(state: CircuitState): CircuitState = {
- val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
- annos match {
- case Nil => state
- case myAnnotations =>
- val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
- // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
- val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
- inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
- }
- }
+ def recDupMods(mods: Map[String, String]): Unit = {
+ val replMods = mutable.Map.empty[String, String]
+
+ def dupMod(x: Statement): Statement = x match {
+ case _: Block => x.map(dupMod)
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ modMap(moduleName) match {
+ case m: Module =>
+ val newModName =
+ if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
+ replMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ case _ => x // Ignore extmodules
+ }
+ case _ => x
+ }
+
+ def dupName(name: String): String = mods(name)
+ val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) }
+
+ newModDefs ++= newMods
+
+ if (replMods.size > 0) recDupMods(replMods.toMap)
+
+ }
+ recDupMods(seedMods.toMap)
+
+ //convert newly created modules to ModuleName for inlining next (outside this function)
+ val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) }
+ (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
+ }
+
+ override def execute(state: CircuitState): CircuitState = {
+ val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
+ annos match {
+ case Nil => state
+ case myAnnotations =>
+ val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
+ // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
+ val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
+ inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
+ }
+ }
}
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index a2399b5a..b582fe2a 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -119,7 +119,7 @@ object FlattenRegUpdate {
def rec(e: Expression): (Info, Expression) = {
val (info, expr) = kind(e) match {
case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e))
- case _ => unwrap(e)
+ case _ => unwrap(e)
}
expr match {
case Mux(cond, tval, fval, tpe) =>
@@ -128,16 +128,18 @@ object FlattenRegUpdate {
val infox = combineInfos(info, tinfo, finfo)
(infox, Mux(cond, tvalx, fvalx, tpe))
// Return the original expression to end flattening
- case _ => unwrap(e)
+ case _ => unwrap(e)
}
}
rec(start)
}
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
- case reg @ DefRegister(_, rname, _,_, resetCond, _) =>
- assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
- "Synchronous reset should have already been made explicit!")
+ case reg @ DefRegister(_, rname, _, _, resetCond, _) =>
+ assert(
+ resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
+ "Synchronous reset should have already been made explicit!"
+ )
val ref = WRef(reg)
val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref))
val update = Connect(info, ref, rhs)
@@ -145,7 +147,7 @@ object FlattenRegUpdate {
reg
// Remove connections to Registers so we preserve LowFirrtl single-connection semantics
case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt
- case other => other
+ case other => other
}
val bodyx = onStmt(mod.body)
@@ -163,12 +165,14 @@ object FlattenRegUpdate {
class FlattenRegUpdate extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[InlineCastsTransform],
- Dependency[LegalizeClocksTransform] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[InlineCastsTransform],
+ Dependency[LegalizeClocksTransform]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -181,7 +185,7 @@ class FlattenRegUpdate extends Transform with DependencyAPIMigration {
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map {
- case mod: Module => FlattenRegUpdate.flattenReg(mod)
+ case mod: Module => FlattenRegUpdate.flattenReg(mod)
case ext: ExtModule => ext
}
state.copy(circuit = state.circuit.copy(modules = modulesx))
diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala
index 166feba0..0db67f1e 100644
--- a/src/main/scala/firrtl/transforms/GroupComponents.scala
+++ b/src/main/scala/firrtl/transforms/GroupComponents.scala
@@ -10,7 +10,6 @@ import firrtl.stage.Forms
import scala.collection.mutable
-
/**
* Specifies a group of components, within a module, to pull out into their own module
* Components that are only connected to a group's components will also be included
@@ -21,8 +20,14 @@ import scala.collection.mutable
* @param outputSuffix suggested suffix of any output ports of the new module
* @param inputSuffix suggested suffix of any input ports of the new module
*/
-case class GroupAnnotation(components: Seq[ComponentName], newModule: String, newInstance: String, outputSuffix: Option[String] = None, inputSuffix: Option[String] = None) extends Annotation {
- if(components.nonEmpty) {
+case class GroupAnnotation(
+ components: Seq[ComponentName],
+ newModule: String,
+ newInstance: String,
+ outputSuffix: Option[String] = None,
+ inputSuffix: Option[String] = None)
+ extends Annotation {
+ if (components.nonEmpty) {
require(components.forall(_.module == components.head.module), "All components must be in the same module.")
require(components.forall(!_.name.contains('.')), "No components can be a subcomponent.")
}
@@ -35,7 +40,7 @@ case class GroupAnnotation(components: Seq[ComponentName], newModule: String, ne
/* Only keeps components renamed to components */
def update(renames: RenameMap): Seq[Annotation] = {
- val newComponents = components.flatMap{c => renames.get(c).getOrElse(Seq(c))}.collect {
+ val newComponents = components.flatMap { c => renames.get(c).getOrElse(Seq(c)) }.collect {
case c: ComponentName => c
}
Seq(GroupAnnotation(newComponents, newModule, newInstance, outputSuffix, inputSuffix))
@@ -58,7 +63,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
}
override def execute(state: CircuitState): CircuitState = {
- val groups = state.annotations.collect {case g: GroupAnnotation => g}
+ val groups = state.annotations.collect { case g: GroupAnnotation => g }
val module2group = groups.groupBy(_.currentModule)
val mnamespace = Namespace(state.circuit)
val newModules = state.circuit.modules.flatMap {
@@ -74,13 +79,12 @@ class GroupComponents extends Transform with DependencyAPIMigration {
val namespace = Namespace(m)
val groupRoots = groups.map(_.components.map(_.name))
val totalSum = groupRoots.map(_.size).sum
- val union = groupRoots.foldLeft(Set.empty[String]){(all, set) => all.union(set.toSet)}
+ val union = groupRoots.foldLeft(Set.empty[String]) { (all, set) => all.union(set.toSet) }
- require(groupRoots.forall{_.forall{namespace.contains}}, "All names should be in this module")
+ require(groupRoots.forall { _.forall { namespace.contains } }, "All names should be in this module")
require(totalSum == union.size, "No name can be in more than one group")
require(groupRoots.forall(_.nonEmpty), "All groupRoots must by non-empty")
-
// Order of groups, according to their label. The label is the first root in the group
val labelOrder = groups.collect({ case g: GroupAnnotation => g.components.head.name })
@@ -90,8 +94,8 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Group roots, by label
// The label "" indicates the original module, and components belonging to that group will remain
// in the original module (not get moved into a new module)
- val label2group: Map[String, MSet[String]] = groups.collect{
- case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name):_*)
+ val label2group: Map[String, MSet[String]] = groups.collect {
+ case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name): _*)
}.toMap + ("" -> mutable.Set(""))
// Name of new module containing each group, by label
@@ -105,7 +109,6 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Build set of components not in set
val notSet = label2group.map { case (key, value) => key -> union.diff(value) }
-
// Get all dependencies between components
val deps = getComponentConnectivity(m)
@@ -114,13 +117,14 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// For each group (by label), add connectivity between nodes in set
// Populate reachableNodes with reachability, where blacklist is their notSet
- label2group.foreach { case (label, set) =>
- set.foreach { x =>
- deps.addPairWithEdge(label, x)
- }
- deps.reachableFrom(label, notSet(label)) foreach { node =>
- reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label
- }
+ label2group.foreach {
+ case (label, set) =>
+ set.foreach { x =>
+ deps.addPairWithEdge(label, x)
+ }
+ deps.reachableFrom(label, notSet(label)).foreach { node =>
+ reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label
+ }
}
// Unused nodes are not reachable from any group nor the root--add them to root group
@@ -129,12 +133,13 @@ class GroupComponents extends Transform with DependencyAPIMigration {
}
// Add nodes who are reached by a single group, to that group
- reachableNodes.foreach { case (node, membership) =>
- if(membership.size == 1) {
- label2group(membership.head) += node
- } else {
- label2group("") += node
- }
+ reachableNodes.foreach {
+ case (node, membership) =>
+ if (membership.size == 1) {
+ label2group(membership.head) += node
+ } else {
+ label2group("") += node
+ }
}
applyGrouping(m, labelOrder, label2group, label2module, label2instance, label2annotation)
@@ -150,19 +155,21 @@ class GroupComponents extends Transform with DependencyAPIMigration {
* @param label2annotation annotation specifying the group, by label
* @return new modules, including each group's module and the new split module
*/
- def applyGrouping( m: Module,
- labelOrder: Seq[String],
- label2group: Map[String, MSet[String]],
- label2module: Map[String, String],
- label2instance: Map[String, String],
- label2annotation: Map[String, GroupAnnotation]
- ): Seq[Module] = {
+ def applyGrouping(
+ m: Module,
+ labelOrder: Seq[String],
+ label2group: Map[String, MSet[String]],
+ label2module: Map[String, String],
+ label2instance: Map[String, String],
+ label2annotation: Map[String, GroupAnnotation]
+ ): Seq[Module] = {
// Maps node to group
val byNode = mutable.HashMap[String, String]()
- label2group.foreach { case (group, nodes) =>
- nodes.foreach { node =>
- byNode(node) = group
- }
+ label2group.foreach {
+ case (group, nodes) =>
+ nodes.foreach { node =>
+ byNode(node) = group
+ }
}
val groupNamespace = label2group.map { case (head, set) => head -> Namespace(set.toSeq) }
@@ -180,7 +187,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
val portNames = groupPortNames(group)
val suffix = d match {
case Output => label2annotation(group).outputSuffix.getOrElse("")
- case Input => label2annotation(group).inputSuffix.getOrElse("")
+ case Input => label2annotation(group).inputSuffix.getOrElse("")
}
val newName = groupNamespace(group).newName(source + suffix)
val portName = portNames.getOrElseUpdate(source, newName)
@@ -192,7 +199,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
val portName = addPort(group, exp, Output)
val connectStatement = exp.tpe match {
case AnalogType(_) => Attach(NoInfo, Seq(WRef(portName), exp))
- case _ => Connect(NoInfo, WRef(portName), exp)
+ case _ => Connect(NoInfo, WRef(portName), exp)
}
groupStatements(group) += connectStatement
portName
@@ -201,7 +208,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Given the sink is in a group, tidy up source references
def inGroupFixExps(group: String, added: mutable.ArrayBuffer[Statement])(e: Expression): Expression = e match {
case _: Literal => e
- case _: DoPrim | _: Mux | _: ValidIf => e map inGroupFixExps(group, added)
+ case _: DoPrim | _: Mux | _: ValidIf => e.map(inGroupFixExps(group, added))
case otherExp: Expression =>
val wref = getWRef(otherExp)
val source = wref.name
@@ -238,10 +245,10 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Given the sink is in the parent module, tidy up source references belonging to groups
def inTopFixExps(e: Expression): Expression = e match {
- case _: DoPrim | _: Mux | _: ValidIf => e map inTopFixExps
+ case _: DoPrim | _: Mux | _: ValidIf => e.map(inTopFixExps)
case otherExp: Expression =>
val wref = getWRef(otherExp)
- if(byNode(wref.name) != "") {
+ if (byNode(wref.name) != "") {
// Get the name of source's group
val otherGroup = byNode(wref.name)
@@ -260,7 +267,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
case r: IsDeclaration if byNode(r.name) != "" =>
val topStmts = mutable.ArrayBuffer[Statement]()
val group = byNode(r.name)
- groupStatements(group) += r mapExpr inGroupFixExps(group, topStmts)
+ groupStatements(group) += r.mapExpr(inGroupFixExps(group, topStmts))
Block(topStmts.toSeq)
case c: Connect if byNode(getWRef(c.loc).name) != "" =>
// Sink is in a group
@@ -276,20 +283,26 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// TODO Attach if all are in a group?
case _: IsDeclaration | _: Connect | _: Attach =>
// Sink is in Top
- val ret = s mapExpr inTopFixExps
+ val ret = s.mapExpr(inTopFixExps)
ret
- case other => other map onStmt
+ case other => other.map(onStmt)
}
}
-
// Build datastructures
- val newTopBody = Block(labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body)))
+ val newTopBody = Block(
+ labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body))
+ )
val finalTopBody = Block(Utils.squashEmpty(newTopBody).asInstanceOf[Block].stmts.distinct)
// For all group labels (not including the original module label), return a new Module.
- val newModules = labelOrder.filter(_ != "") map { group =>
- Module(NoInfo, label2module(group), groupPorts(group).distinct.toSeq, Block(groupStatements(group).distinct.toSeq))
+ val newModules = labelOrder.filter(_ != "").map { group =>
+ Module(
+ NoInfo,
+ label2module(group),
+ groupPorts(group).distinct.toSeq,
+ Block(groupStatements(group).distinct.toSeq)
+ )
}
Seq(m.copy(body = finalTopBody)) ++ newModules
}
@@ -298,7 +311,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
case w: WRef => w
case other =>
var w = WRef("")
- other mapExpr { e => w = getWRef(e); e}
+ other.mapExpr { e => w = getWRef(e); e }
w
}
@@ -317,25 +330,25 @@ class GroupComponents extends Transform with DependencyAPIMigration {
bidirGraph.addPairWithEdge(sink.name, name)
bidirGraph.addPairWithEdge(name, sink.name)
w
- case other => other map onExpr(sink)
+ case other => other.map(onExpr(sink))
}
def onStmt(stmt: Statement): Unit = stmt match {
case w: WDefInstance =>
case h: IsDeclaration =>
bidirGraph.addVertex(h.name)
- h map onExpr(WRef(h.name))
+ h.map(onExpr(WRef(h.name)))
case Attach(_, exprs) => // Add edge between each expression
- exprs.tail map onExpr(getWRef(exprs.head))
+ exprs.tail.map(onExpr(getWRef(exprs.head)))
case Connect(_, loc, expr) =>
onExpr(getWRef(loc))(expr)
- case q @ Stop(_,_, clk, en) =>
+ case q @ Stop(_, _, clk, en) =>
val simName = simNamespace.newTemp
simulations(simName) = q
- Seq(clk, en) map onExpr(WRef(simName))
+ Seq(clk, en).map(onExpr(WRef(simName)))
case q @ Print(_, _, args, clk, en) =>
val simName = simNamespace.newTemp
simulations(simName) = q
- (args :+ clk :+ en) map onExpr(WRef(simName))
+ (args :+ clk :+ en).map(onExpr(WRef(simName)))
case Block(stmts) => stmts.foreach(onStmt)
case ignore @ (_: IsInvalid | EmptyStmt) => // do nothing
case other => throw new Exception(s"Unexpected Statement $other")
@@ -358,7 +371,7 @@ class GroupAndDedup extends GroupComponents {
override def invalidates(a: Transform): Boolean = a match {
case _: DedupModules => true
- case _ => super.invalidates(a)
+ case _ => super.invalidates(a)
}
}
diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala
index dd073001..376382cc 100644
--- a/src/main/scala/firrtl/transforms/InferResets.scala
+++ b/src/main/scala/firrtl/transforms/InferResets.scala
@@ -7,9 +7,9 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.annotations.{ReferenceTarget, TargetToken}
-import firrtl.Utils.{toTarget, throwInternalError}
+import firrtl.Utils.{throwInternalError, toTarget}
import firrtl.options.Dependency
-import firrtl.passes.{Pass, PassException, InferTypes}
+import firrtl.passes.{InferTypes, Pass, PassException}
import firrtl.graph.MutableDiGraph
import scala.collection.mutable
@@ -83,14 +83,13 @@ object InferResets {
// Vectors must all have the same type, so we only process Index 0
// If the subtype is an aggregate, there can be multiple of each index
val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) }
- VectorTree(fromTokens(ts:_*))
+ VectorTree(fromTokens(ts: _*))
// BundleTree
case (TargetToken.Field(_) +: _, _) +: _ =>
val fields =
- tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }
- .mapValues { ts =>
- fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }:_*)
- }.toMap
+ tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }.mapValues { ts =>
+ fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }: _*)
+ }.toMap
BundleTree(fields)
}
}
@@ -113,14 +112,16 @@ object InferResets {
class InferResets extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferWidths] ) ++ stage.Forms.WorkingIR
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferWidths]
+ ) ++ stage.Forms.WorkingIR
override def invalidates(a: Transform): Boolean = a match {
case _: checks.CheckResets | passes.CheckTypes => true
- case _ => false
+ case _ => false
}
import InferResets._
@@ -138,7 +139,7 @@ class InferResets extends Transform with DependencyAPIMigration {
val mod = instMap(target.ref)
val port = target.component.head match {
case TargetToken.Field(name) => name
- case bad => Utils.throwInternalError(s"Unexpected token $bad")
+ case bad => Utils.throwInternalError(s"Unexpected token $bad")
}
target.copy(module = mod, ref = port, component = target.component.tail)
case _ => target
@@ -148,17 +149,18 @@ class InferResets extends Transform with DependencyAPIMigration {
// Mark driver of a ResetType leaf
def markResetDriver(lhs: Expression, rhs: Expression): Unit = {
val con = Utils.flow(lhs) match {
- case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs))
+ case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs))
case SourceFlow if rhs.tpe == ResetType => Some((rhs, lhs))
// If sink is not ResetType, do nothing
- case _ => None
+ case _ => None
}
- con.foreach { case (loc, exp) =>
- val driver = exp.tpe match {
- case ResetType => TargetDriver(makeTarget(exp))
- case tpe => TypeDriver(tpe, () => makeTarget(exp))
- }
- map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver
+ con.foreach {
+ case (loc, exp) =>
+ val driver = exp.tpe match {
+ case ResetType => TargetDriver(makeTarget(exp))
+ case tpe => TypeDriver(tpe, () => makeTarget(exp))
+ }
+ map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver
}
}
stmt match {
@@ -227,7 +229,7 @@ class InferResets extends Transform with DependencyAPIMigration {
private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = {
val graph = new MutableDiGraph[Node]
val asyncNode = Typ(AsyncResetType)
- val syncNode = Typ(Utils.BoolType)
+ val syncNode = Typ(Utils.BoolType)
for ((target, drivers) <- map) {
val v = Var(target)
drivers.foreach {
@@ -247,7 +249,7 @@ class InferResets extends Transform with DependencyAPIMigration {
// do the actual inference, the check is simply if syncNode is reachable from asyncNode
graph.addPairWithEdge(v, u)
case InvalidDriver =>
- graph.addVertex(v) // Must be in the graph or won't be inferred
+ graph.addVertex(v) // Must be in the graph or won't be inferred
}
}
val async = graph.reachableFrom(asyncNode)
@@ -257,7 +259,7 @@ class InferResets extends Transform with DependencyAPIMigration {
case (a, _) if a.contains(syncNode) => throw InferResetsException(graph.path(asyncNode, syncNode))
case (a, s) =>
(a.view.collect { case Var(t) => t -> asyncNode.tpe } ++
- s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap
+ s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap
}
}
}
@@ -265,34 +267,40 @@ class InferResets extends Transform with DependencyAPIMigration {
private def fixupType(tpe: Type, tree: TypeTree): Type = (tpe, tree) match {
case (BundleType(fields), BundleTree(map)) =>
val fieldsx =
- fields.map(f => map.get(f.name) match {
- case Some(t) => f.copy(tpe = fixupType(f.tpe, t))
- case None => f
- })
+ fields.map(f =>
+ map.get(f.name) match {
+ case Some(t) => f.copy(tpe = fixupType(f.tpe, t))
+ case None => f
+ }
+ )
BundleType(fieldsx)
case (VectorType(vtpe, size), VectorTree(t)) =>
VectorType(fixupType(vtpe, t), size)
case (_, GroundTree(t)) => t
- case x => throw new Exception(s"Error! Unexpected pair $x")
+ case x => throw new Exception(s"Error! Unexpected pair $x")
}
// Assumes all ReferenceTargets are in the same module
private def makeDeclMap(map: Map[ReferenceTarget, Type]): Map[String, TypeTree] =
- map.groupBy(_._1.ref).mapValues { ts =>
- TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }:_*)
- }.toMap
+ map
+ .groupBy(_._1.ref)
+ .mapValues { ts =>
+ TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }: _*)
+ }
+ .toMap
private def implPort(map: Map[String, TypeTree])(port: Port): Port =
- map.get(port.name)
- .map(tree => port.copy(tpe = fixupType(port.tpe, tree)))
- .getOrElse(port)
+ map
+ .get(port.name)
+ .map(tree => port.copy(tpe = fixupType(port.tpe, tree)))
+ .getOrElse(port)
private def implStmt(map: Map[String, TypeTree])(stmt: Statement): Statement =
stmt.map(implStmt(map)) match {
case decl: IsDeclaration if map.contains(decl.name) =>
val tree = map(decl.name)
decl match {
- case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree))
- case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree))
+ case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree))
+ case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree))
// TODO Can this really happen?
case mem: DefMemory => mem.copy(dataType = fixupType(mem.dataType, tree))
case other => other
@@ -303,10 +311,13 @@ class InferResets extends Transform with DependencyAPIMigration {
private def implement(c: Circuit, map: Map[ReferenceTarget, Type]): Circuit = {
val modMaps = map.groupBy(_._1.module)
def onMod(mod: DefModule): DefModule = {
- modMaps.get(mod.name).map { tmap =>
- val declMap = makeDeclMap(tmap)
- mod.map(implPort(declMap)).map(implStmt(declMap))
- }.getOrElse(mod)
+ modMaps
+ .get(mod.name)
+ .map { tmap =>
+ val declMap = makeDeclMap(tmap)
+ mod.map(implPort(declMap)).map(implStmt(declMap))
+ }
+ .getOrElse(mod)
}
c.map(onMod)
}
diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
index 515bf407..100b598f 100644
--- a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
+++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
@@ -6,7 +6,7 @@ package transforms
import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
-import firrtl.PrimOps.{Bits, Head, Tail, Shr}
+import firrtl.PrimOps.{Bits, Head, Shr, Tail}
import firrtl.Utils.{isBitExtract, isTemp}
import firrtl.WrappedExpression._
@@ -19,8 +19,8 @@ object InlineBitExtractionsTransform {
// Note that this can have false negatives but MUST NOT have false positives.
private def isSimpleExpr(expr: Expression): Boolean = expr match {
case _: WRef | _: Literal | _: WSubField => true
- case DoPrim(op, args, _,_) if isBitExtract(op) => args.forall(isSimpleExpr)
- case _ => false
+ case DoPrim(op, args, _, _) if isBitExtract(op) => args.forall(isSimpleExpr)
+ case _ => false
}
// replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions
@@ -28,12 +28,12 @@ object InlineBitExtractionsTransform {
case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) =>
val msb = bitWidth(rhs.head.tpe) - 1
val lsb = bitWidth(rhs.head.tpe) - c.head
- DoPrim(Bits, rhs, Seq(msb,lsb), tpe)
+ DoPrim(Bits, rhs, Seq(msb, lsb), tpe)
case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) =>
val msb = bitWidth(rhs.head.tpe) - c.head - 1
- DoPrim(Bits, rhs, Seq(msb,0), tpe)
+ DoPrim(Bits, rhs, Seq(msb, 0), tpe)
case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) =>
- DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe)-1, c.head), tpe)
+ DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe) - 1, c.head), tpe)
case _ => expr // Not a candidate
}
@@ -49,26 +49,28 @@ object InlineBitExtractionsTransform {
*/
def onExpr(netlist: Netlist)(expr: Expression): Expression = {
expr.map(onExpr(netlist)) match {
- case e @ WRef(name, _,_,_) =>
- netlist.get(we(e))
- .filter(isBitExtract)
- .getOrElse(e)
+ case e @ WRef(name, _, _, _) =>
+ netlist
+ .get(we(e))
+ .filter(isBitExtract)
+ .getOrElse(e)
// replace back-to-back Bits Extractions
case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) =>
ival.head match {
case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) =>
(lop, rop) match {
- case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head min rc.head), ltpe)
+ case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head.min(rc.head)), ltpe)
case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe)
- case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
- case (_,_) => (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
- case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
- DoPrim(Bits, rhs, Seq(lmsb+rlsb,llsb+rlsb), ltpe)
- case (_,_) => lhs // Not a candidate
- }
+ case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
+ case (_, _) =>
+ (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
+ case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
+ DoPrim(Bits, rhs, Seq(lmsb + rlsb, llsb + rlsb), ltpe)
+ case (_, _) => lhs // Not a candidate
+ }
}
- case _ => lhs // Not a candidate
- }
+ case _ => lhs // Not a candidate
+ }
case other => other // Not a candidate
}
}
@@ -97,9 +99,11 @@ object InlineBitExtractionsTransform {
class InlineBitExtractionsTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala
index 3dac938e..0efc0727 100644
--- a/src/main/scala/firrtl/transforms/InlineCasts.scala
+++ b/src/main/scala/firrtl/transforms/InlineCasts.scala
@@ -8,7 +8,7 @@ import firrtl.Mappers._
import firrtl.PrimOps.Pad
import firrtl.options.Dependency
-import firrtl.Utils.{isCast, isBitExtract, NodeMap}
+import firrtl.Utils.{isBitExtract, isCast, NodeMap}
object InlineCastsTransform {
@@ -17,8 +17,8 @@ object InlineCastsTransform {
// Note that this can have false negatives but MUST NOT have false positives
private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match {
case _: WRef | _: Literal | _: WSubField => castSeen
- case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true))
- case _ => false
+ case DoPrim(op, args, _, _) if isCast(op) => args.forall(isSimpleCast(true))
+ case _ => false
}
/** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s
@@ -31,17 +31,20 @@ object InlineCastsTransform {
def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match {
// Anything that may generate a part-select should not be inlined!
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
- case e => e.map(onExpr(replace)) match {
- case e @ WRef(name, _,_,_) =>
- replace.get(name)
- .filter(isSimpleCast(castSeen=false))
- .getOrElse(e)
- case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) =>
- replace.get(name)
- .map(value => e.copy(args = Seq(value)))
- .getOrElse(e)
- case other => other // Not a candidate
- }
+ case e =>
+ e.map(onExpr(replace)) match {
+ case e @ WRef(name, _, _, _) =>
+ replace
+ .get(name)
+ .filter(isSimpleCast(castSeen = false))
+ .getOrElse(e)
+ case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
+ replace
+ .get(name)
+ .map(value => e.copy(args = Seq(value)))
+ .getOrElse(e)
+ case other => other // Not a candidate
+ }
}
/** Inline casts in a Statement
@@ -69,11 +72,13 @@ object InlineCastsTransform {
class InlineCastsTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[PropagatePresetAnnotations] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[PropagatePresetAnnotations]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala
index f439fdc9..248775d9 100644
--- a/src/main/scala/firrtl/transforms/LegalizeClocks.scala
+++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala
@@ -18,8 +18,8 @@ object LegalizeClocksTransform {
// Currently only looks for literals nested within casts
private def illegalClockExpr(expr: Expression): Boolean = expr match {
case _: Literal => true
- case DoPrim(op, args, _,_) if isCast(op) => args.exists(illegalClockExpr)
- case _ => false
+ case DoPrim(op, args, _, _) if isCast(op) => args.exists(illegalClockExpr)
+ case _ => false
}
/** Legalize Clocks in a Statement
@@ -66,11 +66,13 @@ object LegalizeClocksTransform {
class LegalizeClocksTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[InlineCastsTransform] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[InlineCastsTransform]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/LegalizeReductions.scala b/src/main/scala/firrtl/transforms/LegalizeReductions.scala
index 2e60aae7..33a10349 100644
--- a/src/main/scala/firrtl/transforms/LegalizeReductions.scala
+++ b/src/main/scala/firrtl/transforms/LegalizeReductions.scala
@@ -6,17 +6,16 @@ import firrtl.Mappers._
import firrtl.options.Dependency
import firrtl.Utils.BoolType
-
object LegalizeAndReductionsTransform {
private def allOnesOfType(tpe: Type): Literal = tpe match {
case UIntType(width @ IntWidth(x)) => UIntLiteral((BigInt(1) << x.toInt) - 1, width)
- case SIntType(width) => SIntLiteral(-1, width)
+ case SIntType(width) => SIntLiteral(-1, width)
}
def onExpr(expr: Expression): Expression = expr.map(onExpr) match {
- case DoPrim(PrimOps.Andr, Seq(arg), _,_) if bitWidth(arg.tpe) > 64 =>
+ case DoPrim(PrimOps.Andr, Seq(arg), _, _) if bitWidth(arg.tpe) > 64 =>
DoPrim(PrimOps.Eq, Seq(arg, allOnesOfType(arg.tpe)), Seq(), BoolType)
case other => other
}
@@ -35,8 +34,7 @@ class LegalizeAndReductionsTransform extends Transform with DependencyAPIMigrati
override def prerequisites =
firrtl.stage.Forms.WorkingIR ++
- Seq( Dependency(passes.CheckTypes),
- Dependency(passes.CheckWidths))
+ Seq(Dependency(passes.CheckTypes), Dependency(passes.CheckWidths))
override def optionalPrerequisites = Nil
diff --git a/src/main/scala/firrtl/transforms/ManipulateNames.scala b/src/main/scala/firrtl/transforms/ManipulateNames.scala
index f15b546f..d0b12e66 100644
--- a/src/main/scala/firrtl/transforms/ManipulateNames.scala
+++ b/src/main/scala/firrtl/transforms/ManipulateNames.scala
@@ -57,8 +57,9 @@ sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[_]] extends Mult
* @note $noteLocalTargets
*/
case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]](
- targets: Seq[Seq[Target]],
- transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] {
+ targets: Seq[Seq[Target]],
+ transform: Dependency[A])
+ extends ManipulateNamesListAnnotation[A] {
override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a)
@@ -77,8 +78,9 @@ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]](
* @note $noteLocalTargets
*/
case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]](
- targets: Seq[Seq[Target]],
- transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] {
+ targets: Seq[Seq[Target]],
+ transform: Dependency[A])
+ extends ManipulateNamesListAnnotation[A] {
override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a)
@@ -94,19 +96,21 @@ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]](
* @param oldTargets the old targets
*/
case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]](
- targets: Seq[Seq[Target]],
- transform: Dependency[A],
- oldTargets: Seq[Seq[Target]]) extends MultiTargetAnnotation {
+ targets: Seq[Seq[Target]],
+ transform: Dependency[A],
+ oldTargets: Seq[Seq[Target]])
+ extends MultiTargetAnnotation {
override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a)
override def update(renames: RenameMap) = {
val (targetsx, oldTargetsx) = targets.zip(oldTargets).foldLeft((Seq.empty[Seq[Target]], Seq.empty[Seq[Target]])) {
- case ((accT, accO), (t, o)) => t.flatMap(renames(_)) match {
- /* If the target was deleted, delete the old target */
- case tx if tx.isEmpty => (accT, accO)
- case tx => (Seq(tx) ++ accT, Seq(o) ++ accO)
- }
+ case ((accT, accO), (t, o)) =>
+ t.flatMap(renames(_)) match {
+ /* If the target was deleted, delete the old target */
+ case tx if tx.isEmpty => (accT, accO)
+ case tx => (Seq(tx) ++ accT, Seq(o) ++ accO)
+ }
}
targetsx match {
/* If all targets were deleted, delete the annotation */
@@ -117,9 +121,13 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]](
/** Return [[firrtl.RenameMap RenameMap]] from old targets to new targets */
def toRenameMap: RenameMap = {
- val m = oldTargets.zip(targets).flatMap {
- case (a, b) => a.map(_ -> b)
- }.toMap.asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]]
+ val m = oldTargets
+ .zip(targets)
+ .flatMap {
+ case (a, b) => a.map(_ -> b)
+ }
+ .toMap
+ .asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]]
RenameMap.create(m)
}
@@ -132,25 +140,28 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]](
* @param allow a function that returns true if a [[firrtl.annotations.Target Target]] should be renamed
*/
private class RenameDataStructure(
- circuit: ir.Circuit,
+ circuit: ir.Circuit,
val renames: RenameMap,
- val block: Target => Boolean,
- val allow: Target => Boolean) {
+ val block: Target => Boolean,
+ val allow: Target => Boolean) {
/** A mapping of targets to associated namespaces */
val namespaces: mutable.HashMap[CompleteTarget, Namespace] =
mutable.HashMap(CircuitTarget(circuit.main) -> Namespace(circuit))
- /** Wraps a HashMap to provide better error messages when accessing a non-existing element */
+ /** Wraps a HashMap to provide better error messages when accessing a non-existing element */
class InstanceHashMap {
type Key = ReferenceTarget
type Value = Either[ReferenceTarget, InstanceTarget]
private val m = mutable.HashMap[Key, Value]()
- def apply(key: ReferenceTarget): Value = m.getOrElse(key, {
- throw new FirrtlUserException(
- s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems.
- | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin)
- })
+ def apply(key: ReferenceTarget): Value = m.getOrElse(
+ key, {
+ throw new FirrtlUserException(
+ s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems.
+ | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin
+ )
+ }
+ )
def update(key: Key, value: Value): Unit = m.update(key, value)
}
@@ -165,17 +176,17 @@ private class RenameDataStructure(
/** Transform for manipulate all the names in a FIRRTL circuit.
* @tparam A the type of the child transform
*/
-abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Transform with DependencyAPIMigration {
+abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transform with DependencyAPIMigration {
/** A function used to manipulate a name in a FIRRTL circuit */
def manipulate: (String, Namespace) => Option[String]
- override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes))
- override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
+ override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes))
+ override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisiteOf: Seq[TransformDependency] = Forms.LowEmitters
override def invalidates(a: Transform) = a match {
case _: analyses.GetNamespace => true
- case _ => false
+ case _ => false
}
/** Compute a new name for some target and record the rename if the new name differs. If the top module or the circuit
@@ -192,27 +203,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
case a if r.skip(a) =>
(name, None)
/* Circuit renaming */
- case a@ CircuitTarget(b) => manipulate(b, r.namespaces(a)) match {
- case Some(str) => (str, Some(a.copy(circuit = str)))
- case None => (b, None)
- }
+ case a @ CircuitTarget(b) =>
+ manipulate(b, r.namespaces(a)) match {
+ case Some(str) => (str, Some(a.copy(circuit = str)))
+ case None => (b, None)
+ }
/* Module renaming for non-top modules */
- case a@ ModuleTarget(_, b) => manipulate(b, r.namespaces(a.circuitTarget)) match {
- case Some(str) => (str, Some(a.copy(module = str)))
- case None => (b, None)
- }
+ case a @ ModuleTarget(_, b) =>
+ manipulate(b, r.namespaces(a.circuitTarget)) match {
+ case Some(str) => (str, Some(a.copy(module = str)))
+ case None => (b, None)
+ }
/* Instance renaming */
- case a@ InstanceTarget(_, _, Nil, b, c) => manipulate(b, r.namespaces(a.moduleTarget)) match {
- case Some(str) => (str, Some(a.copy(instance = str)))
- case None => (b, None)
- }
+ case a @ InstanceTarget(_, _, Nil, b, c) =>
+ manipulate(b, r.namespaces(a.moduleTarget)) match {
+ case Some(str) => (str, Some(a.copy(instance = str)))
+ case None => (b, None)
+ }
/* Rename either a module component or a memory */
- case a@ ReferenceTarget(_, _, _, b, Nil) => manipulate(b, r.namespaces(a.moduleTarget)) match {
- case Some(str) => (str, Some(a.copy(ref = str)))
- case None => (b, None)
- }
+ case a @ ReferenceTarget(_, _, _, b, Nil) =>
+ manipulate(b, r.namespaces(a.moduleTarget)) match {
+ case Some(str) => (str, Some(a.copy(ref = str)))
+ case None => (b, None)
+ }
/* Rename an instance port or a memory reader/writer/readwriter */
- case a@ ReferenceTarget(_, _, _, b, (token@ TargetToken.Field(c)) :: Nil) =>
+ case a @ ReferenceTarget(_, _, _, b, (token @ TargetToken.Field(c)) :: Nil) =>
val ref = r.instanceMap(a.moduleTarget.ref(b)) match {
case Right(inst) => inst.ofModuleTarget
case Left(mem) => mem
@@ -224,8 +239,8 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
}
/* Record the optional rename. If the circuit was renamed, also rename the top module. If the top module was
* renamed, also rename the circuit. */
- ax.foreach(
- axx => target match {
+ ax.foreach(axx =>
+ target match {
case c: CircuitTarget =>
r.renames.rename(target, r.renames(axx))
r.renames.rename(c.module(c.circuit), CircuitTarget(namex).module(namex))
@@ -252,21 +267,26 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
r.renames.underlying.get(t) match {
case Some(ax) if ax.size == 1 =>
ax match {
- case Seq(foo: CircuitTarget) => foo.name
- case Seq(foo: ModuleTarget) => foo.module
- case Seq(foo: InstanceTarget) => foo.instance
- case Seq(foo: ReferenceTarget) => foo.tokens.last match {
- case TargetToken.Ref(value) => value
- case TargetToken.Field(value) => value
- case _ => Utils.throwInternalError(
- s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field'
+ case Seq(foo: CircuitTarget) => foo.name
+ case Seq(foo: ModuleTarget) => foo.module
+ case Seq(foo: InstanceTarget) => foo.instance
+ case Seq(foo: ReferenceTarget) =>
+ foo.tokens.last match {
+ case TargetToken.Ref(value) => value
+ case TargetToken.Field(value) => value
+ case _ =>
+ Utils.throwInternalError(
+ s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field'
| This is indicative of a circuit that has not been run through LowerTypes.""",
- Some(new MatchError(foo.serialize)))
- }
+ Some(new MatchError(foo.serialize))
+ )
+ }
}
- case s@ Some(ax) => Utils.throwInternalError(
- s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""",
- Some(new MatchError(s)))
+ case s @ Some(ax) =>
+ Utils.throwInternalError(
+ s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""",
+ Some(new MatchError(s))
+ )
case None => name
}
@@ -280,27 +300,34 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
/* A reference to something inside this module */
case w: WRef => w.copy(name = maybeRename(w.name, r, Target.asTarget(t)(w)))
/* This is either the subfield of an instance or a subfield of a memory reader/writer/readwriter */
- case w@ WSubField(expr, ref, _, _) => expr match {
- /* This is an instance */
- case we@ WRef(inst, _, _, _) =>
- val tx = Target.asTarget(t)(we)
- val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match {
- case Right(a) => (a.ofModuleTarget.ref(ref), a)
- case a@ Left(ref) => throw new FirrtlUserException(
- s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'.
- | This is indicative of a circuit that has not been run through LowerTypes.""", new MatchError(a))
- }
- w.copy(we.copy(name=maybeRename(inst, r, iTarget)), name=maybeRename(ref, r, rTarget))
- /* This is a reader/writer/readwriter */
- case ws@ WSubField(expr, port, _, _) => expr match {
- /* This is the memory. */
- case wr@ WRef(mem, _, _, _) =>
- w.copy(
- expr=ws.copy(
- expr=wr.copy(name=maybeRename(mem, r, t.ref(mem))),
- name=maybeRename(port, r, t.ref(mem).field(port))))
+ case w @ WSubField(expr, ref, _, _) =>
+ expr match {
+ /* This is an instance */
+ case we @ WRef(inst, _, _, _) =>
+ val tx = Target.asTarget(t)(we)
+ val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match {
+ case Right(a) => (a.ofModuleTarget.ref(ref), a)
+ case a @ Left(ref) =>
+ throw new FirrtlUserException(
+ s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'.
+ | This is indicative of a circuit that has not been run through LowerTypes.""",
+ new MatchError(a)
+ )
+ }
+ w.copy(we.copy(name = maybeRename(inst, r, iTarget)), name = maybeRename(ref, r, rTarget))
+ /* This is a reader/writer/readwriter */
+ case ws @ WSubField(expr, port, _, _) =>
+ expr match {
+ /* This is the memory. */
+ case wr @ WRef(mem, _, _, _) =>
+ w.copy(
+ expr = ws.copy(
+ expr = wr.copy(name = maybeRename(mem, r, t.ref(mem))),
+ name = maybeRename(port, r, t.ref(mem).field(port))
+ )
+ )
+ }
}
- }
case e => e.map(onExpression(_: ir.Expression, r, t))
}
@@ -310,30 +337,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
* and readwriters.
*/
private def onStatement(s: ir.Statement, r: RenameDataStructure, t: ModuleTarget): ir.Statement = s match {
- case decl: ir.IsDeclaration => decl match {
- case decl@ WDefInstance(_, inst, mod, _) =>
- val modx = maybeRename(mod, r, t.circuitTarget.module(mod))
- val instx = doRename(inst, r, t.instOf(inst, mod))
- r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod))
- decl.copy(name = instx, module = modx)
- case decl: ir.DefMemory =>
- val namex = doRename(decl.name, r, t.ref(decl.name))
- val tx = t.ref(decl.name)
- r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters)
- r.instanceMap(tx) = Left(tx)
- decl
- .copy(
- name = namex,
- readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))),
- writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))),
- readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw)))
- )
- .map(onExpression(_: ir.Expression, r, t))
- case decl =>
- decl
- .map(doRename(_: String, r, t.ref(decl.name)))
- .map(onExpression(_: ir.Expression, r, t))
- }
+ case decl: ir.IsDeclaration =>
+ decl match {
+ case decl @ WDefInstance(_, inst, mod, _) =>
+ val modx = maybeRename(mod, r, t.circuitTarget.module(mod))
+ val instx = doRename(inst, r, t.instOf(inst, mod))
+ r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod))
+ decl.copy(name = instx, module = modx)
+ case decl: ir.DefMemory =>
+ val namex = doRename(decl.name, r, t.ref(decl.name))
+ val tx = t.ref(decl.name)
+ r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters)
+ r.instanceMap(tx) = Left(tx)
+ decl
+ .copy(
+ name = namex,
+ readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))),
+ writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))),
+ readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw)))
+ )
+ .map(onExpression(_: ir.Expression, r, t))
+ case decl =>
+ decl
+ .map(doRename(_: String, r, t.ref(decl.name)))
+ .map(onExpression(_: ir.Expression, r, t))
+ }
case s =>
s
.map(onStatement(_: ir.Statement, r, t))
@@ -362,7 +390,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
*/
val onName: String => String = t.circuit match {
case `main` => maybeRename(_, r, moduleTarget)
- case _ => doRename(_, r, moduleTarget)
+ case _ => doRename(_, r, moduleTarget)
}
m
@@ -380,11 +408,11 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
* @return the circuit with manipulated names
*/
def run(
- c: ir.Circuit,
+ c: ir.Circuit,
renames: RenameMap,
- block: Target => Boolean,
- allow: Target => Boolean)
- : ir.Circuit = {
+ block: Target => Boolean,
+ allow: Target => Boolean
+ ): ir.Circuit = {
val t = CircuitTarget(c.main)
/* If the circuit is a skip, return the original circuit. Otherwise, walk all the modules and rename them. Rename the
@@ -427,8 +455,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
.toMap
/* Replace the old modules making sure that they are still in the same order */
- c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))),
- main = mainx)
+ c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))), main = mainx)
}
}
@@ -436,18 +463,20 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
def execute(state: CircuitState): CircuitState = {
val block = state.annotations.collect {
- case ManipulateNamesBlocklistAnnotation(targetSeq, t) => t.getObject match {
- case _: A => targetSeq
- case _ => Nil
- }
+ case ManipulateNamesBlocklistAnnotation(targetSeq, t) =>
+ t.getObject match {
+ case _: A => targetSeq
+ case _ => Nil
+ }
}.flatten.flatten.toSet
val allow = {
val allowx = state.annotations.collect {
- case ManipulateNamesAllowlistAnnotation(targetSeq, t) => t.getObject match {
- case _: A => targetSeq
- case _ => Nil
- }
+ case ManipulateNamesAllowlistAnnotation(targetSeq, t) =>
+ t.getObject match {
+ case _: A => targetSeq
+ case _ => Nil
+ }
}.flatten.flatten
allowx match {
@@ -461,17 +490,19 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
val annotationsx = state.annotations.flatMap {
/* Consume blocklist annotations */
- case foo@ ManipulateNamesBlocklistAnnotation(_, t) => t.getObject match {
- case _: A => None
- case _ => Some(foo)
- }
+ case foo @ ManipulateNamesBlocklistAnnotation(_, t) =>
+ t.getObject match {
+ case _: A => None
+ case _ => Some(foo)
+ }
/* Convert allowlist annotations to result annotations */
- case foo@ ManipulateNamesAllowlistAnnotation(a, t) =>
+ case foo @ ManipulateNamesAllowlistAnnotation(a, t) =>
t.getObject match {
- case _: A => (a, a.map(_.map(renames(_)).flatten)) match {
- case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a))
- }
- case _ => Some(foo)
+ case _: A =>
+ (a, a.map(_.map(renames(_)).flatten)) match {
+ case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a))
+ }
+ case _ => Some(foo)
}
case a => Some(a)
}
diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
index ff44afec..5532d0f0 100644
--- a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
+++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
@@ -1,4 +1,3 @@
-
package firrtl
package transforms
@@ -34,17 +33,19 @@ trait DontTouchAllTargets extends HasDontTouches { self: Annotation =>
* DCE treats the component as a top-level sink of the circuit
*/
case class DontTouchAnnotation(target: ReferenceTarget)
- extends SingleTargetAnnotation[ReferenceTarget] with DontTouchAllTargets {
+ extends SingleTargetAnnotation[ReferenceTarget]
+ with DontTouchAllTargets {
def targets = Seq(target)
def duplicate(n: ReferenceTarget) = this.copy(n)
}
object DontTouchAnnotation {
- class DontTouchNotFoundException(module: String, component: String) extends PassException(
- s"""|Target marked dontTouch ($module.$component) not found!
- |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then
- |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin
- )
+ class DontTouchNotFoundException(module: String, component: String)
+ extends PassException(
+ s"""|Target marked dontTouch ($module.$component) not found!
+ |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then
+ |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin
+ )
def errorNotFound(module: String, component: String) =
throw new DontTouchNotFoundException(module, component)
@@ -58,7 +59,6 @@ object DontTouchAnnotation {
*
* @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted
*/
-case class OptimizableExtModuleAnnotation(target: ModuleName) extends
- SingleTargetAnnotation[ModuleName] {
+case class OptimizableExtModuleAnnotation(target: ModuleName) extends SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(n)
}
diff --git a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala
index da803837..97db0219 100644
--- a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala
+++ b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala
@@ -11,8 +11,10 @@ import firrtl.options.Dependency
import scala.collection.mutable
object PropagatePresetAnnotations {
- val advice = "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset."
- case class TreeCleanUpOrphanException(message: String) extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice")
+ val advice =
+ "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset."
+ case class TreeCleanUpOrphanException(message: String)
+ extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice")
}
/** Propagate PresetAnnotations to all children of targeted AsyncResets
@@ -39,9 +41,11 @@ object PropagatePresetAnnotations {
class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic])
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -52,7 +56,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
import PropagatePresetAnnotations._
private type TargetSet = mutable.HashSet[ReferenceTarget]
- private type TargetMap = mutable.HashMap[ReferenceTarget,String]
+ private type TargetMap = mutable.HashMap[ReferenceTarget, String]
private type TargetSetMap = mutable.HashMap[ReferenceTarget, TargetSet]
private val toCleanUp = new TargetSet()
@@ -71,7 +75,11 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
* @param presetAnnos all the annotations
* @return updated annotations
*/
- private def propagate(cs: CircuitState, presetAnnos: Seq[PresetAnnotation], otherAnnos: Seq[Annotation]): AnnotationSeq = {
+ private def propagate(
+ cs: CircuitState,
+ presetAnnos: Seq[PresetAnnotation],
+ otherAnnos: Seq[Annotation]
+ ): AnnotationSeq = {
val presets = presetAnnos.groupBy(_.target)
// store all annotated asyncreset references
val asyncToAnnotate = new TargetSet()
@@ -85,34 +93,34 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
val circuitTarget = CircuitTarget(cs.circuit.main)
/*
- * WALK I PHASE 1 FUNCTIONS
- */
+ * WALK I PHASE 1 FUNCTIONS
+ */
/* Walk current module
- * - process ports
- * - store connections & entry points for PHASE 2
- * - process statements
- * - Instances => record local instances for cross module AsyncReset Tree Buidling
- * - Registers => store AsyncReset bound registers for PHASE 2
- * - Wire => store AsyncReset Connections & entry points for PHASE 2
- * - Connect => store AsyncReset Connections & entry points for PHASE 2
- *
- * @param m module
- */
+ * - process ports
+ * - store connections & entry points for PHASE 2
+ * - process statements
+ * - Instances => record local instances for cross module AsyncReset Tree Buidling
+ * - Registers => store AsyncReset bound registers for PHASE 2
+ * - Wire => store AsyncReset Connections & entry points for PHASE 2
+ * - Connect => store AsyncReset Connections & entry points for PHASE 2
+ *
+ * @param m module
+ */
def processModule(m: DefModule): Unit = {
val moduleTarget = circuitTarget.module(m.name)
val localInstances = new TargetMap()
/* Recursively process a given type
- * Recursive on Bundle and Vector Type only
- * Store Register and Connections for AsyncResetType
- * @param tpe [[Type]] to be processed
- * @param target [[ReferenceTarget]] associated to the tpe
- * @param all Boolean indicating whether all subelements of the current
- * tpe should also be stored as Annotated AsyncReset entry points
- */
+ * Recursive on Bundle and Vector Type only
+ * Store Register and Connections for AsyncResetType
+ * @param tpe [[Type]] to be processed
+ * @param target [[ReferenceTarget]] associated to the tpe
+ * @param all Boolean indicating whether all subelements of the current
+ * tpe should also be stored as Annotated AsyncReset entry points
+ */
def processType(tpe: Type, target: ReferenceTarget, all: Boolean): Unit = {
- if(tpe == AsyncResetType){
+ if (tpe == AsyncResetType) {
asyncRegMap(target) = new TargetSet()
asyncCoMap(target) = new TargetSet()
if (presets.contains(target) || all) {
@@ -121,14 +129,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
} else {
tpe match {
case b: BundleType =>
- b.fields.foreach{
- (x: Field) =>
- val tar = target.field(x.name)
- processType(x.tpe, tar, (presets.contains(tar) || all))
+ b.fields.foreach { (x: Field) =>
+ val tar = target.field(x.name)
+ processType(x.tpe, tar, (presets.contains(tar) || all))
}
case v: VectorType =>
- for(i <- 0 until v.size) {
+ for (i <- 0 until v.size) {
val tar = target.index(i)
processType(v.tpe, tar, (presets.contains(tar) || all))
}
@@ -143,19 +150,19 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
/* Recursively search for the ReferenceTarget of a given Expression
- * @param e Targeted Expression
- * @param ta Local ReferenceTarget of the Targeted Expression
- * @return a ReferenceTarget in case of success, a GenericTarget otherwise
- * @throws [[InternalError]] on unexpected recursive path return results
- */
- def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false) : Target = {
+ * @param e Targeted Expression
+ * @param ta Local ReferenceTarget of the Targeted Expression
+ * @return a ReferenceTarget in case of success, a GenericTarget otherwise
+ * @throws [[InternalError]] on unexpected recursive path return results
+ */
+ def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false): Target = {
e match {
case w: WRef => moduleTarget.ref(w.name)
case w: WSubField =>
getRef(w.expr, ta, annoCo) match {
case rt: ReferenceTarget =>
- if(localInstances.contains(rt)){
- val remote_ref = circuitTarget.module(localInstances(rt))
+ if (localInstances.contains(rt)) {
+ val remote_ref = circuitTarget.module(localInstances(rt))
if (annoCo)
asyncCoMap(ta) += rt.field(w.name)
remote_ref.ref(w.name)
@@ -163,7 +170,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
rt.field(w.name)
}
case remote_target => remote_target
- }
+ }
case w: WSubIndex =>
getRef(w.expr, ta, annoCo) match {
case remote_target: ReferenceTarget =>
@@ -179,7 +186,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def processRegister(r: DefRegister): Unit = {
getRef(r.reset, moduleTarget.ref(r.name), false) match {
- case rt : ReferenceTarget =>
+ case rt: ReferenceTarget =>
if (asyncRegMap.contains(rt)) {
asyncRegMap(rt) += moduleTarget.ref(r.name)
}
@@ -189,12 +196,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
def processConnect(c: Connect): Unit = {
- getRef(c.expr, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match {
+ getRef(c.expr, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match {
case rhs: ReferenceTarget =>
if (presets.contains(rhs) || asyncRegMap.contains(rhs)) {
getRef(c.loc, rhs, true) match {
- case lhs : ReferenceTarget =>
- if(asyncRegMap.contains(rhs)){
+ case lhs: ReferenceTarget =>
+ if (asyncRegMap.contains(rhs)) {
asyncRegMap(rhs) += lhs
} else {
asyncToAnnotate += lhs
@@ -211,10 +218,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
val target = moduleTarget.ref(n.name)
processType(n.value.tpe, target, presets.contains(target))
- getRef(n.value, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match {
+ getRef(n.value, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match {
case rhs: ReferenceTarget =>
if (presets.contains(rhs) || asyncRegMap.contains(rhs)) {
- if(asyncRegMap.contains(rhs)){
+ if (asyncRegMap.contains(rhs)) {
asyncRegMap(rhs) += target
} else {
asyncToAnnotate += target
@@ -227,18 +234,18 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def processStatements(statement: Statement): Unit = {
statement match {
- case i : WDefInstance =>
+ case i: WDefInstance =>
localInstances(moduleTarget.ref(i.name)) = i.module
- case r : DefRegister => processRegister(r)
- case w : DefWire => processWire(w)
- case n : DefNode => processNode(n)
- case c : Connect => processConnect(c)
- case s => s.foreachStmt(processStatements)
+ case r: DefRegister => processRegister(r)
+ case w: DefWire => processWire(w)
+ case n: DefNode => processNode(n)
+ case c: Connect => processConnect(c)
+ case s => s.foreachStmt(processStatements)
}
}
def processPorts(port: Port): Unit = {
- if(port.tpe == AsyncResetType){
+ if (port.tpe == AsyncResetType) {
val target = moduleTarget.ref(port.name)
asyncRegMap(target) = new TargetSet()
asyncCoMap(target) = new TargetSet()
@@ -263,17 +270,17 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
/** Annotate a given target and all its children according to the asyncCoMap */
def annotateCo(ta: ReferenceTarget): Unit = {
- if (asyncCoMap.contains(ta)){
+ if (asyncCoMap.contains(ta)) {
toCleanUp += ta
- asyncCoMap(ta) foreach( (t: ReferenceTarget) => {
+ asyncCoMap(ta).foreach((t: ReferenceTarget) => {
toCleanUp += t
})
}
}
/** Annotate all registers somehow connected to the orignal annotated async reset */
- def annotateRegSet(set: TargetSet) : Unit = {
- set foreach ( (ta: ReferenceTarget) => {
+ def annotateRegSet(set: TargetSet): Unit = {
+ set.foreach((ta: ReferenceTarget) => {
annotateCo(ta)
if (asyncRegMap.contains(ta)) {
annotateRegSet(asyncRegMap(ta))
@@ -287,8 +294,8 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
* Walk AsyncReset Trees with all Annotated AsyncReset as entry points
* Annotate all leaf registers and intermediate wires, nodes, connectors along the way
*/
- def annotateAsyncSet(set: TargetSet) : Unit = {
- set foreach ((t: ReferenceTarget) => {
+ def annotateAsyncSet(set: TargetSet): Unit = {
+ set.foreach((t: ReferenceTarget) => {
annotateCo(t)
if (asyncRegMap.contains(t))
annotateRegSet(asyncRegMap(t))
@@ -300,7 +307,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
*/
cs.circuit.foreachModule(processModule) // PHASE 1 : Initialize
- annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate
+ annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate
otherAnnos ++ newAnnos
}
@@ -312,21 +319,21 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
* Clean-up useless reset tree (not relying on DCE)
* Disconnect preset registers from their reset tree
*/
- private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq) : Circuit = {
- val presetRegs = annos.collect {case a : PresetRegAnnotation => a}.groupBy(_.target)
+ private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq): Circuit = {
+ val presetRegs = annos.collect { case a: PresetRegAnnotation => a }.groupBy(_.target)
val circuitTarget = CircuitTarget(circuit.main)
def processModule(m: DefModule): DefModule = {
val moduleTarget = circuitTarget.module(m.name)
val localInstances = new TargetMap()
- def getRef(e: Expression) : Target = {
+ def getRef(e: Expression): Target = {
e match {
case w: WRef => moduleTarget.ref(w.name)
case w: WSubField =>
getRef(w.expr) match {
case rt: ReferenceTarget =>
- if(localInstances.contains(rt)){
+ if (localInstances.contains(rt)) {
circuitTarget.module(localInstances(rt)).ref(w.name)
} else {
rt.field(w.name)
@@ -341,14 +348,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
case DoPrim(op, args, _, _) =>
op match {
case AsInterval | AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => getRef(args.head)
- case _ => Target(None, None, Seq.empty)
+ case _ => Target(None, None, Seq.empty)
}
case _ => Target(None, None, Seq.empty)
}
}
-
- def processRegister(r: DefRegister) : DefRegister = {
+ def processRegister(r: DefRegister): DefRegister = {
if (presetRegs.contains(moduleTarget.ref(r.name))) {
r.copy(reset = UIntLiteral(0))
} else {
@@ -356,7 +362,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
}
- def processWire(w: DefWire) : Statement = {
+ def processWire(w: DefWire): Statement = {
if (toCleanUp.contains(moduleTarget.ref(w.name))) {
EmptyStmt
} else {
@@ -364,12 +370,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
}
- def processNode(n: DefNode) : Statement = {
+ def processNode(n: DefNode): Statement = {
if (toCleanUp.contains(moduleTarget.ref(n.name))) {
EmptyStmt
} else {
getRef(n.value) match {
- case rt : ReferenceTarget if(toCleanUp.contains(rt)) =>
+ case rt: ReferenceTarget if (toCleanUp.contains(rt)) =>
throw TreeCleanUpOrphanException(s"Orphan (${moduleTarget.ref(n.name)}) the way.")
case _ => n
}
@@ -380,7 +386,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
getRef(c.expr) match {
case rhs: ReferenceTarget if (toCleanUp.contains(rhs)) =>
getRef(c.loc) match {
- case lhs : ReferenceTarget if(!toCleanUp.contains(lhs)) =>
+ case lhs: ReferenceTarget if (!toCleanUp.contains(lhs)) =>
throw TreeCleanUpOrphanException(s"Orphan ${lhs} connected deleted node $rhs.")
case _ => EmptyStmt
}
@@ -388,7 +394,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
}
- def processInstance(i: WDefInstance) : WDefInstance = {
+ def processInstance(i: WDefInstance): WDefInstance = {
localInstances(moduleTarget.ref(i.name)) = i.module
val tpe = i.tpe match {
case b: BundleType =>
@@ -401,12 +407,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def processStatements(statement: Statement): Statement = {
statement match {
- case i : WDefInstance => processInstance(i)
- case r : DefRegister => processRegister(r)
- case w : DefWire => processWire(w)
- case n : DefNode => processNode(n)
- case c : Connect => processConnect(c)
- case s => s.mapStmt(processStatements)
+ case i: WDefInstance => processInstance(i)
+ case r: DefRegister => processRegister(r)
+ case w: DefWire => processWire(w)
+ case n: DefNode => processNode(n)
+ case c: Connect => processConnect(c)
+ case s => s.mapStmt(processStatements)
}
}
@@ -422,10 +428,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def execute(state: CircuitState): CircuitState = {
// Collect all user-defined PresetAnnotation
- val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true ; case _ => false }
+ val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true; case _ => false }
// No PresetAnnotation => no need to walk the IR
- if (presets.isEmpty){
+ if (presets.isEmpty) {
state
} else {
// PHASE I - Propagate
diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala
index 840a3d99..ae3bc693 100644
--- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala
+++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala
@@ -21,10 +21,11 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames {
* @return Some name if a rename occurred, None otherwise
* @note prefix uniqueness is not respected
*/
- override def manipulate = (n: String, ns: Namespace) => keywords.contains(n) match {
- case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords))
- case false => None
- }
+ override def manipulate = (n: String, ns: Namespace) =>
+ keywords.contains(n) match {
+ case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords))
+ case false => None
+ }
}
@@ -32,14 +33,16 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames {
class VerilogRename extends RemoveKeywordCollisions(v_keywords) {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[InlineCastsTransform],
- Dependency[LegalizeClocksTransform],
- Dependency[FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup) )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[InlineCastsTransform],
+ Dependency[LegalizeClocksTransform],
+ Dependency[FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup)
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index 6b3a9d07..8736e21b 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -18,8 +18,7 @@ import scala.collection.{immutable, mutable}
object RemoveReset extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize) )
+ Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize))
override def optionalPrerequisites = Seq.empty
@@ -58,7 +57,7 @@ object RemoveReset extends Transform with DependencyAPIMigration {
reg.copy(reset = Utils.zero, init = WRef(reg))
case reg @ DefRegister(_, rname, _, _, Utils.zero, _) =>
reg.copy(init = WRef(reg)) // canonicalize
- case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType =>
+ case reg @ DefRegister(info, rname, _, _, reset, init) if reset.tpe != AsyncResetType =>
// Add register reset to map
resets(rname) = Reset(reset, init, info)
reg.copy(reset = Utils.zero, init = WRef(reg))
@@ -68,7 +67,7 @@ object RemoveReset extends Transform with DependencyAPIMigration {
// Use reg source locator for mux enable and true value since that's where they're defined
val infox = MultiInfo(reset.info, reset.info, info)
Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType))
- case other => other map onStmt
+ case other => other.map(onStmt)
}
}
m.map(onStmt)
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index f692e513..31fa3b6f 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -8,11 +8,11 @@ import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
-import firrtl.graph.{MutableDiGraph, CyclicException}
+import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import scala.collection.mutable
-import scala.util.{Try, Success, Failure}
+import scala.util.{Failure, Success, Try}
/** Replace wires with nodes in a legal, flow-forward order
*
@@ -23,11 +23,13 @@ import scala.util.{Try, Success, Failure}
class RemoveWires extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
- Dependency(passes.ResolveKinds),
- Dependency(transforms.RemoveReset),
- Dependency[transforms.CheckCombLoops] )
+ Seq(
+ Dependency(passes.LowerTypes),
+ Dependency(passes.Legalize),
+ Dependency(passes.ResolveKinds),
+ Dependency(transforms.RemoveReset),
+ Dependency[transforms.CheckCombLoops]
+ )
override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])
@@ -35,7 +37,7 @@ class RemoveWires extends Transform with DependencyAPIMigration {
override def invalidates(a: Transform) = a match {
case passes.ResolveKinds => true
- case _ => false
+ case _ => false
}
// Extract all expressions that are references to a Node, Wire, or Reg
@@ -44,7 +46,7 @@ class RemoveWires extends Transform with DependencyAPIMigration {
val refs = mutable.ArrayBuffer.empty[WRef]
def rec(e: Expression): Expression = {
e match {
- case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref
+ case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref
case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec)
case _ => // Do nothing
}
@@ -57,7 +59,8 @@ class RemoveWires extends Transform with DependencyAPIMigration {
// Transform netlist into DefNodes
private def getOrderedNodes(
netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
- regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = {
+ regInfo: mutable.Map[WrappedExpression, DefRegister]
+ ): Try[Seq[Statement]] = {
val digraph = new MutableDiGraph[WrappedExpression]
for ((sink, (exprs, _)) <- netlist) {
digraph.addVertex(sink)
@@ -106,21 +109,22 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case reg: DefRegister =>
val resetDep = reg.reset.tpe match {
case AsyncResetType => Some(reg.reset)
- case _ => None
+ case _ => None
}
val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself
regInfo(we(WRef(reg))) = reg
netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info)
case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires
decls += decl
- case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match {
- case WireKind =>
- // Be sure to pad the rhs since nodes get their type from the rhs
- val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
- val dinfo = wireInfo(lhs)
- netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
- case _ => otherStmts += con // Other connections just pass through
- }
+ case con @ Connect(cinfo, lhs, rhs) =>
+ kind(lhs) match {
+ case WireKind =>
+ // Be sure to pad the rhs since nodes get their type from the rhs
+ val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
+ val dinfo = wireInfo(lhs)
+ netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
+ case _ => otherStmts += con // Other connections just pass through
+ }
case invalid @ IsInvalid(info, expr) =>
kind(expr) match {
case WireKind =>
@@ -146,8 +150,10 @@ class RemoveWires extends Transform with DependencyAPIMigration {
// If we hit a CyclicException, just abort removing wires
case Failure(c: CyclicException) =>
val problematicNode = c.node
- logger.warn(s"Cycle found in module $name, " +
- s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode")
+ logger.warn(
+ s"Cycle found in module $name, " +
+ s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode"
+ )
mod
case Failure(other) => throw other
}
@@ -155,7 +161,6 @@ class RemoveWires extends Transform with DependencyAPIMigration {
}
}
-
def execute(state: CircuitState): CircuitState =
state.copy(circuit = state.circuit.map(onModule))
}
diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala
index d37f8c39..16fd655a 100644
--- a/src/main/scala/firrtl/transforms/RenameModules.scala
+++ b/src/main/scala/firrtl/transforms/RenameModules.scala
@@ -44,7 +44,7 @@ class RenameModules extends Transform with DependencyAPIMigration {
moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings))
val modulesx = state.circuit.modules.map {
- case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings)
+ case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings)
case ext: ExtModule => ext
}
diff --git a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
index a93087b9..14c84b91 100644
--- a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
+++ b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
@@ -80,8 +80,7 @@ object ReplaceTruncatingArithmetic {
class ReplaceTruncatingArithmetic extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals] )
+ Seq(Dependency[BlackBoxSourceHelper], Dependency[FixAddingNegativeLiterals])
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala
index a056c7da..7790d060 100644
--- a/src/main/scala/firrtl/transforms/SimplifyMems.scala
+++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala
@@ -33,12 +33,13 @@ class SimplifyMems extends Transform with DependencyAPIMigration {
def onExpr(e: Expression): Expression = e.map(onExpr) match {
case wr @ WRef(name, _, MemKind, _) if memAdapters.contains(name) => wr.copy(kind = WireKind)
- case e => e
+ case e => e
}
def simplifyMem(mem: DefMemory): Statement = {
val adapterDecl = DefWire(mem.info, mem.name, memType(mem))
- val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
+ val simpleMemDecl =
+ mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
val oldRT = mTarget.ref(mem.name)
val adapterConnects = memType(simpleMemDecl).fields.flatMap {
case Field(pName, Flip, pType: BundleType) =>
@@ -63,8 +64,10 @@ class SimplifyMems extends Transform with DependencyAPIMigration {
def canSimplify(mem: DefMemory) = mem.dataType match {
case at: AggregateType =>
- val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
- val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
+ val wMasks =
+ mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
+ val rwMasks =
+ mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
(wMasks ++ rwMasks).flatten.isEmpty
case _ => false
}
diff --git a/src/main/scala/firrtl/transforms/TopWiring.scala b/src/main/scala/firrtl/transforms/TopWiring.scala
index f5a5e2a3..b35fed22 100644
--- a/src/main/scala/firrtl/transforms/TopWiring.scala
+++ b/src/main/scala/firrtl/transforms/TopWiring.scala
@@ -4,7 +4,7 @@ package TopWiring
import firrtl._
import firrtl.ir._
-import firrtl.passes.{InferTypes, LowerTypes, ResolveKinds, ResolveFlows, ExpandConnects}
+import firrtl.passes.{ExpandConnects, InferTypes, LowerTypes, ResolveFlows, ResolveKinds}
import firrtl.annotations._
import firrtl.Mappers._
import firrtl.analyses.InstanceKeyGraph
@@ -13,22 +13,21 @@ import firrtl.options.Dependency
import collection.mutable
-/** Annotation for optional output files, and what directory to put those files in (absolute path) **/
-case class TopWiringOutputFilesAnnotation(dirName: String,
- outputFunction: (String,Seq[((ComponentName, Type, Boolean,
- Seq[String],String), Int)],
- CircuitState) => CircuitState) extends NoTargetAnnotation
+/** Annotation for optional output files, and what directory to put those files in (absolute path) * */
+case class TopWiringOutputFilesAnnotation(
+ dirName: String,
+ outputFunction: (String, Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)],
+ CircuitState) => CircuitState)
+ extends NoTargetAnnotation
/** Annotation for indicating component to be wired, and what prefix to add to the ports that are generated */
-case class TopWiringAnnotation(target: ComponentName, prefix: String) extends
- SingleTargetAnnotation[ComponentName] {
+case class TopWiringAnnotation(target: ComponentName, prefix: String) extends SingleTargetAnnotation[ComponentName] {
def duplicate(n: ComponentName) = this.copy(target = n)
}
-
/** Punch out annotated ports out to the toplevel of the circuit.
- This also has an option to pass a function as a parmeter to generate
- custom output files as a result of the additional ports
+ * This also has an option to pass a function as a parmeter to generate
+ * custom output files as a result of the additional ports
* @note This *does* work for deduped modules
*/
class TopWiringTransform extends Transform with DependencyAPIMigration {
@@ -39,116 +38,133 @@ class TopWiringTransform extends Transform with DependencyAPIMigration {
override def invalidates(a: Transform): Boolean = a match {
case InferTypes | ResolveKinds | ResolveFlows | ExpandConnects => true
- case _ => false
+ case _ => false
}
type InstPath = Seq[String]
/** Get the names of the targets that need to be wired */
private def getSourceNames(state: CircuitState): Map[ComponentName, String] = {
- state.annotations.collect { case TopWiringAnnotation(srcname,prefix) =>
- (srcname -> prefix) }.toMap.withDefaultValue("")
+ state.annotations.collect {
+ case TopWiringAnnotation(srcname, prefix) =>
+ (srcname -> prefix)
+ }.toMap.withDefaultValue("")
}
-
/** Get the names of the modules which include the targets that need to be wired */
private def getSourceModNames(state: CircuitState): Seq[String] = {
- state.annotations.collect { case TopWiringAnnotation(ComponentName(_,ModuleName(srcmodname, _)),_) => srcmodname }
+ state.annotations.collect { case TopWiringAnnotation(ComponentName(_, ModuleName(srcmodname, _)), _) => srcmodname }
}
-
-
/** Get the Type of each wire to be connected
*
* Find the definition of each wire in sourceList, and get the type and whether or not it's a port
* Update the results in sourceMap
*/
- private def getSourceTypes(sourceList: Map[ComponentName, String],
- sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
- currentmodule: ModuleName, state: CircuitState)(s: Statement): Statement = s match {
+ private def getSourceTypes(
+ sourceList: Map[ComponentName, String],
+ sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
+ currentmodule: ModuleName,
+ state: CircuitState
+ )(s: Statement
+ ): Statement = s match {
// If target wire, add name and size to to sourceMap
case w: IsDeclaration =>
if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) {
- val (isport, tpe, prefix) = w match {
- case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
- }
- sourceMap.get(currentmodule.name) match {
- case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
- sourceMap.update(currentmodule.name, xs :+(
- (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) ))
- case None =>
- sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule),
- tpe, isport ,Seq[String](w.name), prefix))
- }
+ val (isport, tpe, prefix) = w match {
+ case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
+ }
+ sourceMap.get(currentmodule.name) match {
+ case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
+ sourceMap.update(
+ currentmodule.name,
+ xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix))
+ )
+ case None =>
+ sourceMap(currentmodule.name) = Seq(
+ (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)
+ )
+ }
}
w // Return argument unchanged (ok because DefWire has no Statement children)
// If not, apply to all children Statement
- case _ => s map getSourceTypes(sourceList, sourceMap, currentmodule, state)
+ case _ => s.map(getSourceTypes(sourceList, sourceMap, currentmodule, state))
}
-
-
/** Get the Type of each port to be connected
*
* Similar to getSourceTypes, but specifically for ports since they are not found in statements.
* Find the definition of each port in sourceList, and get the type and whether or not it's a port
* Update the results in sourceMap
*/
- private def getSourceTypesPorts(sourceList: Map[ComponentName, String], sourceMap: mutable.Map[String,
- Seq[(ComponentName, Type, Boolean, InstPath, String)]],
- currentmodule: ModuleName, state: CircuitState)(s: Port): CircuitState = s match {
+ private def getSourceTypesPorts(
+ sourceList: Map[ComponentName, String],
+ sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
+ currentmodule: ModuleName,
+ state: CircuitState
+ )(s: Port
+ ): CircuitState = s match {
// If target port, add name and size to to sourceMap
case w: IsDeclaration =>
if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) {
- val (isport, tpe, prefix) = w match {
- case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
- }
- sourceMap.get(currentmodule.name) match {
- case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
- sourceMap.update(currentmodule.name, xs :+(
- (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) ))
- case None =>
- sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule),
- tpe, isport ,Seq[String](w.name), prefix))
- }
+ val (isport, tpe, prefix) = w match {
+ case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
+ }
+ sourceMap.get(currentmodule.name) match {
+ case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
+ sourceMap.update(
+ currentmodule.name,
+ xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix))
+ )
+ case None =>
+ sourceMap(currentmodule.name) = Seq(
+ (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)
+ )
+ }
}
state // Return argument unchanged (ok because DefWire has no Statement children)
// If not, apply to all children Statement
case _ => state
}
-
/** Create a map of Module name to target wires under this module
*
* These paths are relative but cross module (they refer down through instance hierarchy)
*/
- private def getSourcesMap(state: CircuitState): Map[String,Seq[(ComponentName, Type, Boolean, InstPath, String)]] = {
+ private def getSourcesMap(state: CircuitState): Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = {
val sSourcesModNames = getSourceModNames(state)
val sSourcesNames = getSourceNames(state)
val instGraph = firrtl.analyses.InstanceKeyGraph(state.circuit)
- val cMap = instGraph.getChildInstances.map{ case (m, wdis) =>
- (m -> wdis.map{ case wdi => (wdi.name, wdi.module) }.toSeq) }.toMap
+ val cMap = instGraph.getChildInstances.map {
+ case (m, wdis) =>
+ (m -> wdis.map { case wdi => (wdi.name, wdi.module) }.toSeq)
+ }.toMap
val topSort = instGraph.moduleOrder.reverse
// Map of component name to relative instance paths that result in a debug wire
val sourcemods: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] =
mutable.Map(sSourcesModNames.map(_ -> Seq()): _*)
- state.circuit.modules.foreach { m => m map
- getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) }
- state.circuit.modules.foreach { m => m.ports.foreach {
- p => Seq(p) map
- getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) }}
+ state.circuit.modules.foreach { m =>
+ m.map(getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state))
+ }
+ state.circuit.modules.foreach { m =>
+ m.ports.foreach { p =>
+ Seq(p).map(
+ getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state)
+ )
+ }
+ }
for (mod <- topSort) {
- val seqChildren: Seq[(ComponentName,Type,Boolean,InstPath,String)] = cMap(mod.name).flatMap {
+ val seqChildren: Seq[(ComponentName, Type, Boolean, InstPath, String)] = cMap(mod.name).flatMap {
case (inst, module) =>
- sourcemods.get(module).map( _.map { case (a,b,c,path,p) => (a,b,c, inst +: path, p)})
+ sourcemods.get(module).map(_.map { case (a, b, c, path, p) => (a, b, c, inst +: path, p) })
}.flatten
if (seqChildren.nonEmpty) {
sourcemods(mod.name) = sourcemods.getOrElse(mod.name, Seq()) ++ seqChildren
@@ -158,108 +174,113 @@ class TopWiringTransform extends Transform with DependencyAPIMigration {
sourcemods.toMap
}
-
-
/** Process a given DefModule
*
* For Modules that contain or are in the parent hierarchy to modules containing target wires
* 1. Add ports for each target wire this module is parent to
* 2. Connect these ports to ports of instances that are parents to some number of target wires
*/
- private def onModule(sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
- portnamesmap : mutable.Map[String,String],
- instgraph : firrtl.analyses.InstanceKeyGraph,
- namespacemap : Map[String, Namespace])
- (module: DefModule): DefModule = {
+ private def onModule(
+ sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
+ portnamesmap: mutable.Map[String, String],
+ instgraph: firrtl.analyses.InstanceKeyGraph,
+ namespacemap: Map[String, Namespace]
+ )(module: DefModule
+ ): DefModule = {
val namespace = namespacemap(module.name)
sources.get(module.name) match {
case Some(p) =>
- val newPorts = p.map{ case (ComponentName(cname,_), tpe, _ , path, prefix) => {
- val newportname = portnamesmap.get(prefix + path.mkString("_")) match {
- case Some(pn) => pn
- case None => {
- val npn = namespace.newName(prefix + path.mkString("_"))
- portnamesmap(prefix + path.mkString("_")) = npn
- npn
- }
+ val newPorts = p.map {
+ case (ComponentName(cname, _), tpe, _, path, prefix) => {
+ val newportname = portnamesmap.get(prefix + path.mkString("_")) match {
+ case Some(pn) => pn
+ case None => {
+ val npn = namespace.newName(prefix + path.mkString("_"))
+ portnamesmap(prefix + path.mkString("_")) = npn
+ npn
}
- Port(NoInfo, newportname, Output, tpe)
- } }
+ }
+ Port(NoInfo, newportname, Output, tpe)
+ }
+ }
// Add connections to Module
val childInstances = instgraph.getChildInstances.toMap
module match {
case m: Module =>
- val connections: Seq[Connect] = p.map { case (ComponentName(cname,_), _, _ , path, prefix) =>
+ val connections: Seq[Connect] = p.map {
+ case (ComponentName(cname, _), _, _, path, prefix) =>
val modRef = portnamesmap.get(prefix + path.mkString("_")) match {
- case Some(pn) => WRef(pn)
- case None => {
- portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_"))
- WRef(portnamesmap(prefix + path.mkString("_")))
- }
+ case Some(pn) => WRef(pn)
+ case None => {
+ portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_"))
+ WRef(portnamesmap(prefix + path.mkString("_")))
+ }
}
path.size match {
- case 1 => {
- val leafRef = WRef(path.head.mkString(""))
- Connect(NoInfo, modRef, leafRef)
- }
- case _ => {
- val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match {
- case Some(ipn) => ipn
- case None => {
- val instmod = childInstances(module.name).collectFirst {
- case wdi if wdi.name == path.head => wdi.module}.get
- val instnamespace = namespacemap(instmod)
- portnamesmap(prefix + path.tail.mkString("_")) =
- instnamespace.newName(prefix + path.tail.mkString("_"))
- portnamesmap(prefix + path.tail.mkString("_"))
- }
- }
- val instRef = WSubField(WRef(path.head), instportname)
- Connect(NoInfo, modRef, instRef)
+ case 1 => {
+ val leafRef = WRef(path.head.mkString(""))
+ Connect(NoInfo, modRef, leafRef)
+ }
+ case _ => {
+ val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match {
+ case Some(ipn) => ipn
+ case None => {
+ val instmod = childInstances(module.name).collectFirst {
+ case wdi if wdi.name == path.head => wdi.module
+ }.get
+ val instnamespace = namespacemap(instmod)
+ portnamesmap(prefix + path.tail.mkString("_")) =
+ instnamespace.newName(prefix + path.tail.mkString("_"))
+ portnamesmap(prefix + path.tail.mkString("_"))
+ }
+ }
+ val instRef = WSubField(WRef(path.head), instportname)
+ Connect(NoInfo, modRef, instRef)
}
}
}
- m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections ))
+ m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections))
case e: ExtModule =>
e.copy(ports = e.ports ++ newPorts)
- }
+ }
case None => module // unchanged if no paths
}
}
- /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */
- def topWiringDummyOutputFilesFunction(dir: String,
- mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)],
- state: CircuitState): CircuitState = {
- state
+ /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */
+ def topWiringDummyOutputFilesFunction(
+ dir: String,
+ mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)],
+ state: CircuitState
+ ): CircuitState = {
+ state
}
-
def execute(state: CircuitState): CircuitState = {
- val outputTuples: Seq[(String,
- (String,Seq[((ComponentName, Type, Boolean, InstPath, String), Int)],
- CircuitState) => CircuitState)] = state.annotations.collect {
- case TopWiringOutputFilesAnnotation(td,of) => (td, of) }
+ val outputTuples: Seq[
+ (String, (String, Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], CircuitState) => CircuitState)
+ ] = state.annotations.collect {
+ case TopWiringOutputFilesAnnotation(td, of) => (td, of)
+ }
// Do actual work of this transform
val sources = getSourcesMap(state)
val (nstate, nmappings) = if (sources.nonEmpty) {
- val portnamesmap: mutable.Map[String,String] = mutable.Map()
+ val portnamesmap: mutable.Map[String, String] = mutable.Map()
val instgraph = InstanceKeyGraph(state.circuit)
- val namespacemap = state.circuit.modules.map{ case m => (m.name -> Namespace(m)) }.toMap
- val modulesx = state.circuit.modules map onModule(sources, portnamesmap, instgraph, namespacemap)
+ val namespacemap = state.circuit.modules.map { case m => (m.name -> Namespace(m)) }.toMap
+ val modulesx = state.circuit.modules.map(onModule(sources, portnamesmap, instgraph, namespacemap))
val newCircuit = state.circuit.copy(modules = modulesx)
val mappings = sources(state.circuit.main).zipWithIndex
val annosx = state.annotations.filter {
case _: TopWiringAnnotation => false
- case _ => true
+ case _ => true
}
(state.copy(circuit = newCircuit, annotations = annosx), mappings)
- }
- else { (state, List.empty) }
+ } else { (state, List.empty) }
//Generate output files based on the mapping.
outputTuples.map { case (dir, outputfunction) => outputfunction(dir, nmappings, nstate) }
nstate
diff --git a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
index 7370fcfb..cdbee495 100644
--- a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
+++ b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms.formal
import firrtl.ir.{Circuit, Formal, Statement, Verification}
@@ -7,7 +6,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, Transform}
import firrtl.annotations.NoTargetAnnotation
import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption}
-
/**
* Assert Submodule Assumptions
*
@@ -16,12 +14,13 @@ import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption}
* overly restrictive assume in a child module can prevent the model checker
* from searching valid inputs and states in the parent module.
*/
-class AssertSubmoduleAssumptions extends Transform
- with RegisteredTransform
- with DependencyAPIMigration
- with PreservesAll[Transform] {
+class AssertSubmoduleAssumptions
+ extends Transform
+ with RegisteredTransform
+ with DependencyAPIMigration
+ with PreservesAll[Transform] {
- override def prerequisites: Seq[TransformDependency] = Seq.empty
+ override def prerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisiteOf: Seq[TransformDependency] =
firrtl.stage.Forms.MidEmitters
@@ -29,9 +28,10 @@ class AssertSubmoduleAssumptions extends Transform
val options = Seq(
new ShellOption[Unit](
longOption = "no-asa",
- toAnnotationSeq = (_: Unit) => Seq(
- DontAssertSubmoduleAssumptionsAnnotation),
- helpText = "Disable assert submodule assumptions" ) )
+ toAnnotationSeq = (_: Unit) => Seq(DontAssertSubmoduleAssumptionsAnnotation),
+ helpText = "Disable assert submodule assumptions"
+ )
+ )
def assertAssumption(s: Statement): Statement = s match {
case Verification(Formal.Assume, info, clk, cond, en, msg) =>
@@ -50,8 +50,7 @@ class AssertSubmoduleAssumptions extends Transform
}
def execute(state: CircuitState): CircuitState = {
- val noASA = state.annotations.contains(
- DontAssertSubmoduleAssumptionsAnnotation)
+ val noASA = state.annotations.contains(DontAssertSubmoduleAssumptionsAnnotation)
if (noASA) {
logger.info("Skipping assert submodule assumptions")
state
diff --git a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala
index ddead331..5928c79c 100644
--- a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala
+++ b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala
@@ -14,10 +14,8 @@ import firrtl.options.Dependency
object ConvertAsserts extends Transform with DependencyAPIMigration {
override def prerequisites = Nil
override def optionalPrerequisites = Nil
- override def optionalPrerequisiteOf = Seq(
- Dependency[VerilogEmitter],
- Dependency[MinimumVerilogEmitter],
- Dependency[RemoveVerificationStatements])
+ override def optionalPrerequisiteOf =
+ Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[RemoveVerificationStatements])
override def invalidates(a: Transform): Boolean = false
@@ -28,7 +26,7 @@ object ConvertAsserts extends Transform with DependencyAPIMigration {
val stop = Stop(i, 1, clk, gatedNPred)
msg match {
case StringLit("") => stop
- case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop)
+ case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop)
}
case s => s.mapStmt(convertAsserts)
}
diff --git a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala
index 72890c07..1e6d2c72 100644
--- a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala
+++ b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms.formal
import firrtl.ir.{Circuit, EmptyStmt, Statement, Verification}
@@ -6,7 +5,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, MinimumVerilogEmitter, Tran
import firrtl.options.{Dependency, PreservesAll, StageUtils}
import firrtl.stage.TransformManager.TransformDependency
-
/**
* Remove Verification Statements
*
@@ -14,15 +12,12 @@ import firrtl.stage.TransformManager.TransformDependency
* This is intended to be required by the Verilog emitter to ensure compatibility
* with the Verilog 2001 standard.
*/
-class RemoveVerificationStatements extends Transform
- with DependencyAPIMigration
- with PreservesAll[Transform] {
+class RemoveVerificationStatements extends Transform with DependencyAPIMigration with PreservesAll[Transform] {
- override def prerequisites: Seq[TransformDependency] = Seq.empty
+ override def prerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency(ConvertAsserts))
override def optionalPrerequisiteOf: Seq[TransformDependency] =
- Seq( Dependency[VerilogEmitter],
- Dependency[MinimumVerilogEmitter])
+ Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter])
private var removedCounter = 0
@@ -43,11 +38,13 @@ class RemoveVerificationStatements extends Transform
def execute(state: CircuitState): CircuitState = {
val newState = state.copy(circuit = run(state.circuit))
if (removedCounter > 0) {
- StageUtils.dramaticWarning(s"$removedCounter verification statements " +
- "(assert, assume or cover) " +
- "were removed when compiling to Verilog because the basic Verilog " +
- "standard does not support them. If this was not intended, compile " +
- "to System Verilog instead using the `-X sverilog` compiler flag.")
+ StageUtils.dramaticWarning(
+ s"$removedCounter verification statements " +
+ "(assert, assume or cover) " +
+ "were removed when compiling to Verilog because the basic Verilog " +
+ "standard does not support them. If this was not intended, compile " +
+ "to System Verilog instead using the `-X sverilog` compiler flag."
+ )
}
newState
}
diff --git a/src/main/scala/firrtl/traversals/Foreachers.scala b/src/main/scala/firrtl/traversals/Foreachers.scala
index fdb02399..dee74d63 100644
--- a/src/main/scala/firrtl/traversals/Foreachers.scala
+++ b/src/main/scala/firrtl/traversals/Foreachers.scala
@@ -15,19 +15,19 @@ object Foreachers {
}
private object StmtForMagnet {
implicit def forStmt(f: Statement => Unit): StmtForMagnet = new StmtForMagnet {
- def foreach(stmt: Statement): Unit = stmt foreachStmt f
+ def foreach(stmt: Statement): Unit = stmt.foreachStmt(f)
}
implicit def forExp(f: Expression => Unit): StmtForMagnet = new StmtForMagnet {
- def foreach(stmt: Statement): Unit = stmt foreachExpr f
+ def foreach(stmt: Statement): Unit = stmt.foreachExpr(f)
}
implicit def forType(f: Type => Unit): StmtForMagnet = new StmtForMagnet {
- def foreach(stmt: Statement) : Unit = stmt foreachType f
+ def foreach(stmt: Statement): Unit = stmt.foreachType(f)
}
implicit def forString(f: String => Unit): StmtForMagnet = new StmtForMagnet {
- def foreach(stmt: Statement): Unit = stmt foreachString f
+ def foreach(stmt: Statement): Unit = stmt.foreachString(f)
}
implicit def forInfo(f: Info => Unit): StmtForMagnet = new StmtForMagnet {
- def foreach(stmt: Statement): Unit = stmt foreachInfo f
+ def foreach(stmt: Statement): Unit = stmt.foreachInfo(f)
}
}
implicit class StmtForeach(val _stmt: Statement) extends AnyVal {
@@ -41,13 +41,13 @@ object Foreachers {
}
private object ExprForMagnet {
implicit def forExpr(f: Expression => Unit): ExprForMagnet = new ExprForMagnet {
- def foreach(expr: Expression): Unit = expr foreachExpr f
+ def foreach(expr: Expression): Unit = expr.foreachExpr(f)
}
implicit def forType(f: Type => Unit): ExprForMagnet = new ExprForMagnet {
- def foreach(expr: Expression): Unit = expr foreachType f
+ def foreach(expr: Expression): Unit = expr.foreachType(f)
}
implicit def forWidth(f: Width => Unit): ExprForMagnet = new ExprForMagnet {
- def foreach(expr: Expression): Unit = expr foreachWidth f
+ def foreach(expr: Expression): Unit = expr.foreachWidth(f)
}
}
implicit class ExprForeach(val _expr: Expression) extends AnyVal {
@@ -60,10 +60,10 @@ object Foreachers {
}
private object TypeForMagnet {
implicit def forType(f: Type => Unit): TypeForMagnet = new TypeForMagnet {
- def foreach(tpe: Type): Unit = tpe foreachType f
+ def foreach(tpe: Type): Unit = tpe.foreachType(f)
}
implicit def forWidth(f: Width => Unit): TypeForMagnet = new TypeForMagnet {
- def foreach(tpe: Type): Unit = tpe foreachWidth f
+ def foreach(tpe: Type): Unit = tpe.foreachWidth(f)
}
}
implicit class TypeForeach(val _tpe: Type) extends AnyVal {
@@ -76,16 +76,16 @@ object Foreachers {
}
private object ModuleForMagnet {
implicit def forStmt(f: Statement => Unit): ModuleForMagnet = new ModuleForMagnet {
- def foreach(module: DefModule): Unit = module foreachStmt f
+ def foreach(module: DefModule): Unit = module.foreachStmt(f)
}
implicit def forPorts(f: Port => Unit): ModuleForMagnet = new ModuleForMagnet {
- def foreach(module: DefModule): Unit = module foreachPort f
+ def foreach(module: DefModule): Unit = module.foreachPort(f)
}
implicit def forString(f: String => Unit): ModuleForMagnet = new ModuleForMagnet {
- def foreach(module: DefModule): Unit = module foreachString f
+ def foreach(module: DefModule): Unit = module.foreachString(f)
}
implicit def forInfo(f: Info => Unit): ModuleForMagnet = new ModuleForMagnet {
- def foreach(module: DefModule): Unit = module foreachInfo f
+ def foreach(module: DefModule): Unit = module.foreachInfo(f)
}
}
implicit class ModuleForeach(val _module: DefModule) extends AnyVal {
@@ -98,13 +98,13 @@ object Foreachers {
}
private object CircuitForMagnet {
implicit def forModules(f: DefModule => Unit): CircuitForMagnet = new CircuitForMagnet {
- def foreach(circuit: Circuit): Unit = circuit foreachModule f
+ def foreach(circuit: Circuit): Unit = circuit.foreachModule(f)
}
implicit def forString(f: String => Unit): CircuitForMagnet = new CircuitForMagnet {
- def foreach(circuit: Circuit): Unit = circuit foreachString f
+ def foreach(circuit: Circuit): Unit = circuit.foreachString(f)
}
implicit def forInfo(f: Info => Unit): CircuitForMagnet = new CircuitForMagnet {
- def foreach(circuit: Circuit): Unit = circuit foreachInfo f
+ def foreach(circuit: Circuit): Unit = circuit.foreachInfo(f)
}
}
implicit class CircuitForeach(val _circuit: Circuit) extends AnyVal {
diff --git a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
index 1557bb0c..2ac5b035 100644
--- a/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
+++ b/src/main/scala/firrtl/util/BackendCompilationUtilities.scala
@@ -14,6 +14,7 @@ import firrtl.FileUtils
import scala.sys.process.{ProcessBuilder, ProcessLogger, _}
object BackendCompilationUtilities extends LazyLogging {
+
/** Parent directory for tests */
lazy val TestDirectory = new File("test_run_dir")
@@ -69,12 +70,7 @@ object BackendCompilationUtilities extends LazyLogging {
* @return true if compiler completed successfully
*/
def firrtlToVerilog(prefix: String, dir: File): ProcessBuilder = {
- Process(
- Seq("firrtl",
- "-i", s"$prefix.fir",
- "-o", s"$prefix.v",
- "-X", "verilog"),
- dir)
+ Process(Seq("firrtl", "-i", s"$prefix.fir", "-o", s"$prefix.v", "-X", "verilog"), dir)
}
/** Generates a Verilator invocation to convert Verilog sources to C++
@@ -103,11 +99,11 @@ object BackendCompilationUtilities extends LazyLogging {
* @param extraCmdLineArgs list of additional command line arguments
*/
def verilogToCpp(
- dutFile: String,
- dir: File,
- vSources: Seq[File],
- cppHarness: File,
- suppressVcd: Boolean = false,
+ dutFile: String,
+ dir: File,
+ vSources: Seq[File],
+ cppHarness: File,
+ suppressVcd: Boolean = false,
resourceFileName: String = firrtl.transforms.BlackBoxSourceHelper.defaultFileListName,
extraCmdLineArgs: Seq[String] = Seq.empty
): ProcessBuilder = {
@@ -116,10 +112,9 @@ object BackendCompilationUtilities extends LazyLogging {
val list_file = new File(dir, resourceFileName)
val blackBoxVerilogList = {
- if(list_file.exists()) {
+ if (list_file.exists()) {
Seq("-f", list_file.getAbsolutePath)
- }
- else {
+ } else {
Seq.empty[String]
}
}
@@ -128,37 +123,39 @@ object BackendCompilationUtilities extends LazyLogging {
// If it's in the main .f resource file, don't explicitly include it on the command line.
// Build a set of canonical file paths to use as a filter to exclude already included additional Verilog sources.
val blackBoxHelperFiles: Set[String] = {
- if(list_file.exists()) {
+ if (list_file.exists()) {
FileUtils.getLines(list_file).toSet
- }
- else {
+ } else {
Set.empty
}
}
val vSourcesFiltered = vSources.filterNot(f => blackBoxHelperFiles.contains(f.getCanonicalPath))
val command = Seq(
"verilator",
- "--cc", s"${dir.getAbsolutePath}/$dutFile.v"
+ "--cc",
+ s"${dir.getAbsolutePath}/$dutFile.v"
) ++
extraCmdLineArgs ++
blackBoxVerilogList ++
vSourcesFiltered.flatMap(file => Seq("-v", file.getCanonicalPath)) ++
- Seq("--assert",
- "-Wno-fatal",
- "-Wno-WIDTH",
- "-Wno-STMTDLY"
- ) ++
- { if(suppressVcd) { Seq.empty } else { Seq("--trace")} } ++
+ Seq("--assert", "-Wno-fatal", "-Wno-WIDTH", "-Wno-STMTDLY") ++ {
+ if (suppressVcd) { Seq.empty }
+ else { Seq("--trace") }
+ } ++
Seq(
"-O1",
- "--top-module", topModule,
+ "--top-module",
+ topModule,
"+define+TOP_TYPE=V" + dutFile,
s"+define+PRINTF_COND=!$topModule.reset",
s"+define+STOP_COND=!$topModule.reset",
"-CFLAGS",
s"""-Wno-undefined-bool-conversion -O1 -DTOP_TYPE=V$dutFile -DVL_USER_FINISH -include V$dutFile.h""",
- "-Mdir", dir.getAbsolutePath,
- "--exe", cppHarness.getAbsolutePath)
+ "-Mdir",
+ dir.getAbsolutePath,
+ "--exe",
+ cppHarness.getAbsolutePath
+ )
logger.info(s"${command.mkString(" ")}")
command
}
@@ -167,17 +164,20 @@ object BackendCompilationUtilities extends LazyLogging {
Seq("make", "-C", dir.toString, "-j", "-f", s"V$prefix.mk", s"V$prefix")
def executeExpectingFailure(
- prefix: String,
- dir: File,
- assertionMsg: String = ""): Boolean = {
+ prefix: String,
+ dir: File,
+ assertionMsg: String = ""
+ ): Boolean = {
var triggered = false
val assertionMessageSupplied = assertionMsg != ""
val e = Process(s"./V$prefix", dir) !
- ProcessLogger(line => {
- triggered = triggered || (assertionMessageSupplied && line.contains(assertionMsg))
- logger.info(line)
- },
- logger.warn(_))
+ ProcessLogger(
+ line => {
+ triggered = triggered || (assertionMessageSupplied && line.contains(assertionMsg))
+ logger.info(line)
+ },
+ logger.warn(_)
+ )
// Fail if a line contained an assertion or if we get a non-zero exit code
// or, we get a SIGABRT (assertion failure) and we didn't provide a specific assertion message
triggered || (e != 0 && (e != 134 || !assertionMessageSupplied))
@@ -201,10 +201,7 @@ object BackendCompilationUtilities extends LazyLogging {
* @param timesteps the maximum number of timesteps for Yosys equivalence
* checking to consider
*/
- def yosysExpectSuccess(customTop: String,
- referenceTop: String,
- testDir: File,
- timesteps: Int = 1): Boolean = {
+ def yosysExpectSuccess(customTop: String, referenceTop: String, testDir: File, timesteps: Int = 1): Boolean = {
!yosysExpectFailure(customTop, referenceTop, testDir, timesteps)
}
@@ -222,31 +219,26 @@ object BackendCompilationUtilities extends LazyLogging {
* @param timesteps the maximum number of timesteps for Yosys equivalence
* checking to consider
*/
- def yosysExpectFailure(customTop: String,
- referenceTop: String,
- testDir: File,
- timesteps: Int = 1): Boolean = {
+ def yosysExpectFailure(customTop: String, referenceTop: String, testDir: File, timesteps: Int = 1): Boolean = {
val scriptFileName = s"${testDir.getAbsolutePath}/yosys_script"
val yosysScriptWriter = new PrintWriter(scriptFileName)
- yosysScriptWriter.write(
- s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v
- |prep -flatten -top $customTop; proc; opt; memory
- |design -stash custom
- |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v
- |prep -flatten -top $referenceTop; proc; opt; memory
- |design -stash reference
- |design -copy-from custom -as custom $customTop
- |design -copy-from reference -as reference $referenceTop
- |equiv_make custom reference equiv
- |hierarchy -top equiv
- |prep -flatten -top equiv
- |clean -purge
- |equiv_simple -seq $timesteps
- |equiv_induct -seq $timesteps
- |equiv_status -assert
- """
- .stripMargin)
+ yosysScriptWriter.write(s"""read_verilog ${testDir.getAbsolutePath}/$customTop.v
+ |prep -flatten -top $customTop; proc; opt; memory
+ |design -stash custom
+ |read_verilog ${testDir.getAbsolutePath}/$referenceTop.v
+ |prep -flatten -top $referenceTop; proc; opt; memory
+ |design -stash reference
+ |design -copy-from custom -as custom $customTop
+ |design -copy-from reference -as reference $referenceTop
+ |equiv_make custom reference equiv
+ |hierarchy -top equiv
+ |prep -flatten -top equiv
+ |clean -purge
+ |equiv_simple -seq $timesteps
+ |equiv_induct -seq $timesteps
+ |equiv_status -assert
+ """.stripMargin)
yosysScriptWriter.close()
val resultFileName = testDir.getAbsolutePath + "/yosys_results"
@@ -258,28 +250,32 @@ object BackendCompilationUtilities extends LazyLogging {
@deprecated("use object BackendCompilationUtilities", "1.3")
trait BackendCompilationUtilities extends LazyLogging {
lazy val TestDirectory = BackendCompilationUtilities.TestDirectory
- def timeStamp: String = BackendCompilationUtilities.timeStamp
+ def timeStamp: String = BackendCompilationUtilities.timeStamp
def loggingProcessLogger: ProcessLogger = BackendCompilationUtilities.loggingProcessLogger
- def copyResourceToFile(name: String, file: File): Unit = BackendCompilationUtilities.copyResourceToFile(name, file)
+ def copyResourceToFile(name: String, file: File): Unit = BackendCompilationUtilities.copyResourceToFile(name, file)
def createTestDirectory(testName: String): File = BackendCompilationUtilities.createTestDirectory(testName)
- def makeHarness(template: String => String, post: String)(f: File): File = BackendCompilationUtilities.makeHarness(template, post)(f)
- def firrtlToVerilog(prefix: String, dir: File): ProcessBuilder = BackendCompilationUtilities.firrtlToVerilog(prefix, dir)
+ def makeHarness(template: String => String, post: String)(f: File): File =
+ BackendCompilationUtilities.makeHarness(template, post)(f)
+ def firrtlToVerilog(prefix: String, dir: File): ProcessBuilder =
+ BackendCompilationUtilities.firrtlToVerilog(prefix, dir)
def verilogToCpp(
- dutFile: String,
- dir: File,
- vSources: Seq[File],
- cppHarness: File,
- suppressVcd: Boolean = false,
- resourceFileName: String = firrtl.transforms.BlackBoxSourceHelper.defaultFileListName
- ): ProcessBuilder = {
+ dutFile: String,
+ dir: File,
+ vSources: Seq[File],
+ cppHarness: File,
+ suppressVcd: Boolean = false,
+ resourceFileName: String = firrtl.transforms.BlackBoxSourceHelper.defaultFileListName
+ ): ProcessBuilder = {
BackendCompilationUtilities.verilogToCpp(dutFile, dir, vSources, cppHarness, suppressVcd, resourceFileName)
}
def cppToExe(prefix: String, dir: File): ProcessBuilder = BackendCompilationUtilities.cppToExe(prefix, dir)
def executeExpectingFailure(
- prefix: String,
- dir: File,
- assertionMsg: String = ""): Boolean = {
+ prefix: String,
+ dir: File,
+ assertionMsg: String = ""
+ ): Boolean = {
BackendCompilationUtilities.executeExpectingFailure(prefix, dir, assertionMsg)
}
- def executeExpectingSuccess(prefix: String, dir: File): Boolean = BackendCompilationUtilities.executeExpectingSuccess(prefix, dir)
+ def executeExpectingSuccess(prefix: String, dir: File): Boolean =
+ BackendCompilationUtilities.executeExpectingSuccess(prefix, dir)
}
diff --git a/src/main/scala/firrtl/util/ClassUtils.scala b/src/main/scala/firrtl/util/ClassUtils.scala
index 1b388035..34ff60fc 100644
--- a/src/main/scala/firrtl/util/ClassUtils.scala
+++ b/src/main/scala/firrtl/util/ClassUtils.scala
@@ -1,18 +1,20 @@
package firrtl.util
object ClassUtils {
+
/** Determine if a named class is loaded.
*
* @param name - name of the class: "foo.bar" or "org.foo.bar"
* @return true if the class has been loaded (is accessible), false otherwise.
*/
def isClassLoaded(name: String): Boolean = {
- val found = try {
- Class.forName(name, false, getClass.getClassLoader) != null
- } catch {
- case e: ClassNotFoundException => false
- case x: Throwable => throw x
- }
+ val found =
+ try {
+ Class.forName(name, false, getClass.getClassLoader) != null
+ } catch {
+ case e: ClassNotFoundException => false
+ case x: Throwable => throw x
+ }
// println(s"isClassLoaded: %s $name".format(if (found) "found" else "didn't find"))
found
}
diff --git a/src/main/scala/logger/Logger.scala b/src/main/scala/logger/Logger.scala
index 9cf645fa..e002db92 100644
--- a/src/main/scala/logger/Logger.scala
+++ b/src/main/scala/logger/Logger.scala
@@ -4,7 +4,7 @@ package logger
import java.io.{ByteArrayOutputStream, File, FileOutputStream, PrintStream}
-import firrtl.{ExecutionOptionsManager, AnnotationSeq}
+import firrtl.{AnnotationSeq, ExecutionOptionsManager}
import firrtl.options.Viewer.view
import logger.phases.{AddDefaults, Checks}
@@ -38,7 +38,7 @@ object LogLevel extends Enumeration {
case "info" => LogLevel.Info
case "debug" => LogLevel.Debug
case "trace" => LogLevel.Trace
- case level => throw new Exception(s"Unknown LogLevel '$level'")
+ case level => throw new Exception(s"Unknown LogLevel '$level'")
}
}
@@ -58,8 +58,8 @@ private class LoggerState {
val classLevels = new scala.collection.mutable.HashMap[String, LogLevel.Value]
val classToLevelCache = new scala.collection.mutable.HashMap[String, LogLevel.Value]
var logClassNames = false
- var stream: PrintStream = System.out
- var fromInvoke: Boolean = false // this is used to not have invokes re-create run-state
+ var stream: PrintStream = System.out
+ var fromInvoke: Boolean = false // this is used to not have invokes re-create run-state
var stringBufferOption: Option[Logger.OutputCaptor] = None
override def toString: String = {
@@ -137,10 +137,9 @@ object Logger {
@deprecated("Use makescope(opts: FirrtlOptions)", "1.2")
def makeScope[A](args: Array[String] = Array.empty)(codeBlock: => A): A = {
val executionOptionsManager = new ExecutionOptionsManager("logger")
- if(executionOptionsManager.parse(args)) {
+ if (executionOptionsManager.parse(args)) {
makeScope(executionOptionsManager)(codeBlock)
- }
- else {
+ } else {
throw new Exception(s"logger invoke failed to parse args ${args.mkString(", ")}")
}
}
@@ -154,10 +153,9 @@ object Logger {
def makeScope[A](options: AnnotationSeq)(codeBlock: => A): A = {
val runState: LoggerState = {
val newRunState = updatableLoggerState.value.getOrElse(new LoggerState)
- if(newRunState.fromInvoke) {
+ if (newRunState.fromInvoke) {
newRunState
- }
- else {
+ } else {
val forcedNewRunState = new LoggerState
forcedNewRunState.fromInvoke = true
forcedNewRunState
@@ -179,39 +177,41 @@ object Logger {
*/
private def testPackageNameMatch(className: String, level: LogLevel.Value): Option[Boolean] = {
val classLevels = state.classLevels
- if(classLevels.isEmpty) return None
+ if (classLevels.isEmpty) return None
// If this class name in cache just use that value
- val levelForThisClassName = state.classToLevelCache.getOrElse(className, {
- // otherwise break up the class name in to full package path as list and find most specific entry you can
- val packageNameList = className.split("""\.""").toList
- /*
- * start with full class path, lopping off from the tail until nothing left
- */
- def matchPathToFindLevel(packageList: List[String]): LogLevel.Value = {
- if(packageList.isEmpty) {
- LogLevel.None
+ val levelForThisClassName = state.classToLevelCache.getOrElse(
+ className, {
+ // otherwise break up the class name in to full package path as list and find most specific entry you can
+ val packageNameList = className.split("""\.""").toList
+ /*
+ * start with full class path, lopping off from the tail until nothing left
+ */
+ def matchPathToFindLevel(packageList: List[String]): LogLevel.Value = {
+ if (packageList.isEmpty) {
+ LogLevel.None
+ } else {
+ val partialName = packageList.mkString(".")
+ val level = classLevels.getOrElse(
+ partialName, {
+ matchPathToFindLevel(packageList.reverse.tail.reverse)
+ }
+ )
+ level
+ }
}
- else {
- val partialName = packageList.mkString(".")
- val level = classLevels.getOrElse(partialName, {
- matchPathToFindLevel(packageList.reverse.tail.reverse)
- })
- level
- }
- }
- val levelSpecified = matchPathToFindLevel(packageNameList)
- if(levelSpecified != LogLevel.None) {
- state.classToLevelCache(className) = levelSpecified
+ val levelSpecified = matchPathToFindLevel(packageNameList)
+ if (levelSpecified != LogLevel.None) {
+ state.classToLevelCache(className) = levelSpecified
+ }
+ levelSpecified
}
- levelSpecified
- })
+ )
- if(levelForThisClassName != LogLevel.None) {
+ if (levelForThisClassName != LogLevel.None) {
Some(levelForThisClassName >= level)
- }
- else {
+ } else {
None
}
}
@@ -226,19 +226,20 @@ object Logger {
*/
private def showMessage(level: LogLevel.Value, className: String, message: => String): Unit = {
def logIt(): Unit = {
- if(state.logClassNames) {
+ if (state.logClassNames) {
state.stream.println(s"[$level:$className] $message")
- }
- else {
+ } else {
state.stream.println(message)
}
}
testPackageNameMatch(className, level) match {
- case Some(true) => logIt()
+ case Some(true) => logIt()
case Some(false) =>
case None =>
- if((state.globalLevel == LogLevel.None && level == LogLevel.Error) ||
- (state.globalLevel != LogLevel.None && state.globalLevel >= level)) {
+ if (
+ (state.globalLevel == LogLevel.None && level == LogLevel.Error) ||
+ (state.globalLevel != LogLevel.None && state.globalLevel >= level)
+ ) {
logIt()
}
}
@@ -247,6 +248,7 @@ object Logger {
def getGlobalLevel: LogLevel.Value = {
state.globalLevel
}
+
/**
* This resets everything in the current Logger environment, including the destination
* use this with caution. Unexpected things can happen
@@ -309,7 +311,7 @@ object Logger {
def clearStringBuffer(): Unit = {
state.stringBufferOption match {
case Some(x) => x.byteArrayOutputStream.reset()
- case None =>
+ case None =>
}
}
@@ -360,16 +362,16 @@ object Logger {
*/
def setOptions(inputAnnotations: AnnotationSeq): Unit = {
val annotations =
- Seq( new AddDefaults, Checks )
- .foldLeft(inputAnnotations)((a, p) => p.transform(a))
+ Seq(new AddDefaults, Checks)
+ .foldLeft(inputAnnotations)((a, p) => p.transform(a))
val lopts = view[LoggerOptions](annotations)
state.globalLevel = (state.globalLevel, lopts.globalLogLevel) match {
case (LogLevel.None, LogLevel.None) => LogLevel.None
- case (x, LogLevel.None) => x
- case (LogLevel.None, x) => x
- case (_, x) => x
- case _ => LogLevel.Error
+ case (x, LogLevel.None) => x
+ case (LogLevel.None, x) => x
+ case (_, x) => x
+ case _ => LogLevel.Error
}
setClassLogLevels(lopts.classLogLevels)
@@ -386,6 +388,7 @@ object Logger {
* @param containerClass passed in from the LazyLogging trait in order to provide class level logging granularity
*/
class Logger(containerClass: String) {
+
/**
* Log message at Error level
* @param message message generator to be invoked if level is right
@@ -393,6 +396,7 @@ class Logger(containerClass: String) {
def error(message: => String): Unit = {
Logger.showMessage(LogLevel.Error, containerClass, message)
}
+
/**
* Log message at Warn level
* @param message message generator to be invoked if level is right
@@ -400,6 +404,7 @@ class Logger(containerClass: String) {
def warn(message: => String): Unit = {
Logger.showMessage(LogLevel.Warn, containerClass, message)
}
+
/**
* Log message at Inof level
* @param message message generator to be invoked if level is right
@@ -407,6 +412,7 @@ class Logger(containerClass: String) {
def info(message: => String): Unit = {
Logger.showMessage(LogLevel.Info, containerClass, message)
}
+
/**
* Log message at Debug level
* @param message message generator to be invoked if level is right
@@ -414,6 +420,7 @@ class Logger(containerClass: String) {
def debug(message: => String): Unit = {
Logger.showMessage(LogLevel.Debug, containerClass, message)
}
+
/**
* Log message at Trace level
* @param message message generator to be invoked if level is right
diff --git a/src/main/scala/logger/LoggerAnnotations.scala b/src/main/scala/logger/LoggerAnnotations.scala
index f4dc6b38..b345d617 100644
--- a/src/main/scala/logger/LoggerAnnotations.scala
+++ b/src/main/scala/logger/LoggerAnnotations.scala
@@ -5,7 +5,6 @@ package logger
import firrtl.annotations.{Annotation, NoTargetAnnotation}
import firrtl.options.{HasShellOptions, ShellOption}
-
/** An annotation associated with a Logger command line option */
sealed trait LoggerOption { this: Annotation => }
@@ -14,7 +13,9 @@ sealed trait LoggerOption { this: Annotation => }
* - if unset, a [[LogLevelAnnotation]] with the default log level will be emitted
* @param level the level of logging
*/
-case class LogLevelAnnotation(globalLogLevel: LogLevel.Value = LogLevel.Warn) extends NoTargetAnnotation with LoggerOption
+case class LogLevelAnnotation(globalLogLevel: LogLevel.Value = LogLevel.Warn)
+ extends NoTargetAnnotation
+ with LoggerOption
object LogLevelAnnotation extends HasShellOptions {
@@ -24,7 +25,9 @@ object LogLevelAnnotation extends HasShellOptions {
toAnnotationSeq = (a: String) => Seq(LogLevelAnnotation(LogLevel(a))),
helpText = s"Set global logging verbosity (default: ${new LoggerOptions().globalLogLevel}",
shortOption = Some("ll"),
- helpValueName = Some("{error|warn|info|debug|trace}") ) )
+ helpValueName = Some("{error|warn|info|debug|trace}")
+ )
+ )
}
@@ -33,20 +36,26 @@ object LogLevelAnnotation extends HasShellOptions {
* @param name the class name to log
* @param level the verbosity level
*/
-case class ClassLogLevelAnnotation(className: String, level: LogLevel.Value) extends NoTargetAnnotation with LoggerOption
+case class ClassLogLevelAnnotation(className: String, level: LogLevel.Value)
+ extends NoTargetAnnotation
+ with LoggerOption
object ClassLogLevelAnnotation extends HasShellOptions {
val options = Seq(
new ShellOption[Seq[String]](
longOption = "class-log-level",
- toAnnotationSeq = (a: Seq[String]) => a.map { aa =>
- val className :: levelName :: _ = aa.split(":").toList
- val level = LogLevel(levelName)
- ClassLogLevelAnnotation(className, level) },
+ toAnnotationSeq = (a: Seq[String]) =>
+ a.map { aa =>
+ val className :: levelName :: _ = aa.split(":").toList
+ val level = LogLevel(levelName)
+ ClassLogLevelAnnotation(className, level)
+ },
helpText = "Set per-class logging verbosity",
shortOption = Some("cll"),
- helpValueName = Some("<FullClassName:{error|warn|info|debug|trace}>...") ) )
+ helpValueName = Some("<FullClassName:{error|warn|info|debug|trace}>...")
+ )
+ )
}
@@ -63,7 +72,9 @@ object LogFileAnnotation extends HasShellOptions {
longOption = "log-file",
toAnnotationSeq = (a: String) => Seq(LogFileAnnotation(Some(a))),
helpText = "Log to a file instead of STDOUT",
- helpValueName = Some("<file>") ) )
+ helpValueName = Some("<file>")
+ )
+ )
}
@@ -77,6 +88,8 @@ case object LogClassNamesAnnotation extends NoTargetAnnotation with LoggerOption
longOption = "log-class-names",
toAnnotationSeq = (a: Unit) => Seq(LogClassNamesAnnotation),
helpText = "Show class names and log level in logging output",
- shortOption = Some("lcn") ) )
+ shortOption = Some("lcn")
+ )
+ )
}
diff --git a/src/main/scala/logger/LoggerOptions.scala b/src/main/scala/logger/LoggerOptions.scala
index 299382f0..6cc745b9 100644
--- a/src/main/scala/logger/LoggerOptions.scala
+++ b/src/main/scala/logger/LoggerOptions.scala
@@ -9,23 +9,25 @@ package logger
* @param logToFile if true, log to a file
* @param logClassNames indicates logging verbosity on a class-by-class basis
*/
-class LoggerOptions private [logger] (
- val globalLogLevel: LogLevel.Value = LogLevelAnnotation().globalLogLevel,
+class LoggerOptions private[logger] (
+ val globalLogLevel: LogLevel.Value = LogLevelAnnotation().globalLogLevel,
val classLogLevels: Map[String, LogLevel.Value] = Map.empty,
- val logClassNames: Boolean = false,
- val logFileName: Option[String] = None) {
+ val logClassNames: Boolean = false,
+ val logFileName: Option[String] = None) {
- private [logger] def copy(
- globalLogLevel: LogLevel.Value = globalLogLevel,
+ private[logger] def copy(
+ globalLogLevel: LogLevel.Value = globalLogLevel,
classLogLevels: Map[String, LogLevel.Value] = classLogLevels,
- logClassNames: Boolean = logClassNames,
- logFileName: Option[String] = logFileName): LoggerOptions = {
+ logClassNames: Boolean = logClassNames,
+ logFileName: Option[String] = logFileName
+ ): LoggerOptions = {
new LoggerOptions(
globalLogLevel = globalLogLevel,
classLogLevels = classLogLevels,
logClassNames = logClassNames,
- logFileName = logFileName)
+ logFileName = logFileName
+ )
}
diff --git a/src/main/scala/logger/phases/AddDefaults.scala b/src/main/scala/logger/phases/AddDefaults.scala
index 660de579..ec673637 100644
--- a/src/main/scala/logger/phases/AddDefaults.scala
+++ b/src/main/scala/logger/phases/AddDefaults.scala
@@ -5,10 +5,10 @@ package logger.phases
import firrtl.AnnotationSeq
import firrtl.options.Phase
-import logger.{LoggerOption, LogLevelAnnotation}
+import logger.{LogLevelAnnotation, LoggerOption}
/** Add default logger [[Annotation]]s */
-private [logger] class AddDefaults extends Phase {
+private[logger] class AddDefaults extends Phase {
override def prerequisites = Seq.empty
override def optionalPrerequisiteOf = Seq.empty
@@ -20,12 +20,12 @@ private [logger] class AddDefaults extends Phase {
*/
def transform(annotations: AnnotationSeq): AnnotationSeq = {
var ll = true
- annotations.collect{ case a: LoggerOption => a }.map{
+ annotations.collect { case a: LoggerOption => a }.map {
case _: LogLevelAnnotation => ll = false
- case _ =>
+ case _ =>
}
annotations ++
- (if (ll) Seq(LogLevelAnnotation()) else Seq() )
+ (if (ll) Seq(LogLevelAnnotation()) else Seq())
}
}
diff --git a/src/main/scala/logger/phases/Checks.scala b/src/main/scala/logger/phases/Checks.scala
index e945fa98..0109c7ad 100644
--- a/src/main/scala/logger/phases/Checks.scala
+++ b/src/main/scala/logger/phases/Checks.scala
@@ -6,12 +6,13 @@ import firrtl.AnnotationSeq
import firrtl.annotations.Annotation
import firrtl.options.{Dependency, Phase}
-import logger.{LogLevelAnnotation, LogFileAnnotation, LoggerException}
+import logger.{LogFileAnnotation, LogLevelAnnotation, LoggerException}
import scala.collection.mutable
/** Check that an [[firrtl.AnnotationSeq AnnotationSeq]] has all necessary [[firrtl.annotations.Annotation Annotation]]s
- * for a [[Logger]] */
+ * for a [[Logger]]
+ */
object Checks extends Phase {
override def prerequisites = Seq(Dependency[AddDefaults])
@@ -26,20 +27,22 @@ object Checks extends Phase {
*/
def transform(annotations: AnnotationSeq): AnnotationSeq = {
val ll, lf = mutable.ListBuffer[Annotation]()
- annotations.foreach(
- _ match {
- case a: LogLevelAnnotation => ll += a
- case a: LogFileAnnotation => lf += a
- case _ => })
+ annotations.foreach(_ match {
+ case a: LogLevelAnnotation => ll += a
+ case a: LogFileAnnotation => lf += a
+ case _ =>
+ })
if (ll.size > 1) {
- val l = ll.map{ case LogLevelAnnotation(x) => x }
+ val l = ll.map { case LogLevelAnnotation(x) => x }
throw new LoggerException(
s"""|At most one log level can be specified, but found '${l.mkString(", ")}' specified via:
- | - an option or annotation: -ll, --log-level, LogLevelAnnotation""".stripMargin )}
+ | - an option or annotation: -ll, --log-level, LogLevelAnnotation""".stripMargin
+ )
+ }
if (lf.size > 1) {
- throw new LoggerException(
- s"""|At most one log file can be specified, but found ${lf.size} combinations of:
- | - an options or annotation: -ltf, --log-to-file, --log-file, LogFileAnnotation""".stripMargin )}
+ throw new LoggerException(s"""|At most one log file can be specified, but found ${lf.size} combinations of:
+ | - an options or annotation: -ltf, --log-to-file, --log-file, LogFileAnnotation""".stripMargin)
+ }
annotations
}
diff --git a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala
index 48427af8..72f50461 100644
--- a/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala
+++ b/src/main/scala/tutorial/lesson1-circuit-traversal/AnalyzeCircuit.scala
@@ -4,9 +4,9 @@ package tutorial
package lesson1
// Compiler Infrastructure
-import firrtl.{Transform, LowForm, CircuitState, Utils}
+import firrtl.{CircuitState, LowForm, Transform, Utils}
// Firrtl IR classes
-import firrtl.ir.{DefModule, Statement, Expression, Mux}
+import firrtl.ir.{DefModule, Expression, Mux, Statement}
// Map functions
import firrtl.Mappers._
// Scala's mutable collections
@@ -26,11 +26,11 @@ class Ledger {
private val modules = mutable.Set[String]()
private val moduleMuxMap = mutable.Map[String, Int]()
def foundMux(): Unit = moduleName match {
- case None => sys.error("Module name not defined in Ledger!")
+ case None => sys.error("Module name not defined in Ledger!")
case Some(name) => moduleMuxMap(name) = moduleMuxMap.getOrElse(name, 0) + 1
}
def getModuleName: String = moduleName match {
- case None => Utils.error("Module name not defined in Ledger!")
+ case None => Utils.error("Module name not defined in Ledger!")
case Some(name) => name
}
def setModuleName(myName: String): Unit = {
@@ -38,9 +38,9 @@ class Ledger {
moduleName = Some(myName)
}
def serialize: String = {
- modules map { myName =>
+ modules.map { myName =>
s"$myName => ${moduleMuxMap.getOrElse(myName, 0)} muxes!"
- } mkString "\n"
+ }.mkString("\n")
}
}
@@ -68,8 +68,10 @@ class Ledger {
* - https://github.com/ucb-bar/firrtl/wiki/Common-Pass-Idioms
*/
class AnalyzeCircuit extends Transform {
+
/** Requires the [[firrtl.ir.Circuit Circuit]] form to be "low" */
def inputForm = LowForm
+
/** Indicates the output [[firrtl.ir.Circuit Circuit]] form to be "low" */
def outputForm = LowForm
@@ -88,7 +90,7 @@ class AnalyzeCircuit extends Transform {
* - "map" - classic functional programming concept
* - discard the returned new [[firrtl.ir.Circuit Circuit]] because circuit is unmodified
*/
- circuit map walkModule(ledger)
+ circuit.map(walkModule(ledger))
// Print our ledger
println(ledger.serialize)
@@ -106,7 +108,7 @@ class AnalyzeCircuit extends Transform {
* - return the new [[firrtl.ir.DefModule DefModule]] (in this case, its identical to m)
* - if m does not contain [[firrtl.ir.Statement Statement]], map returns m.
*/
- m map walkStatement(ledger)
+ m.map(walkStatement(ledger))
}
/** Deeply visits every [[firrtl.ir.Statement Statement]] and [[firrtl.ir.Expression Expression]] in s. */
@@ -116,13 +118,13 @@ class AnalyzeCircuit extends Transform {
* - discard the new [[firrtl.ir.Statement Statement]] (in this case, its identical to s)
* - if s does not contain [[firrtl.ir.Expression Expression]], map returns s.
*/
- s map walkExpression(ledger)
+ s.map(walkExpression(ledger))
/* Execute the function walkStatement(ledger) on every [[firrtl.ir.Statement Statement]] in s.
* - return the new [[firrtl.ir.Statement Statement]] (in this case, its identical to s)
* - if s does not contain [[firrtl.ir.Statement Statement]], map returns s.
*/
- s map walkStatement(ledger)
+ s.map(walkStatement(ledger))
}
/** Deeply visits every [[firrtl.ir.Expression Expression]] in e.
@@ -135,7 +137,7 @@ class AnalyzeCircuit extends Transform {
* - return the new [[firrtl.ir.Expression Expression]] (in this case, its identical to e)
* - if s does not contain [[firrtl.ir.Expression Expression]], map returns e.
*/
- val visited = e map walkExpression(ledger)
+ val visited = e.map(walkExpression(ledger))
visited match {
// If e is a [[firrtl.ir.Mux Mux]], increment our ledger and return e.
diff --git a/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala b/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala
index 523be723..11b4519c 100644
--- a/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala
+++ b/src/main/scala/tutorial/lesson2-ir-fields/AnalyzeCircuit.scala
@@ -4,9 +4,9 @@ package tutorial
package lesson2
// Compiler Infrastructure
-import firrtl.{Transform, LowForm, CircuitState}
+import firrtl.{CircuitState, LowForm, Transform}
// Firrtl IR classes
-import firrtl.ir.{DefModule, Statement, Expression, Mux, DefInstance}
+import firrtl.ir.{DefInstance, DefModule, Expression, Mux, Statement}
// Map functions
import firrtl.Mappers._
// Scala's mutable collections
@@ -27,7 +27,7 @@ class Ledger {
private val moduleMuxMap = mutable.Map[String, Int]()
private val moduleInstanceMap = mutable.Map[String, Seq[String]]()
def getModuleName: String = moduleName match {
- case None => sys.error("Module name not defined in Ledger!")
+ case None => sys.error("Module name not defined in Ledger!")
case Some(name) => name
}
def setModuleName(myName: String): Unit = {
@@ -47,14 +47,14 @@ class Ledger {
private def countMux(myName: String): Int = {
val myMuxes = moduleMuxMap.getOrElse(myName, 0)
val myInstanceMuxes =
- moduleInstanceMap.getOrElse(myName, Nil).foldLeft(0) {
- (total, name) => total + countMux(name)
+ moduleInstanceMap.getOrElse(myName, Nil).foldLeft(0) { (total, name) =>
+ total + countMux(name)
}
myMuxes + myInstanceMuxes
}
// Display recursive total of muxes
def serialize: String = {
- modules map { myName => s"$myName => ${countMux(myName)} muxes!" } mkString "\n"
+ modules.map { myName => s"$myName => ${countMux(myName)} muxes!" }.mkString("\n")
}
}
@@ -76,7 +76,6 @@ class Ledger {
* - Kind -> ExpKind
* - Flow -> UnknownFlow
* - Type -> UnknownType
- *
*/
class AnalyzeCircuit extends Transform {
def inputForm = LowForm
@@ -88,7 +87,7 @@ class AnalyzeCircuit extends Transform {
val circuit = state.circuit
// Execute the function walkModule(ledger) on all [[DefModule]] in circuit
- circuit map walkModule(ledger)
+ circuit.map(walkModule(ledger))
// Print our ledger
println(ledger.serialize)
@@ -103,13 +102,13 @@ class AnalyzeCircuit extends Transform {
ledger.setModuleName(m.name)
// Execute the function walkStatement(ledger) on every [[Statement]] in m.
- m map walkStatement(ledger)
+ m.map(walkStatement(ledger))
}
// Deeply visits every [[Statement]] and [[Expression]] in s.
def walkStatement(ledger: Ledger)(s: Statement): Statement = {
// Map the functions walkStatement(ledger) and walkExpression(ledger)
- val visited = s map walkStatement(ledger) map walkExpression(ledger)
+ val visited = s.map(walkStatement(ledger)).map(walkExpression(ledger))
visited match {
case DefInstance(info, name, module, tpe) =>
ledger.foundInstance(module)
@@ -122,7 +121,7 @@ class AnalyzeCircuit extends Transform {
def walkExpression(ledger: Ledger)(e: Expression): Expression = {
// Execute the function walkExpression(ledger) on every [[Expression]] in e,
// then handle if a [[Mux]].
- e map walkExpression(ledger) match {
+ e.map(walkExpression(ledger)) match {
case mux: Mux =>
ledger.foundMux()
mux