aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala55
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala93
-rw-r--r--src/main/scala/firrtl/transforms/CombineCats.scala33
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala463
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala130
-rw-r--r--src/main/scala/firrtl/transforms/Dedup.scala297
-rw-r--r--src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala25
-rw-r--r--src/main/scala/firrtl/transforms/Flatten.scala179
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala30
-rw-r--r--src/main/scala/firrtl/transforms/GroupComponents.scala123
-rw-r--r--src/main/scala/firrtl/transforms/InferResets.scala93
-rw-r--r--src/main/scala/firrtl/transforms/InlineBitExtractions.scala48
-rw-r--r--src/main/scala/firrtl/transforms/InlineCasts.scala43
-rw-r--r--src/main/scala/firrtl/transforms/LegalizeClocks.scala16
-rw-r--r--src/main/scala/firrtl/transforms/LegalizeReductions.scala8
-rw-r--r--src/main/scala/firrtl/transforms/ManipulateNames.scala287
-rw-r--r--src/main/scala/firrtl/transforms/OptimizationAnnotations.scala18
-rw-r--r--src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala168
-rw-r--r--src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala27
-rw-r--r--src/main/scala/firrtl/transforms/RemoveReset.scala7
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala49
-rw-r--r--src/main/scala/firrtl/transforms/RenameModules.scala2
-rw-r--r--src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala3
-rw-r--r--src/main/scala/firrtl/transforms/SimplifyMems.scala11
-rw-r--r--src/main/scala/firrtl/transforms/TopWiring.scala269
-rw-r--r--src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala23
-rw-r--r--src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala8
-rw-r--r--src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala23
28 files changed, 1368 insertions, 1163 deletions
diff --git a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala
index a57973d5..5000e07a 100644
--- a/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala
+++ b/src/main/scala/firrtl/transforms/BlackBoxSourceHelper.scala
@@ -2,7 +2,7 @@
package firrtl.transforms
-import java.io.{File, FileNotFoundException, FileInputStream, FileOutputStream, PrintWriter}
+import java.io.{File, FileInputStream, FileNotFoundException, FileOutputStream, PrintWriter}
import firrtl._
import firrtl.annotations._
@@ -11,31 +11,32 @@ import scala.collection.immutable.ListSet
sealed trait BlackBoxHelperAnno extends Annotation
-case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno
- with NoTargetAnnotation {
+case class BlackBoxTargetDirAnno(targetDir: String) extends BlackBoxHelperAnno with NoTargetAnnotation {
override def serialize: String = s"targetDir\n$targetDir"
}
-case class BlackBoxResourceAnno(target: ModuleName, resourceId: String) extends BlackBoxHelperAnno
+case class BlackBoxResourceAnno(target: ModuleName, resourceId: String)
+ extends BlackBoxHelperAnno
with SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(target = n)
override def serialize: String = s"resource\n$resourceId"
}
-case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String) extends BlackBoxHelperAnno
+case class BlackBoxInlineAnno(target: ModuleName, name: String, text: String)
+ extends BlackBoxHelperAnno
with SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(target = n)
override def serialize: String = s"inline\n$name\n$text"
}
-case class BlackBoxPathAnno(target: ModuleName, path: String) extends BlackBoxHelperAnno
+case class BlackBoxPathAnno(target: ModuleName, path: String)
+ extends BlackBoxHelperAnno
with SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(target = n)
override def serialize: String = s"path\n$path"
}
-case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno
- with NoTargetAnnotation {
+case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackBoxHelperAnno with NoTargetAnnotation {
override def serialize: String = s"resourceFileName\n$resourceFileName"
}
@@ -43,8 +44,10 @@ case class BlackBoxResourceFileNameAnno(resourceFileName: String) extends BlackB
* @param fileName the name of the BlackBox file (only used for error message generation)
* @param e an underlying exception that generated this
*/
-class BlackBoxNotFoundException(fileName: String, message: String) extends FirrtlUserException(
- s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message")
+class BlackBoxNotFoundException(fileName: String, message: String)
+ extends FirrtlUserException(
+ s"BlackBox '$fileName' not found. Did you misspell it? Is it in src/{main,test}/resources?\n$message"
+ )
/** Handle source for Verilog ExtModules (BlackBoxes)
*
@@ -72,15 +75,16 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration {
*/
def collectAnnos(annos: Seq[Annotation]): (ListSet[BlackBoxHelperAnno], File, File) =
annos.foldLeft((ListSet.empty[BlackBoxHelperAnno], DefaultTargetDir, new File(defaultFileListName))) {
- case ((acc, tdir, flistName), anno) => anno match {
- case BlackBoxTargetDirAnno(dir) =>
- val targetDir = new File(dir)
- if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) }
- (acc, targetDir, flistName)
- case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName))
- case a: BlackBoxHelperAnno => (acc + a, tdir, flistName)
- case _ => (acc, tdir, flistName)
- }
+ case ((acc, tdir, flistName), anno) =>
+ anno match {
+ case BlackBoxTargetDirAnno(dir) =>
+ val targetDir = new File(dir)
+ if (!targetDir.exists()) { FileUtils.makeDirectory(targetDir.getAbsolutePath) }
+ (acc, targetDir, flistName)
+ case BlackBoxResourceFileNameAnno(fileName) => (acc, tdir, new File(fileName))
+ case a: BlackBoxHelperAnno => (acc + a, tdir, flistName)
+ case _ => (acc, tdir, flistName)
+ }
}
/**
@@ -112,14 +116,15 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration {
case BlackBoxInlineAnno(_, name, text) =>
val outFile = new File(targetDir, name)
(text, outFile)
- }.map { case (text, file) =>
- writeTextToFile(text, file)
- file
+ }.map {
+ case (text, file) =>
+ writeTextToFile(text, file)
+ file
}
// Issue #917 - We don't want to list Verilog header files ("*.vh") in our file list - they will automatically be included by reference.
def isHeader(name: String) = name.endsWith(".h") || name.endsWith(".vh") || name.endsWith(".svh")
- val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot{ f => isHeader(f.getName()) }
+ val verilogSourcesOnly = (resourceFiles ++ inlineFiles).filterNot { f => isHeader(f.getName()) }
val filelistFile = if (flistName.isAbsolute()) flistName else new File(targetDir, flistName.getName())
// We need the canonical path here, so verilator will create a path to the file that works from the targetDir,
@@ -137,12 +142,14 @@ class BlackBoxSourceHelper extends Transform with DependencyAPIMigration {
}
object BlackBoxSourceHelper {
+
/** Safely access a file converting [[FileNotFoundException]]s and [[NullPointerException]]s into
* [[BlackBoxNotFoundException]]s
* @param fileName the name of the file to be accessed (only used for error message generation)
* @param code some code to run
*/
- private def safeFile[A](fileName: String)(code: => A) = try { code } catch {
+ private def safeFile[A](fileName: String)(code: => A) = try { code }
+ catch {
case e @ (_: FileNotFoundException | _: NullPointerException) =>
throw new BlackBoxNotFoundException(fileName, e.getMessage)
}
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index 6403be23..ee4c1d0b 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -24,6 +24,7 @@ import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
object LogicNode {
+
/**
* Construct a LogicNode from a *Low FIRRTL* reference or subfield that refers to a component.
* Since aggregate types appear in Low FIRRTL only as the full types of instances or memories,
@@ -39,11 +40,11 @@ object LogicNode {
case s: WSubField =>
s.expr match {
case modref: WRef =>
- LogicNode(s.name,Some(modref.name))
+ LogicNode(s.name, Some(modref.name))
case memport: WSubField =>
memport.expr match {
case memref: WRef =>
- LogicNode(s.name,Some(memref.name),Some(memport.name))
+ LogicNode(s.name, Some(memref.name), Some(memport.name))
case _ => throwInternalError(s"LogicNode: unrecognized subsubfield expression - $memport")
}
case _ => throwInternalError(s"LogicNode: unrecognized subfield expression - $s")
@@ -56,9 +57,8 @@ object CheckCombLoops {
type ConnMap = DiGraph[LogicNode] with EdgeData[LogicNode, Info]
type MutableConnMap = MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
-
- class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException(
- s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))
+ class CombLoopException(info: Info, mname: String, cycle: Seq[String])
+ extends PassException(s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))
}
case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation
@@ -73,7 +73,7 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge
override def update(renames: RenameMap): Seq[Annotation] = {
val sources = renames.get(source).getOrElse(Seq(source))
val sinks = renames.get(sink).getOrElse(Seq(sink))
- val paths = sources flatMap { s => sinks.map((s, _)) }
+ val paths = sources.flatMap { s => sinks.map((s, _)) }
paths.collect {
case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink)
}
@@ -82,8 +82,8 @@ case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarge
case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget]) extends Annotation {
override def update(renames: RenameMap): Seq[Annotation] = {
- val newSources = sources.flatMap { s => renames(s) }.collect {case x: ReferenceTarget if x.isLocal => x}
- val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x}
+ val newSources = sources.flatMap { s => renames(s) }.collect { case x: ReferenceTarget if x.isLocal => x }
+ val newSinks = renames(sink).collect { case x: ReferenceTarget if x.isLocal => x }
newSinks.map(snk => CombinationalPath(snk, newSources))
}
}
@@ -98,14 +98,10 @@ case class CombinationalPath(sink: ReferenceTarget, sources: Seq[ReferenceTarget
* @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules
* @note The pass will throw exceptions on "false paths"
*/
-class CheckCombLoops extends Transform
- with RegisteredTransform
- with DependencyAPIMigration {
+class CheckCombLoops extends Transform with RegisteredTransform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
- Dependency(firrtl.transforms.RemoveReset) )
+ Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize), Dependency(firrtl.transforms.RemoveReset))
override def optionalPrerequisites = Seq.empty
@@ -119,17 +115,21 @@ class CheckCombLoops extends Transform
new ShellOption[Unit](
longOption = "no-check-comb-loops",
toAnnotationSeq = (_: Unit) => Seq(DontCheckCombLoopsAnnotation),
- helpText = "Disable combinational loop checking" ) )
+ helpText = "Disable combinational loop checking"
+ )
+ )
private def getExprDeps(deps: MutableConnMap, v: LogicNode, info: Info)(e: Expression): Unit = e match {
- case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info)
+ case r: WRef => deps.addEdgeIfValid(v, LogicNode(r), info)
case s: WSubField => deps.addEdgeIfValid(v, LogicNode(s), info)
case _ => e.foreach(getExprDeps(deps, v, info))
}
private def getStmtDeps(
simplifiedModules: mutable.Map[String, AbstractConnMap],
- deps: MutableConnMap)(s: Statement): Unit = s match {
+ deps: MutableConnMap
+ )(s: Statement
+ ): Unit = s match {
case Connect(info, loc, expr) =>
val lhs = LogicNode(loc)
if (deps.contains(lhs)) {
@@ -152,9 +152,9 @@ class CheckCombLoops extends Transform
case i: WDefInstance =>
val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
iGraph.getVertices.foreach(deps.addVertex(_))
- iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v, _) } })
case _ =>
- s.foreach(getStmtDeps(simplifiedModules,deps))
+ s.foreach(getStmtDeps(simplifiedModules, deps))
}
// Pretty-print a LogicNode with a prepended hierarchical path
@@ -169,24 +169,26 @@ class CheckCombLoops extends Transform
* recovered.
*/
private def expandInstancePaths(
- m: String,
+ m: String,
moduleGraphs: mutable.Map[String, ConnMap],
- moduleDeps: Map[String, Map[String, String]],
- hierPrefix: Seq[String],
- path: Seq[LogicNode]): Seq[String] = {
+ moduleDeps: Map[String, Map[String, String]],
+ hierPrefix: Seq[String],
+ path: Seq[LogicNode]
+ ): Seq[String] = {
// Recover info from edge data, add to error string
def info(u: LogicNode, v: LogicNode): String =
moduleGraphs(m).getEdgeData(u, v).map(_.toString).mkString("\t", "", "")
// lhs comes after rhs
- val pathNodes = (path zip path.tail) map { case (rhs, lhs) =>
- if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) {
- val child = moduleDeps(m)(lhs.inst.get)
- val newHierPrefix = hierPrefix :+ lhs.inst.get
- val subpath = moduleGraphs(child).path(lhs.copy(inst=None),rhs.copy(inst=None)).reverse
- expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath)
- } else {
- Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs))
- }
+ val pathNodes = (path.zip(path.tail)).map {
+ case (rhs, lhs) =>
+ if (lhs.inst.isDefined && !lhs.memport.isDefined && lhs.inst == rhs.inst) {
+ val child = moduleDeps(m)(lhs.inst.get)
+ val newHierPrefix = hierPrefix :+ lhs.inst.get
+ val subpath = moduleGraphs(child).path(lhs.copy(inst = None), rhs.copy(inst = None)).reverse
+ expandInstancePaths(child, moduleGraphs, moduleDeps, newHierPrefix, subpath)
+ } else {
+ Seq(prettyPrintAbsoluteRef(hierPrefix, lhs) ++ info(lhs, rhs))
+ }
}
pathNodes.flatten
}
@@ -238,12 +240,13 @@ class CheckCombLoops extends Transform
val errors = new Errors()
val extModulePaths = state.annotations.groupBy {
case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module)
- case ann: Annotation => CircuitTarget(c.main)
+ case ann: Annotation => CircuitTarget(c.main)
}
- val moduleMap = c.modules.map({m => (m.name,m) }).toMap
+ val moduleMap = c.modules.map({ m => (m.name, m) }).toMap
val iGraph = InstanceKeyGraph(c).graph
- val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
- val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
+ val moduleDeps =
+ iGraph.getEdgeMap.map({ case (k, v) => (k.module, (v.map { i => (i.name, i.module) }).toMap) }).toMap
+ val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse.map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String, ConnMap]
val simplifiedModuleGraphs = new mutable.HashMap[String, AbstractConnMap]
topoSortedModules.foreach {
@@ -252,7 +255,8 @@ class CheckCombLoops extends Transform
val extModuleDeps = new MutableDiGraph[LogicNode] with MutableEdgeData[LogicNode, Info]
portSet.foreach(extModuleDeps.addVertex(_))
extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect {
- case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
+ case a: ExtModulePathAnnotation =>
+ extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
}
moduleGraphs(em.name) = extModuleDeps
simplifiedModuleGraphs(em.name) = extModuleDeps.simplify(portSet)
@@ -270,7 +274,7 @@ class CheckCombLoops extends Transform
for (scc <- internalDeps.findSCCs.filter(_.length > 1)) {
val sccSubgraph = internalDeps.subgraph(scc.toSet)
val cycle = findCycleInSCC(sccSubgraph)
- (cycle zip cycle.tail).foreach({ case (a,b) => require(internalDeps.getEdges(a).contains(b)) })
+ (cycle.zip(cycle.tail)).foreach({ case (a, b) => require(internalDeps.getEdges(a).contains(b)) })
// Reverse to make sure LHS comes after RHS, print repeated vertex at start for legibility
val intuitiveCycle = cycle.reverse
val repeatedInitial = prettyPrintAbsoluteRef(Seq(m.name), intuitiveCycle.head)
@@ -280,10 +284,11 @@ class CheckCombLoops extends Transform
case m => throwInternalError(s"Module ${m.name} has unrecognized type")
}
val mt = ModuleTarget(c.main, c.main)
- val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty =>
- val sink = mt.ref(from.name)
- val sources = tos.map(to => mt.ref(to.name))
- CombinationalPath(sink, sources.toSeq)
+ val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect {
+ case (from, tos) if tos.nonEmpty =>
+ val sink = mt.ref(from.name)
+ val sources = tos.map(to => mt.ref(to.name))
+ CombinationalPath(sink, sources.toSeq)
}
(state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs, moduleGraphs)
}
@@ -291,7 +296,7 @@ class CheckCombLoops extends Transform
/**
* Returns a Map from Module name to port connectivity
*/
- def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = {
+ def analyze(state: CircuitState): collection.Map[String, DiGraph[String]] = {
val (result, errors, connectivity, _) = run(state)
connectivity.map {
case (k, v) => (k, v.transformNodes(ln => ln.name))
@@ -301,7 +306,7 @@ class CheckCombLoops extends Transform
/**
* Returns a Map from Module name to complete netlist connectivity
*/
- def analyzeFull(state: CircuitState): collection.Map[String,DiGraph[LogicNode]] = {
+ def analyzeFull(state: CircuitState): collection.Map[String, DiGraph[LogicNode]] = {
run(state)._4
}
diff --git a/src/main/scala/firrtl/transforms/CombineCats.scala b/src/main/scala/firrtl/transforms/CombineCats.scala
index 7fa01e46..3014d0e3 100644
--- a/src/main/scala/firrtl/transforms/CombineCats.scala
+++ b/src/main/scala/firrtl/transforms/CombineCats.scala
@@ -1,4 +1,3 @@
-
package firrtl
package transforms
@@ -14,26 +13,30 @@ import scala.collection.mutable
case class MaxCatLenAnnotation(maxCatLen: Int) extends NoTargetAnnotation
object CombineCats {
+
/** Mapping from references to the [[firrtl.ir.Expression Expression]]s that drive them paired with their Cat length */
type Netlist = mutable.HashMap[WrappedExpression, (Int, Expression)]
def expandCatArgs(maxCatLen: Int, netlist: Netlist)(expr: Expression): (Int, Expression) = expr match {
- case cat@DoPrim(Cat, args, _, _) =>
+ case cat @ DoPrim(Cat, args, _, _) =>
val (a0Len, a0Expanded) = expandCatArgs(maxCatLen - 1, netlist)(args.head)
val (a1Len, a1Expanded) = expandCatArgs(maxCatLen - a0Len, netlist)(args(1))
(a0Len + a1Len, cat.copy(args = Seq(a0Expanded, a1Expanded)).asInstanceOf[Expression])
case other =>
- netlist.get(we(expr)).collect {
- case (len, cat@DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat)
- }.getOrElse((1, other))
+ netlist
+ .get(we(expr))
+ .collect {
+ case (len, cat @ DoPrim(Cat, _, _, _)) if maxCatLen >= len => expandCatArgs(maxCatLen, netlist)(cat)
+ }
+ .getOrElse((1, other))
}
def onStmt(maxCatLen: Int, netlist: Netlist)(stmt: Statement): Statement = {
stmt.map(onStmt(maxCatLen, netlist)) match {
- case node@DefNode(_, name, value) =>
+ case node @ DefNode(_, name, value) =>
val catLenAndVal = value match {
- case cat@DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat)
- case other => (1, other)
+ case cat @ DoPrim(Cat, _, _, _) => expandCatArgs(maxCatLen, netlist)(cat)
+ case other => (1, other)
}
netlist(we(WRef(name))) = catLenAndVal
node.copy(value = catLenAndVal._2)
@@ -55,16 +58,16 @@ object CombineCats {
class CombineCats extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions) )
+ Seq(
+ Dependency(passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions)
+ )
override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Seq(
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ override def optionalPrerequisiteOf = Seq(Dependency[SystemVerilogEmitter], Dependency[VerilogEmitter])
override def invalidates(a: Transform) = false
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index ce36dd72..dc9b2bbe 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -28,7 +28,7 @@ object ConstantPropagation {
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
- case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
+ case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
case (we, wt) if we == wt => e
}
@@ -44,38 +44,40 @@ object ConstantPropagation {
case lit: Literal =>
require(hi >= lo)
UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), getWidth(e.tpe))
- case x if bitWidth(e.tpe) == bitWidth(x.tpe) => x.tpe match {
- case t: UIntType => x
- case _ => asUInt(x, e.tpe)
- }
+ case x if bitWidth(e.tpe) == bitWidth(x.tpe) =>
+ x.tpe match {
+ case t: UIntType => x
+ case _ => asUInt(x, e.tpe)
+ }
case _ => e
}
}
def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
- case x => e.args.head match {
- // TODO when amount >= x.width, return a zero-width wire
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1))
- // take sign bit if shift amount is larger than arg width
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1))
- case _ => e
- }
+ case x =>
+ e.args.head match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x).max(1)))
+ // take sign bit if shift amount is larger than arg width
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x).max(1)))
+ case _ => e
+ }
}
-
- /**********************************************
- * REGISTER CONSTANT PROPAGATION HELPER TYPES *
- **********************************************/
+ /** ********************************************
+ * REGISTER CONSTANT PROPAGATION HELPER TYPES *
+ * ********************************************
+ */
// A utility class that is somewhat like an Option but with two variants containing Nothing.
// for register constant propagation (register or literal).
private abstract class ConstPropBinding[+T] {
def resolve[V >: T](that: ConstPropBinding[V]): ConstPropBinding[V] = (this, that) match {
- case (x, y) if (x == y) => x
+ case (x, y) if (x == y) => x
case (x, UnboundConstant) => x
case (UnboundConstant, y) => y
- case _ => NonConstant
+ case _ => NonConstant
}
}
@@ -103,21 +105,23 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
override def prerequisites =
((new mutable.LinkedHashSet())
- ++ firrtl.stage.Forms.LowForm
- - Dependency(firrtl.passes.Legalize)
- + Dependency(firrtl.passes.RemoveValidIf)).toSeq
+ ++ firrtl.stage.Forms.LowForm
+ - Dependency(firrtl.passes.Legalize)
+ + Dependency(firrtl.passes.RemoveValidIf)).toSeq
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf =
- Seq( Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[SystemVerilogEmitter],
- Dependency[VerilogEmitter] )
+ Seq(
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[SystemVerilogEmitter],
+ Dependency[VerilogEmitter]
+ )
override def invalidates(a: Transform): Boolean = a match {
case firrtl.passes.Legalize => true
- case _ => false
+ case _ => false
}
override val annotationClasses: Traversable[Class[_]] = Seq(classOf[DontTouchAnnotation])
@@ -130,7 +134,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
sealed trait FoldCommutativeOp extends SimplifyBinaryOp {
- def fold(c1: Literal, c2: Literal): Expression
+ def fold(c1: Literal, c2: Literal): Expression
def simplify(e: Expression, lhs: Literal, rhs: Expression): Expression
override def apply(e: DoPrim): Expression = (e.args.head, e.args(1)) match {
@@ -138,7 +142,7 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case (lhs: Literal, rhs) => pad(simplify(e, lhs, rhs), e.tpe)
case (lhs, rhs: Literal) => pad(simplify(e, rhs, lhs), e.tpe)
case (lhs, rhs) if (lhs == rhs) => matchingArgsValue(e, lhs)
- case _ => e
+ case _ => e
}
}
@@ -177,20 +181,20 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
*/
def apply(prim: DoPrim): Expression = prim.args.head match {
case a: Literal => simplifyLiteral(a)
- case _ => prim
+ case _ => prim
}
}
object FoldADD extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
- case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
- case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
+ case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1))
+ case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width.max(c2.width)) + IntWidth(1))
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => rhs
case SIntLiteral(v, w) if v == BigInt(0) => rhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = e
}
@@ -209,77 +213,81 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
object FoldAND extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked(c1.value & c2.value, width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
- case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
+ case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => rhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked((c1.value | c2.value), width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
- case UIntLiteral(v, _) if v == BigInt(0) => rhs
- case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
+ case UIntLiteral(v, _) if v == BigInt(0) => rhs
+ case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
case UIntLiteral(v, IntWidth(w)) if v == (BigInt(1) << bitWidth(rhs.tpe).toInt) - 1 => lhs
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = asUInt(arg, e.tpe)
}
object FoldXOR extends FoldCommutativeOp {
def fold(c1: Literal, c2: Literal) = {
- val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ val width = (c1.width.max(c2.width)).asInstanceOf[IntWidth]
UIntLiteral.masked((c1.value ^ c2.value), width)
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
- case _ => e
+ case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe))
}
object FoldEqual extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) =>
+ DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(1)
}
object FoldNotEqual extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
+ def fold(c1: Literal, c2: Literal) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1))
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, IntWidth(w)) if v == BigInt(0) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => rhs
- case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) => DoPrim(Not, Seq(rhs), Nil, e.tpe)
+ case UIntLiteral(v, IntWidth(w)) if v == BigInt(1) && w == BigInt(1) && bitWidth(rhs.tpe) == BigInt(1) =>
+ DoPrim(Not, Seq(rhs), Nil, e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0)
}
private def foldConcat(e: DoPrim) = (e.args.head, e.args(1)) match {
- case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
+ case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) =>
+ UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw))
case _ => e
}
private def foldShiftLeft(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
- case x => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
- case _ => e
- }
+ case x =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x))
+ case _ => e
+ }
}
private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
@@ -296,53 +304,55 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case _ => e
}
-
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
case UIntType(_) => true
- case _ => false
+ case _ => false
}
def isZero(e: Expression) = e match {
- case UIntLiteral(value, _) => value == BigInt(0)
- case SIntLiteral(value, _) => value == BigInt(0)
- case _ => false
- }
+ case UIntLiteral(value, _) => value == BigInt(0)
+ case SIntLiteral(value, _) => value == BigInt(0)
+ case _ => false
+ }
x match {
- case DoPrim(Lt, Seq(a,b),_,_) if isUInt(a) && isZero(b) => zero
- case DoPrim(Leq, Seq(a,b),_,_) if isZero(a) && isUInt(b) => one
- case DoPrim(Gt, Seq(a,b),_,_) if isZero(a) && isUInt(b) => zero
- case DoPrim(Geq, Seq(a,b),_,_) if isUInt(a) && isZero(b) => one
- case ex => ex
+ case DoPrim(Lt, Seq(a, b), _, _) if isUInt(a) && isZero(b) => zero
+ case DoPrim(Leq, Seq(a, b), _, _) if isZero(a) && isUInt(b) => one
+ case DoPrim(Gt, Seq(a, b), _, _) if isZero(a) && isUInt(b) => zero
+ case DoPrim(Geq, Seq(a, b), _, _) if isUInt(a) && isZero(b) => one
+ case ex => ex
}
}
def foldIfOutsideRange(x: Expression): Expression = {
//Note, only abides by a partial ordering
case class Range(min: BigInt, max: BigInt) {
- def === (that: Range) =
+ def ===(that: Range) =
Seq(this.min, this.max, that.min, that.max)
- .sliding(2,1)
+ .sliding(2, 1)
.map(x => x.head == x(1))
.reduce(_ && _)
- def > (that: Range) = this.min > that.max
- def >= (that: Range) = this.min >= that.max
- def < (that: Range) = this.max < that.min
- def <= (that: Range) = this.max <= that.min
+ def >(that: Range) = this.min > that.max
+ def >=(that: Range) = this.min >= that.max
+ def <(that: Range) = this.max < that.min
+ def <=(that: Range) = this.max <= that.min
}
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
- case _ => e.tpe match {
- case SIntType(IntWidth(width)) => Range(
- min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
- max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
- )
- case UIntType(IntWidth(width)) => Range(
- min = BigInt(0),
- max = BigInt(2).pow(width.toInt) - BigInt(1)
- )
- }
+ case _ =>
+ e.tpe match {
+ case SIntType(IntWidth(width)) =>
+ Range(
+ min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
+ max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
+ )
+ case UIntType(IntWidth(width)) =>
+ Range(
+ min = BigInt(0),
+ max = BigInt(2).pow(width.toInt) - BigInt(1)
+ )
+ }
}
// Calculates an expression's range of values
x match {
@@ -351,27 +361,28 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
def r1 = range(ex.args(1))
ex.op match {
// Always true
- case Lt if r0 < r1 => one
+ case Lt if r0 < r1 => one
case Leq if r0 <= r1 => one
- case Gt if r0 > r1 => one
+ case Gt if r0 > r1 => one
case Geq if r0 >= r1 => one
// Always false
- case Lt if r0 >= r1 => zero
+ case Lt if r0 >= r1 => zero
case Leq if r0 > r1 => zero
- case Gt if r0 <= r1 => zero
+ case Gt if r0 <= r1 => zero
case Geq if r0 < r1 => zero
- case _ => ex
+ case _ => ex
}
case ex => ex
}
}
def foldIfMatchingArgs(x: Expression) = x match {
- case DoPrim(op, Seq(a, b), _, _) if (a == b) => op match {
- case (Lt | Gt) => zero
- case (Leq | Geq) => one
- case _ => x
- }
+ case DoPrim(op, Seq(a, b), _, _) if (a == b) =>
+ op match {
+ case (Lt | Gt) => zero
+ case (Leq | Geq) => one
+ case _ => x
+ }
case _ => x
}
foldIfZeroedArg(foldIfOutsideRange(foldIfMatchingArgs(e)))
@@ -393,43 +404,47 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
private def constPropPrim(e: DoPrim): Expression = e.op match {
- case Shl => foldShiftLeft(e)
- case Dshl => foldDynamicShiftLeft(e)
- case Shr => foldShiftRight(e)
- case Dshr => foldDynamicShiftRight(e)
- case Cat => foldConcat(e)
- case Add => FoldADD(e)
- case Sub => SimplifySUB(e)
- case Div => SimplifyDIV(e)
- case Rem => SimplifyREM(e)
- case And => FoldAND(e)
- case Or => FoldOR(e)
- case Xor => FoldXOR(e)
- case Eq => FoldEqual(e)
- case Neq => FoldNotEqual(e)
- case Andr => FoldANDR(e)
- case Orr => FoldORR(e)
- case Xorr => FoldXORR(e)
+ case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
+ case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
+ case Cat => foldConcat(e)
+ case Add => FoldADD(e)
+ case Sub => SimplifySUB(e)
+ case Div => SimplifyDIV(e)
+ case Rem => SimplifyREM(e)
+ case And => FoldAND(e)
+ case Or => FoldOR(e)
+ case Xor => FoldXOR(e)
+ case Eq => FoldEqual(e)
+ case Neq => FoldNotEqual(e)
+ case Andr => FoldANDR(e)
+ case Orr => FoldORR(e)
+ case Xorr => FoldXORR(e)
case (Lt | Leq | Gt | Geq) => foldComparison(e)
- case Not => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
- case _ => e
- }
+ case Not =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w))
+ case _ => e
+ }
case AsUInt =>
e.args.head match {
case SIntLiteral(v, IntWidth(w)) => UIntLiteral(v + (if (v < 0) BigInt(1) << w.toInt else 0), IntWidth(w))
- case arg => arg.tpe match {
- case _: UIntType => arg
- case _ => e
- }
+ case arg =>
+ arg.tpe match {
+ case _: UIntType => arg
+ case _ => e
+ }
}
- case AsSInt => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt-1)) << w.toInt), IntWidth(w))
- case arg => arg.tpe match {
- case _: SIntType => arg
- case _ => e
+ case AsSInt =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => SIntLiteral(v - ((v >> (w.toInt - 1)) << w.toInt), IntWidth(w))
+ case arg =>
+ arg.tpe match {
+ case _: SIntType => arg
+ case _ => e
+ }
}
- }
case AsClock =>
val arg = e.args.head
arg.tpe match {
@@ -442,25 +457,27 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
case AsyncResetType => arg
case _ => e
}
- case Pad => e.args.head match {
- case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head max w))
- case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head max w))
- case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
- case _ => e
- }
+ case Pad =>
+ e.args.head match {
+ case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v, IntWidth(e.consts.head.max(w)))
+ case _ if bitWidth(e.args.head.tpe) >= e.consts.head => e.args.head
+ case _ => e
+ }
case (Bits | Head | Tail) => constPropBitExtract(e)
- case _ => e
+ case _ => e
}
private def constPropMuxCond(m: Mux) = m.cond match {
case UIntLiteral(c, _) => pad(if (c == BigInt(1)) m.tval else m.fval, m.tpe)
- case _ => m
+ case _ => m
}
private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match {
case _ if m.tval == m.fval => m.tval
case (t: UIntLiteral, f: UIntLiteral)
- if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) => m.cond
+ if t.value == BigInt(1) && f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) =>
+ m.cond
case (t: UIntLiteral, _) if t.value == BigInt(1) && bitWidth(m.tpe) == BigInt(1) =>
DoPrim(Or, Seq(m.cond, m.fval), Nil, m.tpe)
case (_, f: UIntLiteral) if f.value == BigInt(0) && bitWidth(m.tpe) == BigInt(1) =>
@@ -479,15 +496,22 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Is "a" a "better name" than "b"?
private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_')
- def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
- def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
-
- private def constPropExpression(nodeMap: NodeMap, instMap: collection.Map[Instance, OfModule], constSubOutputs: Map[OfModule, Map[String, Literal]])(e: Expression): Expression = {
- val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
+ def optimize(e: Expression): Expression =
+ constPropExpression(new NodeMap(), Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+ def optimize(e: Expression, nodeMap: NodeMap): Expression =
+ constPropExpression(nodeMap, Map.empty[Instance, OfModule], Map.empty[OfModule, Map[String, Literal]])(e)
+
+ private def constPropExpression(
+ nodeMap: NodeMap,
+ instMap: collection.Map[Instance, OfModule],
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ )(e: Expression
+ ): Expression = {
+ val old = e.map(constPropExpression(nodeMap, instMap, constSubOutputs))
val propagated = old match {
case p: DoPrim => constPropPrim(p)
- case m: Mux => constPropMux(m)
- case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
+ case m: Mux => constPropMux(m)
+ case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, SourceFlow) =>
val module = instMap(inst.Instance)
@@ -506,17 +530,17 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
* @todo generalize source locator propagation across Expressions and delete this method
* @todo is the `orElse` the way we want to do propagation here?
*/
- private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])
- (stmt: Statement): Statement = stmt match {
- // We check rname because inlining it would cause the original declaration to go away
- case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
- val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
- node.copy(info = InfoExpr.orElse(info1, info0))
- case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
- val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
- con.copy(info = InfoExpr.orElse(info1, info0))
- case other => other
- }
+ private def propagateDirectConnectionInfoOnly(nodeMap: NodeMap, dontTouch: Set[String])(stmt: Statement): Statement =
+ stmt match {
+ // We check rname because inlining it would cause the original declaration to go away
+ case node @ DefNode(info0, name, WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ node.copy(info = InfoExpr.orElse(info1, info0))
+ case con @ Connect(info0, lhs, rref @ WRef(rname, _, NodeKind, _)) if !dontTouch(rname) =>
+ val (info1, _) = InfoExpr.unwrap(nodeMap(rname))
+ con.copy(info = InfoExpr.orElse(info1, info0))
+ case other => other
+ }
/* Constant propagate a Module
*
@@ -538,12 +562,12 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
*/
@tailrec
private def constPropModule(
- m: Module,
- dontTouches: Set[String],
- instMap: collection.Map[Instance, OfModule],
- constInputs: Map[String, Literal],
- constSubOutputs: Map[OfModule, Map[String, Literal]]
- ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
+ m: Module,
+ dontTouches: Set[String],
+ instMap: collection.Map[Instance, OfModule],
+ constInputs: Map[String, Literal],
+ constSubOutputs: Map[OfModule, Map[String, Literal]]
+ ): (Module, Map[String, Literal], Map[OfModule, Map[String, Seq[Literal]]]) = {
var nPropagated = 0L
val nodeMap = new NodeMap()
@@ -571,13 +595,13 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// to constant wires, we don't need to worry about propagating primops or muxes since we'll do
// that on the next iteration if necessary
def backPropExpr(expr: Expression): Expression = {
- val old = expr map backPropExpr
+ val old = expr.map(backPropExpr)
val propagated = old match {
// When swapping, we swap both rhs and lhs
- case ref @ WRef(rname, _,_,_) if swapMap.contains(rname) =>
+ case ref @ WRef(rname, _, _, _) if swapMap.contains(rname) =>
ref.copy(name = swapMap(rname))
// Only const prop on the rhs
- case ref @ WRef(rname, _,_, SourceFlow) if nodeMap.contains(rname) =>
+ case ref @ WRef(rname, _, _, SourceFlow) if nodeMap.contains(rname) =>
constPropNodeRef(ref, InfoExpr.unwrap(nodeMap(rname))._2)
case x => x
}
@@ -590,27 +614,29 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
def backPropStmt(stmt: Statement): Statement = stmt match {
case reg: DefRegister if (WrappedExpression.weq(reg.init, WRef(reg))) =>
// Self-init reset is an idiom for "no reset," and must be handled separately
- swapMap.get(reg.name)
- .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName)))
- .getOrElse(reg)
- case s => s map backPropExpr match {
- case decl: IsDeclaration if swapMap.contains(decl.name) =>
- val newName = swapMap(decl.name)
- nPropagated += 1
- decl match {
- case node: DefNode => node.copy(name = newName)
- case wire: DefWire => wire.copy(name = newName)
- case reg: DefRegister => reg.copy(name = newName)
- case other => throwInternalError()
- }
- case other => other map backPropStmt
- }
+ swapMap
+ .get(reg.name)
+ .map(newName => reg.copy(name = newName, init = WRef(reg).copy(name = newName)))
+ .getOrElse(reg)
+ case s =>
+ s.map(backPropExpr) match {
+ case decl: IsDeclaration if swapMap.contains(decl.name) =>
+ val newName = swapMap(decl.name)
+ nPropagated += 1
+ decl match {
+ case node: DefNode => node.copy(name = newName)
+ case wire: DefWire => wire.copy(name = newName)
+ case reg: DefRegister => reg.copy(name = newName)
+ case other => throwInternalError()
+ }
+ case other => other.map(backPropStmt)
+ }
}
// When propagating a reference, check if we want to keep the name that would be deleted
def propagateRef(lname: String, value: Expression, info: Info): Unit = {
value match {
- case WRef(rname,_,kind,_) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
+ case WRef(rname, _, kind, _) if betterName(lname, rname) && !swapMap.contains(rname) && kind != PortKind =>
assert(!swapMap.contains(lname)) // <- Shouldn't be possible because lname is either a
// node declaration or the single connection to a wire or register
swapMap += (lname -> rname, rname -> lname)
@@ -639,25 +665,24 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Const prop registers that are driven by a mux tree containing only instances of one constant or self-assigns
// This requires that reset has been made explicit
case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), rhs) if !dontTouches(lname) =>
-
- /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
- * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of
- * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can
- * be resolved: each side is allowed one candidate register and one candidate literal to
- * appear in their source trees, referring to the potential constant propagation case that
- * they could allow. If the two are compatible (no different bound sources of either of
- * the two types), they can be resolved by combining sources. Otherwise, they propagate
- * NonConstant values. When encountering a node reference, it expands the node by to its
- * RHS assignment and recurses.
- *
- * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This
- * can stifle register constant propagations, which looks at drivers through value-preserving
- * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into
- * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole
- * de-optimization that does not actually appear in the output code.
- *
- * @return a RegCPEntry describing the constant prop-compatible sources driving this expression
- */
+ /* Checks if an RHS expression e of a register assignment is convertible to a constant assignment.
+ * Here, this means that e must be 1) a literal, 2) a self-connect, or 3) a mux tree of
+ * cases (1) and (2). In case (3), it also recursively checks that the two mux cases can
+ * be resolved: each side is allowed one candidate register and one candidate literal to
+ * appear in their source trees, referring to the potential constant propagation case that
+ * they could allow. If the two are compatible (no different bound sources of either of
+ * the two types), they can be resolved by combining sources. Otherwise, they propagate
+ * NonConstant values. When encountering a node reference, it expands the node by to its
+ * RHS assignment and recurses.
+ *
+ * @note Some optimization of Mux trees turn 1-bit mux operators into boolean operators. This
+ * can stifle register constant propagations, which looks at drivers through value-preserving
+ * Muxes and Connects only. By speculatively expanding some 1-bit Or and And operations into
+ * muxes, we can obtain the best possible insight on the value of the mux with a simple peephole
+ * de-optimization that does not actually appear in the output code.
+ *
+ * @return a RegCPEntry describing the constant prop-compatible sources driving this expression
+ */
val unbound = RegCPEntry(UnboundConstant, UnboundConstant)
val selfBound = RegCPEntry(BoundConstant(lname), UnboundConstant)
@@ -684,11 +709,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Updates nodeMap after analyzing the returned value from regConstant
def updateNodeMapIfConstant(e: Expression): Unit = regConstant(e, selfBound) match {
- case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
+ case RegCPEntry(UnboundConstant, UnboundConstant) => nodeMap(lname) = padCPExp(zero)
case RegCPEntry(BoundConstant(_), UnboundConstant) => nodeMap(lname) = padCPExp(zero)
- case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
+ case RegCPEntry(UnboundConstant, BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
case RegCPEntry(BoundConstant(_), BoundConstant(lit)) => nodeMap(lname) = padCPExp(lit)
- case _ =>
+ case _ =>
}
def padCPExp(e: Expression) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(e, ltpe))
@@ -733,11 +758,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Unify two maps using f to combine values of duplicate keys
private def unify[K, V](a: Map[K, V], b: Map[K, V])(f: (V, V) => V): Map[K, V] =
- b.foldLeft(a) { case (acc, (k, v)) =>
- acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
+ b.foldLeft(a) {
+ case (acc, (k, v)) =>
+ acc + (k -> acc.get(k).map(f(_, v)).getOrElse(v))
}
-
private def run(c: Circuit, dontTouchMap: Map[OfModule, Set[String]]): Circuit = {
val iGraph = InstanceKeyGraph(c)
val moduleDeps = iGraph.getChildInstanceMap
@@ -754,9 +779,11 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// are driven with the same constant value. Then, if we find a Module input where each instance
// is driven with the same constant (and not seen in a previous iteration), we iterate again
@tailrec
- def iterate(toVisit: Set[OfModule],
- modules: Map[OfModule, Module],
- constInputs: Map[OfModule, Map[String, Literal]]): Map[OfModule, DefModule] = {
+ def iterate(
+ toVisit: Set[OfModule],
+ modules: Map[OfModule, Module],
+ constInputs: Map[OfModule, Map[String, Literal]]
+ ): Map[OfModule, DefModule] = {
if (toVisit.isEmpty) modules
else {
// Order from leaf modules to root so that any module driving an output
@@ -767,31 +794,36 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
// Aggreagte Module outputs that are driven constant for use by instaniating Modules
// Aggregate submodule inputs driven constant for checking later
val (modulesx, _, constInputsx) =
- order.foldLeft((modules,
- Map[OfModule, Map[String, Literal]](),
- Map[OfModule, Map[String, Seq[Literal]]]())) {
+ order.foldLeft((modules, Map[OfModule, Map[String, Literal]](), Map[OfModule, Map[String, Seq[Literal]]]())) {
case ((mmap, constOutputs, constInputsAcc), mname) =>
val dontTouches = dontTouchMap.getOrElse(mname, Set.empty)
- val (mx, mco, mci) = constPropModule(modules(mname), dontTouches, moduleDeps(mname),
- constInputs.getOrElse(mname, Map.empty), constOutputs)
+ val (mx, mco, mci) = constPropModule(
+ modules(mname),
+ dontTouches,
+ moduleDeps(mname),
+ constInputs.getOrElse(mname, Map.empty),
+ constOutputs
+ )
// Accumulate all Literals used to drive a particular Module port
val constInputsx = unify(constInputsAcc, mci)((a, b) => unify(a, b)((c, d) => c ++ d))
(mmap + (mname -> mx), constOutputs + (mname -> mco), constInputsx)
}
// Determine which module inputs have all of the same, new constants driving them
- val newProppedInputs = constInputsx.flatMap { case (mname, ports) =>
- val portsx = ports.flatMap { case (pname, lits) =>
- val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
- val isModule = modules.contains(mname) // ExtModules are not contained in modules
- val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
- if (isModule && newPort && allSameConst) Some(pname -> lits.head)
- else None
- }
- if (portsx.nonEmpty) Some(mname -> portsx) else None
+ val newProppedInputs = constInputsx.flatMap {
+ case (mname, ports) =>
+ val portsx = ports.flatMap {
+ case (pname, lits) =>
+ val newPort = !constInputs.get(mname).map(_.contains(pname)).getOrElse(false)
+ val isModule = modules.contains(mname) // ExtModules are not contained in modules
+ val allSameConst = lits.size == instCount(mname) && lits.toSet.size == 1
+ if (isModule && newPort && allSameConst) Some(pname -> lits.head)
+ else None
+ }
+ if (portsx.nonEmpty) Some(mname -> portsx) else None
}
val modsWithConstInputs = newProppedInputs.keySet
val newToVisit = modsWithConstInputs ++
- modsWithConstInputs.flatMap(parentGraph.reachableFrom)
+ modsWithConstInputs.flatMap(parentGraph.reachableFrom)
// Combine const inputs (there can't be duplicate values in the inner maps)
val nextConstInputs = unify(constInputs, newProppedInputs)((a, b) => a ++ b)
iterate(newToVisit.toSet, modulesx, nextConstInputs)
@@ -805,7 +837,6 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
c.modules.map(m => mmap.getOrElse(m.OfModule, m))
}
-
Circuit(c.info, modulesx, c.main)
}
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index c883bdfb..fb1bd1f6 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms
import firrtl._
@@ -8,7 +7,7 @@ import firrtl.annotations._
import firrtl.graph._
import firrtl.analyses.InstanceKeyGraph
import firrtl.Mappers._
-import firrtl.Utils.{throwInternalError, kind}
+import firrtl.Utils.{kind, throwInternalError}
import firrtl.MemoizedHash._
import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
@@ -29,29 +28,34 @@ import collection.mutable
* circumstances of their instantiation in their parent module, they will still not be removed. To
* remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication.
*/
-class DeadCodeElimination extends Transform
+class DeadCodeElimination
+ extends Transform
with ResolvedAnnotationPaths
with RegisteredTransform
with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination) )
+ Seq(
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[firrtl.transforms.CombineCats],
+ Dependency(passes.CommonSubexpressionElimination)
+ )
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf =
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(passes.VerilogPrep),
- Dependency[firrtl.AddDescriptionNodes] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(passes.VerilogPrep),
+ Dependency[firrtl.AddDescriptionNodes]
+ )
override def invalidates(a: Transform) = false
@@ -59,7 +63,9 @@ class DeadCodeElimination extends Transform
new ShellOption[Unit](
longOption = "no-dce",
toAnnotationSeq = (_: Unit) => Seq(NoDCEAnnotation),
- helpText = "Disable dead code elimination" ) )
+ helpText = "Disable dead code elimination"
+ )
+ )
/** Based on LogicNode ins CheckCombLoops, currently kind of faking it */
private type LogicNode = MemoizedHash[WrappedExpression]
@@ -72,6 +78,7 @@ class DeadCodeElimination extends Transform
val loweredName = LowerTypes.loweredName(component.name.split('.'))
apply(component.module.name, WRef(loweredName))
}
+
/** External Modules are representated as a single node driven by all inputs and driving all
* outputs
*/
@@ -87,7 +94,7 @@ class DeadCodeElimination extends Transform
def rec(e: Expression): Expression = {
e match {
case ref @ (_: WRef | _: WSubField) => refs += ref
- case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.map(rec)
case ignore @ (_: Literal) => // Do nothing
case unexpected => throwInternalError()
}
@@ -98,9 +105,7 @@ class DeadCodeElimination extends Transform
}
// Gets all dependencies and constructs LogicNodes from them
- private def getDepsImpl(mname: String,
- instMap: collection.Map[String, String])
- (expr: Expression): Seq[LogicNode] =
+ private def getDepsImpl(mname: String, instMap: collection.Map[String, String])(expr: Expression): Seq[LogicNode] =
extractRefs(expr).map { e =>
if (kind(e) == InstanceKind) {
val (inst, tail) = Utils.splitRef(e)
@@ -110,11 +115,12 @@ class DeadCodeElimination extends Transform
}
}
-
/** Construct the dependency graph within this module */
- private def setupDepGraph(depGraph: MutableDiGraph[LogicNode],
- instMap: collection.Map[String, String])
- (mod: Module): Unit = {
+ private def setupDepGraph(
+ depGraph: MutableDiGraph[LogicNode],
+ instMap: collection.Map[String, String]
+ )(mod: Module
+ ): Unit = {
def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
def onStmt(stmt: Statement): Unit = stmt match {
@@ -150,7 +156,7 @@ class DeadCodeElimination extends Transform
val node = getDeps(loc) match { case Seq(elt) => elt }
getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref))
// Simulation constructs are treated as top-level outputs
- case Stop(_,_, clk, en) =>
+ case Stop(_, _, clk, en) =>
Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Print(_, _, args, clk, en) =>
(args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
@@ -172,12 +178,14 @@ class DeadCodeElimination extends Transform
}
// TODO Make immutable?
- private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]],
- doTouchExtMods: Set[String],
- c: Circuit): MutableDiGraph[LogicNode] = {
+ private def createDependencyGraph(
+ instMaps: collection.Map[String, collection.Map[String, String]],
+ doTouchExtMods: Set[String],
+ c: Circuit
+ ): MutableDiGraph[LogicNode] = {
val depGraph = new MutableDiGraph[LogicNode]
c.modules.foreach {
- case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
+ case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
case ext: ExtModule =>
// Connect all inputs to all outputs
val node = LogicNode(ext)
@@ -205,23 +213,25 @@ class DeadCodeElimination extends Transform
depGraph
}
- private def deleteDeadCode(instMap: collection.Map[String, String],
- deadNodes: collection.Set[LogicNode],
- moduleMap: collection.Map[String, DefModule],
- renames: RenameMap,
- topName: String,
- doTouchExtMods: Set[String])
- (mod: DefModule): Option[DefModule] = {
+ private def deleteDeadCode(
+ instMap: collection.Map[String, String],
+ deadNodes: collection.Set[LogicNode],
+ moduleMap: collection.Map[String, DefModule],
+ renames: RenameMap,
+ topName: String,
+ doTouchExtMods: Set[String]
+ )(mod: DefModule
+ ): Option[DefModule] = {
// For log-level debug
def deleteMsg(decl: IsDeclaration): String = {
val tpe = decl match {
- case _: DefNode => "node"
+ case _: DefNode => "node"
case _: DefRegister => "reg"
- case _: DefWire => "wire"
- case _: Port => "port"
- case _: DefMemory => "mem"
+ case _: DefWire => "wire"
+ case _: Port => "port"
+ case _: DefMemory => "mem"
case (_: DefInstance | _: WDefInstance) => "inst"
- case _: Module => "module"
+ case _: Module => "module"
case _: ExtModule => "extmodule"
}
val ref = decl match {
@@ -237,7 +247,7 @@ class DeadCodeElimination extends Transform
def deleteIfNotEnabled(stmt: Statement, en: Expression): Statement = en match {
case UIntLiteral(v, _) if v == BigInt(0) => EmptyStmt
- case _ => stmt
+ case _ => stmt
}
def onStmt(stmt: Statement): Statement = {
@@ -256,12 +266,11 @@ class DeadCodeElimination extends Transform
logger.debug(deleteMsg(decl))
renames.delete(decl.name)
EmptyStmt
- }
- else decl
- case print: Print => deleteIfNotEnabled(print, print.en)
- case stop: Stop => deleteIfNotEnabled(stop, stop.en)
+ } else decl
+ case print: Print => deleteIfNotEnabled(print, print.en)
+ case stop: Stop => deleteIfNotEnabled(stop, stop.en)
case formal: Verification => deleteIfNotEnabled(formal, formal.en)
- case con: Connect =>
+ case con: Connect =>
val node = getDeps(con.loc) match { case Seq(elt) => elt }
if (deadNodes.contains(node)) EmptyStmt else con
case Attach(info, exprs) => // If any exprs are dead then all are
@@ -270,7 +279,7 @@ class DeadCodeElimination extends Transform
case IsInvalid(info, expr) =>
val node = getDeps(expr) match { case Seq(elt) => elt }
if (deadNodes.contains(node)) EmptyStmt else IsInvalid(info, expr)
- case block: Block => block map onStmt
+ case block: Block => block.map(onStmt)
case other => other
}
stmtx match { // Check if module empty
@@ -300,8 +309,7 @@ class DeadCodeElimination extends Transform
if (portsx.isEmpty && doTouchExtMods.contains(ext.name)) {
logger.debug(deleteMsg(mod))
None
- }
- else {
+ } else {
if (ext.ports != portsx) throwInternalError() // Sanity check
Some(ext.copy(ports = portsx))
}
@@ -309,14 +317,13 @@ class DeadCodeElimination extends Transform
}
- def run(state: CircuitState,
- dontTouches: Seq[LogicNode],
- doTouchExtMods: Set[String]): CircuitState = {
+ def run(state: CircuitState, dontTouches: Seq[LogicNode], doTouchExtMods: Set[String]): CircuitState = {
val c = state.circuit
val moduleMap = c.modules.map(m => m.name -> m).toMap
val iGraph = InstanceKeyGraph(c)
- val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) =>
- k.module -> v.map(i => i.name -> i.module).toMap
+ val moduleDeps = iGraph.graph.getEdgeMap.map({
+ case (k, v) =>
+ k.module -> v.map(i => i.name -> i.module).toMap
})
val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
@@ -347,11 +354,12 @@ class DeadCodeElimination extends Transform
// themselves. We iterate over the modules in a topological order from leaves to the top. The
// current status of the modulesxMap is used to either delete instances or update their types
val modulesxMap = mutable.HashMap.empty[String, DefModule]
- topoSortedModules.foreach { case mod =>
- deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match {
- case Some(m) => modulesxMap += m.name -> m
- case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
- }
+ topoSortedModules.foreach {
+ case mod =>
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match {
+ case Some(m) => modulesxMap += m.name -> m
+ case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
+ }
}
// Preserve original module order
diff --git a/src/main/scala/firrtl/transforms/Dedup.scala b/src/main/scala/firrtl/transforms/Dedup.scala
index 627af11f..18e32cbc 100644
--- a/src/main/scala/firrtl/transforms/Dedup.scala
+++ b/src/main/scala/firrtl/transforms/Dedup.scala
@@ -20,7 +20,6 @@ import scala.annotation.tailrec
// Datastructures
import scala.collection.mutable
-
/** A component, e.g. register etc. Must be declared only once under the TopAnnotation */
case class NoDedupAnnotation(target: ModuleTarget) extends SingleTargetAnnotation[ModuleTarget] {
def duplicate(n: ModuleTarget): NoDedupAnnotation = NoDedupAnnotation(n)
@@ -36,7 +35,9 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt
new ShellOption[Unit](
longOption = "no-dedup",
toAnnotationSeq = _ => Seq(NoCircuitDedupAnnotation),
- helpText = "Do NOT dedup modules" ) )
+ helpText = "Do NOT dedup modules"
+ )
+ )
}
@@ -46,12 +47,13 @@ case object NoCircuitDedupAnnotation extends NoTargetAnnotation with HasShellOpt
* @param original Original module
* @param index the normalized position of the original module in the original module list, fraction between 0 and 1
*/
-case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double) extends MultiTargetAnnotation {
+case class DedupedResult(original: ModuleTarget, duplicate: Option[IsModule], index: Double)
+ extends MultiTargetAnnotation {
override val targets: Seq[Seq[Target]] = Seq(Seq(original), duplicate.toList)
override def duplicate(n: Seq[Seq[Target]]): Annotation = {
n.toList match {
case Seq(_, List(dup: IsModule)) => DedupedResult(original, Some(dup), index)
- case _ => DedupedResult(original, None, -1)
+ case _ => DedupedResult(original, None, -1)
}
}
}
@@ -96,7 +98,7 @@ class DedupModules extends Transform with DependencyAPIMigration {
val noDedups = state.circuit.main +: state.annotations.collect { case NoDedupAnnotation(ModuleTarget(_, m)) => m }
val (remainingAnnotations, dupResults) = state.annotations.partition {
case _: DupedResult => false
- case _ => true
+ case _ => true
}
val previouslyDupedMap = dupResults.flatMap {
case DupedResult(newModules, original) =>
@@ -114,9 +116,11 @@ class DedupModules extends Transform with DependencyAPIMigration {
* @param noDedups Modules not to dedup
* @return Deduped Circuit and corresponding RenameMap
*/
- def run(c: Circuit,
- noDedups: Seq[String],
- previouslyDupedMap: Map[String, String]): (Circuit, RenameMap, AnnotationSeq) = {
+ def run(
+ c: Circuit,
+ noDedups: Seq[String],
+ previouslyDupedMap: Map[String, String]
+ ): (Circuit, RenameMap, AnnotationSeq) = {
// RenameMap
val componentRenameMap = RenameMap()
@@ -124,13 +128,16 @@ class DedupModules extends Transform with DependencyAPIMigration {
// Maps module name to corresponding dedup module
val dedupMap = DedupModules.deduplicate(c, noDedups.toSet, previouslyDupedMap, componentRenameMap)
- val dedupCliques = dedupMap.foldLeft(Map.empty[String, Set[String]]) {
- case (dedupCliqueMap, (orig: String, dupMod: DefModule)) =>
- val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig
- dedupCliqueMap + (dupMod.name -> set)
- }.flatMap { case (dedupName, set) =>
- set.map { _ -> set }
- }
+ val dedupCliques = dedupMap
+ .foldLeft(Map.empty[String, Set[String]]) {
+ case (dedupCliqueMap, (orig: String, dupMod: DefModule)) =>
+ val set = dedupCliqueMap.getOrElse(dupMod.name, Set.empty[String]) + dupMod.name + orig
+ dedupCliqueMap + (dupMod.name -> set)
+ }
+ .flatMap {
+ case (dedupName, set) =>
+ set.map { _ -> set }
+ }
// Use old module list to preserve ordering
// Lookup what a module deduped to, if its a duplicate, remove it
@@ -149,9 +156,10 @@ class DedupModules extends Transform with DependencyAPIMigration {
val ct = CircuitTarget(c.main)
- val map = dedupMap.map { case (from, to) =>
- logger.debug(s"[Dedup] $from -> ${to.name}")
- ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name))
+ val map = dedupMap.map {
+ case (from, to) =>
+ logger.debug(s"[Dedup] $from -> ${to.name}")
+ ct.module(from).asInstanceOf[CompleteTarget] -> Seq(ct.module(to.name))
}
val moduleRenameMap = RenameMap()
moduleRenameMap.recordAll(map)
@@ -159,15 +167,19 @@ class DedupModules extends Transform with DependencyAPIMigration {
// Build instanceify renaming map
val instanceGraph = InstanceKeyGraph(c)
val instanceify = RenameMap()
- val moduleName2Index = c.modules.map(_.name).zipWithIndex.map { case (n, i) =>
- {
- c.modules.size match {
- case 0 => (n, 0.0)
- case 1 => (n, 1.0)
- case d => (n, i.toDouble / (d - 1))
+ val moduleName2Index = c.modules
+ .map(_.name)
+ .zipWithIndex
+ .map {
+ case (n, i) => {
+ c.modules.size match {
+ case 0 => (n, 0.0)
+ case 1 => (n, 1.0)
+ case d => (n, i.toDouble / (d - 1))
+ }
}
}
- }.toMap
+ .toMap
// get the ordered set of instances a module, includes new Deduped modules
val getChildrenInstances = {
@@ -182,56 +194,62 @@ class DedupModules extends Transform with DependencyAPIMigration {
}
val instanceNameMap: Map[OfModule, Map[Instance, Instance]] = {
- dedupMap.map { case (oldName, dedupedMod) =>
- val key = OfModule(oldName)
- val value = getChildrenInstances(oldName).zip(getChildrenInstances(dedupedMod.name)).map {
- case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name)
- }.toMap
- key -> value
+ dedupMap.map {
+ case (oldName, dedupedMod) =>
+ val key = OfModule(oldName)
+ val value = getChildrenInstances(oldName)
+ .zip(getChildrenInstances(dedupedMod.name))
+ .map {
+ case (oldInst, newInst) => Instance(oldInst.name) -> Instance(newInst.name)
+ }
+ .toMap
+ key -> value
}.toMap
}
- val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap { case mt@ModuleTarget(c, m) if dedupCliques(m).size > 1 =>
- dedupMap.get(m) match {
- case None => Nil
- case Some(module: DefModule) =>
- val paths = instanceGraph.findInstancesInHierarchy(m)
- // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option
- val newTargets = paths.map { path =>
- val root: IsModule = ct.module(c)
- path.foldLeft(root -> root) { case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) =>
- if(mod == c) {
- val mod = CircuitTarget(c).module(c)
- mod -> mod
- } else {
- val enclosingMod = oldRelPath match {
- case i: InstanceTarget => i.ofModule
- case m: ModuleTarget => m.module
- }
- val instMap = instanceNameMap(OfModule(enclosingMod))
- val newInstName = instMap(Instance(name)).value
- val old = oldRelPath.instOf(name, mod)
- old -> newRelPath.instOf(newInstName, mod)
+ val dedupAnnotations = c.modules.map(_.name).map(ct.module).flatMap {
+ case mt @ ModuleTarget(c, m) if dedupCliques(m).size > 1 =>
+ dedupMap.get(m) match {
+ case None => Nil
+ case Some(module: DefModule) =>
+ val paths = instanceGraph.findInstancesInHierarchy(m)
+ // If dedupedAnnos is exactly annos, contains is because dedupedAnnos is type Option
+ val newTargets = paths.map { path =>
+ val root: IsModule = ct.module(c)
+ path.foldLeft(root -> root) {
+ case ((oldRelPath, newRelPath), InstanceKeyGraph.InstanceKey(name, mod)) =>
+ if (mod == c) {
+ val mod = CircuitTarget(c).module(c)
+ mod -> mod
+ } else {
+ val enclosingMod = oldRelPath match {
+ case i: InstanceTarget => i.ofModule
+ case m: ModuleTarget => m.module
+ }
+ val instMap = instanceNameMap(OfModule(enclosingMod))
+ val newInstName = instMap(Instance(name)).value
+ val old = oldRelPath.instOf(name, mod)
+ old -> newRelPath.instOf(newInstName, mod)
+ }
}
}
- }
- // Add all relative paths to referredModule to map to new instances
- def addRecord(from: IsMember, to: IsMember): Unit = from match {
- case x: ModuleTarget =>
- instanceify.record(x, to)
- case x: IsComponent =>
- instanceify.record(x, to)
- addRecord(x.stripHierarchy(1), to)
- }
- // Instanceify deduped Modules!
- if (dedupCliques(module.name).size > 1) {
- newTargets.foreach { case (from, to) => addRecord(from, to) }
- }
- // Return Deduped Results
- if (newTargets.size == 1) {
- Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m)))
- } else Nil
- }
+ // Add all relative paths to referredModule to map to new instances
+ def addRecord(from: IsMember, to: IsMember): Unit = from match {
+ case x: ModuleTarget =>
+ instanceify.record(x, to)
+ case x: IsComponent =>
+ instanceify.record(x, to)
+ addRecord(x.stripHierarchy(1), to)
+ }
+ // Instanceify deduped Modules!
+ if (dedupCliques(module.name).size > 1) {
+ newTargets.foreach { case (from, to) => addRecord(from, to) }
+ }
+ // Return Deduped Results
+ if (newTargets.size == 1) {
+ Seq(DedupedResult(mt, newTargets.headOption.map(_._1), moduleName2Index(m)))
+ } else Nil
+ }
case noDedups => Nil
}
@@ -242,6 +260,7 @@ class DedupModules extends Transform with DependencyAPIMigration {
/** Utility functions for [[DedupModules]] */
object DedupModules extends LazyLogging {
+
/** Change's a module's internal signal names, types, infos, and modules.
* @param rename Function to rename a signal. Called on declaration and references.
* @param retype Function to retype a signal. Called on declaration, references, and subfields
@@ -250,14 +269,16 @@ object DedupModules extends LazyLogging {
* @param module Module to change internals
* @return Changed Module
*/
- def changeInternals(rename: String=>String,
- retype: String=>Type=>Type,
- reinfo: Info=>Info,
- renameOfModule: (String, String)=>String,
- renameExps: Boolean = true
- )(module: DefModule): DefModule = {
+ def changeInternals(
+ rename: String => String,
+ retype: String => Type => Type,
+ reinfo: Info => Info,
+ renameOfModule: (String, String) => String,
+ renameExps: Boolean = true
+ )(module: DefModule
+ ): DefModule = {
def onPort(p: Port): Port = Port(reinfo(p.info), rename(p.name), p.direction, retype(p.name)(p.tpe))
- def onExp(e: Expression): Expression = e match {
+ def onExp(e: Expression): Expression = e match {
case WRef(n, t, k, g) => WRef(rename(n), retype(n)(t), k, g)
case WSubField(expr, n, tpe, kind) =>
val fieldIndex = expr.tpe.asInstanceOf[BundleType].fields.indexWhere(f => f.name == n)
@@ -266,12 +287,12 @@ object DedupModules extends LazyLogging {
val finalExpr = WSubField(newExpr, newField.name, newField.tpe, kind)
//TODO: renameMap.rename(e.serialize, finalExpr.serialize)
finalExpr
- case other => other map onExp
+ case other => other.map(onExp)
}
def onStmt(s: Statement): Statement = s match {
case DefNode(info, name, value) =>
retype(name)(value.tpe)
- if(renameExps) DefNode(reinfo(info), rename(name), onExp(value))
+ if (renameExps) DefNode(reinfo(info), rename(name), onExp(value))
else DefNode(reinfo(info), rename(name), value)
case WDefInstance(i, n, m, t) =>
val newmod = renameOfModule(n, m)
@@ -283,12 +304,18 @@ object DedupModules extends LazyLogging {
val oldType = MemPortUtils.memType(d)
val newType = retype(d.name)(oldType)
val index = oldType
- .asInstanceOf[BundleType].fields.headOption
- .map(_.tpe.asInstanceOf[BundleType].fields.indexWhere(
- {
- case Field("data" | "wdata" | "rdata", _, _) => true
- case _ => false
- }))
+ .asInstanceOf[BundleType]
+ .fields
+ .headOption
+ .map(
+ _.tpe
+ .asInstanceOf[BundleType]
+ .fields
+ .indexWhere({
+ case Field("data" | "wdata" | "rdata", _, _) => true
+ case _ => false
+ })
+ )
val newDataType = index match {
case Some(i) =>
//If index nonempty, then there exists a port
@@ -299,15 +326,15 @@ object DedupModules extends LazyLogging {
// associate it with the type of the memory (as the memory type is different than the datatype)
retype(d.name + ";&*^$")(d.dataType)
}
- d.copy(dataType = newDataType) map rename map reinfo
+ d.copy(dataType = newDataType).map(rename).map(reinfo)
case h: IsDeclaration =>
- val temp = h map rename map retype(h.name) map reinfo
- if(renameExps) temp map onExp else temp
+ val temp = h.map(rename).map(retype(h.name)).map(reinfo)
+ if (renameExps) temp.map(onExp) else temp
case other =>
- val temp = other map reinfo map onStmt
- if(renameExps) temp map onExp else temp
+ val temp = other.map(reinfo).map(onStmt)
+ if (renameExps) temp.map(onExp) else temp
}
- module map onPort map onStmt
+ module.map(onPort).map(onStmt)
}
/** Dedup a module's instances based on dedup map
@@ -321,11 +348,13 @@ object DedupModules extends LazyLogging {
* @param renameMap Will be modified to keep track of renames in this function
* @return fixed up module deduped instances
*/
- def dedupInstances(top: CircuitTarget,
- originalModule: String,
- moduleMap: Map[String, DefModule],
- name2name: Map[String, String],
- renameMap: RenameMap): DefModule = {
+ def dedupInstances(
+ top: CircuitTarget,
+ originalModule: String,
+ moduleMap: Map[String, DefModule],
+ name2name: Map[String, String],
+ renameMap: RenameMap
+ ): DefModule = {
val module = moduleMap(originalModule)
// If black box, return it (it has no instances)
@@ -340,7 +369,8 @@ object DedupModules extends LazyLogging {
}
val typeMap = mutable.HashMap[String, Type]()
def retype(name: String)(tpe: Type): Type = {
- if (typeMap.contains(name)) typeMap(name) else {
+ if (typeMap.contains(name)) typeMap(name)
+ else {
if (instanceModuleMap.contains(name)) {
val newType = Utils.module_type(getNewModule(instanceModuleMap(name)))
typeMap(name) = newType
@@ -360,7 +390,7 @@ object DedupModules extends LazyLogging {
def renameOfModule(instance: String, ofModule: String): String = {
name2name(ofModule)
}
- changeInternals({n => n}, retype, {i => i}, renameOfModule)(module)
+ changeInternals({ n => n }, retype, { i => i }, renameOfModule)(module)
}
@tailrec
@@ -415,10 +445,11 @@ object DedupModules extends LazyLogging {
* @return A map from tag to names of modules with the same structure and
* a RenameMap which maps Module names to their Tag.
*/
- def buildRTLTags(top: CircuitTarget,
- moduleLinearization: Seq[DefModule],
- noDedups: Set[String]
- ): (collection.Map[String, collection.Set[String]], RenameMap) = {
+ def buildRTLTags(
+ top: CircuitTarget,
+ moduleLinearization: Seq[DefModule],
+ noDedups: Set[String]
+ ): (collection.Map[String, collection.Set[String]], RenameMap) = {
// maps hash code to human readable tag
val hashToTag = mutable.HashMap[ir.HashCode, String]()
@@ -449,9 +480,9 @@ object DedupModules extends LazyLogging {
moduleNameToTag(originalModule.name) = hashToTag(hash)
}
- val tag2all = hashToNames.map{ case (hash, names) => hashToTag(hash) -> names.toSet }
+ val tag2all = hashToNames.map { case (hash, names) => hashToTag(hash) -> names.toSet }
val tagMap = RenameMap()
- moduleNameToTag.foreach{ case (name, tag) => tagMap.record(top.module(name), top.module(tag)) }
+ moduleNameToTag.foreach { case (name, tag) => tagMap.record(top.module(name), top.module(tag)) }
(tag2all, tagMap)
}
@@ -461,10 +492,12 @@ object DedupModules extends LazyLogging {
* @param renameMap rename map to populate when deduping
* @return Map of original Module name -> Deduped Module
*/
- def deduplicate(circuit: Circuit,
- noDedups: Set[String],
- previousDupResults: Map[String, String],
- renameMap: RenameMap): Map[String, DefModule] = {
+ def deduplicate(
+ circuit: Circuit,
+ noDedups: Set[String],
+ previousDupResults: Map[String, String],
+ renameMap: RenameMap
+ ): Map[String, DefModule] = {
val (moduleMap, moduleLinearization) = {
val iGraph = InstanceKeyGraph(circuit)
@@ -479,13 +512,14 @@ object DedupModules extends LazyLogging {
val (tag2all, tagMap) = buildRTLTags(top, moduleLinearization, noDedups)
// Set tag2name to be the best dedup module name
- val moduleIndex = circuit.modules.zipWithIndex.map{case (m, i) => m.name -> i}.toMap
+ val moduleIndex = circuit.modules.zipWithIndex.map { case (m, i) => m.name -> i }.toMap
// returns the module matching the circuit name or the module with lower index otherwise
def order(l: String, r: String): String = {
if (l == main) l
else if (r == main) r
- else if (moduleIndex(l) < moduleIndex(r)) l else r
+ else if (moduleIndex(l) < moduleIndex(r)) l
+ else r
}
// Maps a module's tag to its deduplicated module
@@ -499,7 +533,7 @@ object DedupModules extends LazyLogging {
tag2name(tag) = dedupName
val dedupModule = moduleMap(dedupWithoutOldName) match {
case e: ExtModule => e.copy(name = dedupName)
- case e: Module => e.copy(name = dedupName)
+ case e: Module => e.copy(name = dedupName)
}
dedupName -> dedupModule
}.toMap
@@ -508,32 +542,32 @@ object DedupModules extends LazyLogging {
val name2name = moduleMap.keysIterator.map { originalModule =>
tagMap.get(top.module(originalModule)) match {
case Some(Seq(Target(_, Some(tag), Nil))) => originalModule -> tag2name(tag)
- case None => originalModule -> originalModule
- case other => throwInternalError(other.toString)
+ case None => originalModule -> originalModule
+ case other => throwInternalError(other.toString)
}
}.toMap
// Build Remap for modules with deduped module references
val dedupedName2module = tag2name.map {
- case (tag, name) => name -> DedupModules.dedupInstances(
- top, name, moduleMapWithOldNames, name2name, renameMap)
+ case (tag, name) => name -> DedupModules.dedupInstances(top, name, moduleMapWithOldNames, name2name, renameMap)
}
// Build map from original name to corresponding deduped module
// It is important to flatMap before looking up the DefModules so that they aren't hashed
val name2module: Map[String, DefModule] =
tag2all.flatMap { case (tag, names) => names.map(_ -> tag) }
- .mapValues(tag => dedupedName2module(tag2name(tag)))
- .toMap
+ .mapValues(tag => dedupedName2module(tag2name(tag)))
+ .toMap
// Build renameMap
val indexedTargets = mutable.HashMap[String, IndexedSeq[ReferenceTarget]]()
- name2module.foreach { case (originalName, depModule) =>
- if(originalName != depModule.name) {
- val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule))
- val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName))
- computeRenameMap(fromSeq, toSeq, renameMap)
- }
+ name2module.foreach {
+ case (originalName, depModule) =>
+ if (originalName != depModule.name) {
+ val toSeq = indexedTargets.getOrElseUpdate(depModule.name, computeIndexedNames(circuit.main, depModule))
+ val fromSeq = computeIndexedNames(circuit.main, moduleMap(originalName))
+ computeRenameMap(fromSeq, toSeq, renameMap)
+ }
}
name2module
@@ -549,18 +583,21 @@ object DedupModules extends LazyLogging {
tpe
}
- changeInternals(rename, retype, {i => i}, {(x, y) => x}, renameExps = false)(m)
+ changeInternals(rename, retype, { i => i }, { (x, y) => x }, renameExps = false)(m)
refs.toIndexedSeq
}
- def computeRenameMap(originalNames: IndexedSeq[ReferenceTarget],
- dedupedNames: IndexedSeq[ReferenceTarget],
- renameMap: RenameMap): Unit = {
+ def computeRenameMap(
+ originalNames: IndexedSeq[ReferenceTarget],
+ dedupedNames: IndexedSeq[ReferenceTarget],
+ renameMap: RenameMap
+ ): Unit = {
originalNames.zip(dedupedNames).foreach {
- case (o, d) => if (o.component != d.component || o.ref != d.ref) {
- renameMap.record(o, d.copy(module = o.module))
- }
+ case (o, d) =>
+ if (o.component != d.component || o.ref != d.ref) {
+ renameMap.record(o, d.copy(module = o.module))
+ }
}
}
diff --git a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala
index a1e49d62..bfab31bf 100644
--- a/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala
+++ b/src/main/scala/firrtl/transforms/FixAddingNegativeLiteralsTransform.scala
@@ -33,7 +33,7 @@ object FixAddingNegativeLiterals {
*/
def fixupModule(m: DefModule): DefModule = {
val namespace = Namespace(m)
- m map fixupStatement(namespace)
+ m.map(fixupStatement(namespace))
}
/** Returns a statement with fixed additions of negative literals
@@ -43,8 +43,8 @@ object FixAddingNegativeLiterals {
*/
def fixupStatement(namespace: Namespace)(s: Statement): Statement = {
val stmtBuffer = mutable.ListBuffer[Statement]()
- val ret = s map fixupStatement(namespace) map fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer)
- if(stmtBuffer.isEmpty) {
+ val ret = s.map(fixupStatement(namespace)).map(fixupOnExpr(Utils.get_info(s), namespace, stmtBuffer))
+ if (stmtBuffer.isEmpty) {
ret
} else {
stmtBuffer += ret
@@ -58,8 +58,7 @@ object FixAddingNegativeLiterals {
* @param e expression to fixup
* @return generated statements and the fixed expression
*/
- def fixupExpression(info: Info, namespace: Namespace)
- (e: Expression): (Seq[Statement], Expression) = {
+ def fixupExpression(info: Info, namespace: Namespace)(e: Expression): (Seq[Statement], Expression) = {
val stmtBuffer = mutable.ListBuffer[Statement]()
val retExpr = fixupOnExpr(info, namespace, stmtBuffer)(e)
(stmtBuffer.toList, retExpr)
@@ -72,12 +71,16 @@ object FixAddingNegativeLiterals {
* @param e expression to fixup
* @return fixed expression
*/
- private def fixupOnExpr(info: Info, namespace: Namespace, stmtBuffer: mutable.ListBuffer[Statement])
- (e: Expression): Expression = {
+ private def fixupOnExpr(
+ info: Info,
+ namespace: Namespace,
+ stmtBuffer: mutable.ListBuffer[Statement]
+ )(e: Expression
+ ): Expression = {
// Helper function to create the subtraction expression
def fixupAdd(expr: Expression, litValue: BigInt, litWidth: BigInt): DoPrim = {
- if(litValue == minNegValue(litWidth)) {
+ if (litValue == minNegValue(litWidth)) {
val posLiteral = SIntLiteral(-litValue)
assert(posLiteral.width.asInstanceOf[IntWidth].width - 1 == litWidth)
val sub = DefNode(info, namespace.newTemp, setType(DoPrim(Sub, Seq(expr, posLiteral), Nil, UnknownType)))
@@ -91,10 +94,10 @@ object FixAddingNegativeLiterals {
}
}
- e map fixupOnExpr(info, namespace, stmtBuffer) match {
- case DoPrim(Add, Seq(arg, lit@SIntLiteral(value, w@IntWidth(width))), Nil, t: SIntType) if value < 0 =>
+ e.map(fixupOnExpr(info, namespace, stmtBuffer)) match {
+ case DoPrim(Add, Seq(arg, lit @ SIntLiteral(value, w @ IntWidth(width))), Nil, t: SIntType) if value < 0 =>
fixupAdd(arg, value, width)
- case DoPrim(Add, Seq(lit@SIntLiteral(value, w@IntWidth(width)), arg), Nil, t: SIntType) if value < 0 =>
+ case DoPrim(Add, Seq(lit @ SIntLiteral(value, w @ IntWidth(width)), arg), Nil, t: SIntType) if value < 0 =>
fixupAdd(arg, value, width)
case other => other
}
diff --git a/src/main/scala/firrtl/transforms/Flatten.scala b/src/main/scala/firrtl/transforms/Flatten.scala
index cc5b3504..36e71470 100644
--- a/src/main/scala/firrtl/transforms/Flatten.scala
+++ b/src/main/scala/firrtl/transforms/Flatten.scala
@@ -7,7 +7,7 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.annotations._
import scala.collection.mutable
-import firrtl.passes.{InlineInstances,PassException}
+import firrtl.passes.{InlineInstances, PassException}
import firrtl.stage.Forms
/** Tags an annotation to be consumed by this transform */
@@ -25,101 +25,114 @@ case class FlattenAnnotation(target: Named) extends SingleTargetAnnotation[Named
*/
class Flatten extends Transform with DependencyAPIMigration {
- override def prerequisites = Forms.LowForm
- override def optionalPrerequisites = Seq.empty
- override def optionalPrerequisiteOf = Forms.LowEmitters
+ override def prerequisites = Forms.LowForm
+ override def optionalPrerequisites = Seq.empty
+ override def optionalPrerequisiteOf = Forms.LowEmitters
override def invalidates(a: Transform) = false
- val inlineTransform = new InlineInstances
-
- private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
- anns.foldLeft( (Set.empty[ModuleName], Set.empty[ComponentName]) ) {
- case ((modNames, instNames), ann) => ann match {
- case FlattenAnnotation(CircuitName(c)) =>
- (circuit.modules.collect {
- case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
- }.toSet, instNames)
- case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
- case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
- case _ => throw new PassException("Annotation must be a FlattenAnnotation")
- }
- }
-
- /**
+ val inlineTransform = new InlineInstances
+
+ private def collectAnns(circuit: Circuit, anns: Iterable[Annotation]): (Set[ModuleName], Set[ComponentName]) =
+ anns.foldLeft((Set.empty[ModuleName], Set.empty[ComponentName])) {
+ case ((modNames, instNames), ann) =>
+ ann match {
+ case FlattenAnnotation(CircuitName(c)) =>
+ (
+ circuit.modules.collect {
+ case Module(_, name, _, _) if name != circuit.main => ModuleName(name, CircuitName(c))
+ }.toSet,
+ instNames
+ )
+ case FlattenAnnotation(ModuleName(mod, cir)) => (modNames + ModuleName(mod, cir), instNames)
+ case FlattenAnnotation(ComponentName(com, mod)) => (modNames, instNames + ComponentName(com, mod))
+ case _ => throw new PassException("Annotation must be a FlattenAnnotation")
+ }
+ }
+
+ /**
* Modifies the circuit by replicating the hierarchy under the annotated objects (mods and insts) and
* by rewriting the original circuit to refer to the new modules that will be inlined later.
* @return modified circuit and ModuleNames to inline
*/
- def duplicateSubCircuitsFromAnno(c: Circuit, mods: Set[ModuleName], insts: Set[ComponentName]): (Circuit, Set[ModuleName]) = {
- val modMap = c.modules.map(m => m.name->m).toMap
- val seedMods = mutable.Map.empty[String, String]
- val newModDefs = mutable.Set.empty[DefModule]
- val nsp = Namespace(c)
-
- /**
+ def duplicateSubCircuitsFromAnno(
+ c: Circuit,
+ mods: Set[ModuleName],
+ insts: Set[ComponentName]
+ ): (Circuit, Set[ModuleName]) = {
+ val modMap = c.modules.map(m => m.name -> m).toMap
+ val seedMods = mutable.Map.empty[String, String]
+ val newModDefs = mutable.Set.empty[DefModule]
+ val nsp = Namespace(c)
+
+ /**
* We start with rewriting DefInstances in the modules with annotations to refer to replicated modules to be created later.
* It populates seedMods where we capture the mapping between the original module name of the instances came from annotation
* to a new module name that we will create as a replica of the original one.
* Note: We replace old modules with it replicas so that other instances of the same module can be left unchanged.
*/
- def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
- case _: Block => x map rewriteMod(parent)
- case WDefInstance(info, instName, moduleName, instTpe) =>
- if (insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
- || mods.contains(ModuleName(parent.name, CircuitName(c.main)))) {
- val newModName = if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
- seedMods += moduleName -> newModName
- WDefInstance(info, instName, newModName, instTpe)
- } else x
- case _ => x
- }
-
- val modifMods = c.modules map { m => m map rewriteMod(m) }
-
- /**
+ def rewriteMod(parent: DefModule)(x: Statement): Statement = x match {
+ case _: Block => x.map(rewriteMod(parent))
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ if (
+ insts.contains(ComponentName(instName, ModuleName(parent.name, CircuitName(c.main))))
+ || mods.contains(ModuleName(parent.name, CircuitName(c.main)))
+ ) {
+ val newModName =
+ if (seedMods.contains(moduleName)) seedMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
+ seedMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ } else x
+ case _ => x
+ }
+
+ val modifMods = c.modules.map { m => m.map(rewriteMod(m)) }
+
+ /**
* Recursively rewrites modules in the hierarchy starting with modules in seedMods (originally annotations).
* Populates newModDefs, which are replicated modules used in the subcircuit that we create
* by recursively traversing modules captured inside seedMods and replicating them
*/
- def recDupMods(mods: Map[String, String]): Unit = {
- val replMods = mutable.Map.empty[String, String]
-
- def dupMod(x: Statement): Statement = x match {
- case _: Block => x map dupMod
- case WDefInstance(info, instName, moduleName, instTpe) => modMap(moduleName) match {
- case m: Module =>
- val newModName = if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName+"_TO_FLATTEN")
- replMods += moduleName -> newModName
- WDefInstance(info, instName, newModName, instTpe)
- case _ => x // Ignore extmodules
- }
- case _ => x
- }
-
- def dupName(name: String): String = mods(name)
- val newMods = mods map { case (origName, newName) => modMap(origName) map dupMod map dupName }
-
- newModDefs ++= newMods
-
- if(replMods.size > 0) recDupMods(replMods.toMap)
-
- }
- recDupMods(seedMods.toMap)
-
- //convert newly created modules to ModuleName for inlining next (outside this function)
- val modsToInline = newModDefs map { m => ModuleName(m.name, CircuitName(c.main)) }
- (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
- }
-
- override def execute(state: CircuitState): CircuitState = {
- val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
- annos match {
- case Nil => state
- case myAnnotations =>
- val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
- // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
- val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
- inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
- }
- }
+ def recDupMods(mods: Map[String, String]): Unit = {
+ val replMods = mutable.Map.empty[String, String]
+
+ def dupMod(x: Statement): Statement = x match {
+ case _: Block => x.map(dupMod)
+ case WDefInstance(info, instName, moduleName, instTpe) =>
+ modMap(moduleName) match {
+ case m: Module =>
+ val newModName =
+ if (replMods.contains(moduleName)) replMods(moduleName) else nsp.newName(moduleName + "_TO_FLATTEN")
+ replMods += moduleName -> newModName
+ WDefInstance(info, instName, newModName, instTpe)
+ case _ => x // Ignore extmodules
+ }
+ case _ => x
+ }
+
+ def dupName(name: String): String = mods(name)
+ val newMods = mods.map { case (origName, newName) => modMap(origName).map(dupMod).map(dupName) }
+
+ newModDefs ++= newMods
+
+ if (replMods.size > 0) recDupMods(replMods.toMap)
+
+ }
+ recDupMods(seedMods.toMap)
+
+ //convert newly created modules to ModuleName for inlining next (outside this function)
+ val modsToInline = newModDefs.map { m => ModuleName(m.name, CircuitName(c.main)) }
+ (c.copy(modules = modifMods ++ newModDefs), modsToInline.toSet)
+ }
+
+ override def execute(state: CircuitState): CircuitState = {
+ val annos = state.annotations.collect { case a @ FlattenAnnotation(_) => a }
+ annos match {
+ case Nil => state
+ case myAnnotations =>
+ val (modNames, instNames) = collectAnns(state.circuit, myAnnotations)
+ // take incoming annotation and produce annotations for InlineInstances, i.e. traverse circuit down to find all instances to inline
+ val (newc, modsToInline) = duplicateSubCircuitsFromAnno(state.circuit, modNames, instNames)
+ inlineTransform.run(newc, modsToInline.toSet, Set.empty[ComponentName], state.annotations)
+ }
+ }
}
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index a2399b5a..b582fe2a 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -119,7 +119,7 @@ object FlattenRegUpdate {
def rec(e: Expression): (Info, Expression) = {
val (info, expr) = kind(e) match {
case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e))
- case _ => unwrap(e)
+ case _ => unwrap(e)
}
expr match {
case Mux(cond, tval, fval, tpe) =>
@@ -128,16 +128,18 @@ object FlattenRegUpdate {
val infox = combineInfos(info, tinfo, finfo)
(infox, Mux(cond, tvalx, fvalx, tpe))
// Return the original expression to end flattening
- case _ => unwrap(e)
+ case _ => unwrap(e)
}
}
rec(start)
}
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
- case reg @ DefRegister(_, rname, _,_, resetCond, _) =>
- assert(resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
- "Synchronous reset should have already been made explicit!")
+ case reg @ DefRegister(_, rname, _, _, resetCond, _) =>
+ assert(
+ resetCond.tpe == AsyncResetType || resetCond == Utils.zero,
+ "Synchronous reset should have already been made explicit!"
+ )
val ref = WRef(reg)
val (info, rhs) = constructRegUpdate(netlist.getOrElse(ref, ref))
val update = Connect(info, ref, rhs)
@@ -145,7 +147,7 @@ object FlattenRegUpdate {
reg
// Remove connections to Registers so we preserve LowFirrtl single-connection semantics
case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt
- case other => other
+ case other => other
}
val bodyx = onStmt(mod.body)
@@ -163,12 +165,14 @@ object FlattenRegUpdate {
class FlattenRegUpdate extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[InlineCastsTransform],
- Dependency[LegalizeClocksTransform] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[InlineCastsTransform],
+ Dependency[LegalizeClocksTransform]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -181,7 +185,7 @@ class FlattenRegUpdate extends Transform with DependencyAPIMigration {
def execute(state: CircuitState): CircuitState = {
val modulesx = state.circuit.modules.map {
- case mod: Module => FlattenRegUpdate.flattenReg(mod)
+ case mod: Module => FlattenRegUpdate.flattenReg(mod)
case ext: ExtModule => ext
}
state.copy(circuit = state.circuit.copy(modules = modulesx))
diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala
index 166feba0..0db67f1e 100644
--- a/src/main/scala/firrtl/transforms/GroupComponents.scala
+++ b/src/main/scala/firrtl/transforms/GroupComponents.scala
@@ -10,7 +10,6 @@ import firrtl.stage.Forms
import scala.collection.mutable
-
/**
* Specifies a group of components, within a module, to pull out into their own module
* Components that are only connected to a group's components will also be included
@@ -21,8 +20,14 @@ import scala.collection.mutable
* @param outputSuffix suggested suffix of any output ports of the new module
* @param inputSuffix suggested suffix of any input ports of the new module
*/
-case class GroupAnnotation(components: Seq[ComponentName], newModule: String, newInstance: String, outputSuffix: Option[String] = None, inputSuffix: Option[String] = None) extends Annotation {
- if(components.nonEmpty) {
+case class GroupAnnotation(
+ components: Seq[ComponentName],
+ newModule: String,
+ newInstance: String,
+ outputSuffix: Option[String] = None,
+ inputSuffix: Option[String] = None)
+ extends Annotation {
+ if (components.nonEmpty) {
require(components.forall(_.module == components.head.module), "All components must be in the same module.")
require(components.forall(!_.name.contains('.')), "No components can be a subcomponent.")
}
@@ -35,7 +40,7 @@ case class GroupAnnotation(components: Seq[ComponentName], newModule: String, ne
/* Only keeps components renamed to components */
def update(renames: RenameMap): Seq[Annotation] = {
- val newComponents = components.flatMap{c => renames.get(c).getOrElse(Seq(c))}.collect {
+ val newComponents = components.flatMap { c => renames.get(c).getOrElse(Seq(c)) }.collect {
case c: ComponentName => c
}
Seq(GroupAnnotation(newComponents, newModule, newInstance, outputSuffix, inputSuffix))
@@ -58,7 +63,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
}
override def execute(state: CircuitState): CircuitState = {
- val groups = state.annotations.collect {case g: GroupAnnotation => g}
+ val groups = state.annotations.collect { case g: GroupAnnotation => g }
val module2group = groups.groupBy(_.currentModule)
val mnamespace = Namespace(state.circuit)
val newModules = state.circuit.modules.flatMap {
@@ -74,13 +79,12 @@ class GroupComponents extends Transform with DependencyAPIMigration {
val namespace = Namespace(m)
val groupRoots = groups.map(_.components.map(_.name))
val totalSum = groupRoots.map(_.size).sum
- val union = groupRoots.foldLeft(Set.empty[String]){(all, set) => all.union(set.toSet)}
+ val union = groupRoots.foldLeft(Set.empty[String]) { (all, set) => all.union(set.toSet) }
- require(groupRoots.forall{_.forall{namespace.contains}}, "All names should be in this module")
+ require(groupRoots.forall { _.forall { namespace.contains } }, "All names should be in this module")
require(totalSum == union.size, "No name can be in more than one group")
require(groupRoots.forall(_.nonEmpty), "All groupRoots must by non-empty")
-
// Order of groups, according to their label. The label is the first root in the group
val labelOrder = groups.collect({ case g: GroupAnnotation => g.components.head.name })
@@ -90,8 +94,8 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Group roots, by label
// The label "" indicates the original module, and components belonging to that group will remain
// in the original module (not get moved into a new module)
- val label2group: Map[String, MSet[String]] = groups.collect{
- case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name):_*)
+ val label2group: Map[String, MSet[String]] = groups.collect {
+ case GroupAnnotation(set, module, instance, _, _) => set.head.name -> mutable.Set(set.map(_.name): _*)
}.toMap + ("" -> mutable.Set(""))
// Name of new module containing each group, by label
@@ -105,7 +109,6 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Build set of components not in set
val notSet = label2group.map { case (key, value) => key -> union.diff(value) }
-
// Get all dependencies between components
val deps = getComponentConnectivity(m)
@@ -114,13 +117,14 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// For each group (by label), add connectivity between nodes in set
// Populate reachableNodes with reachability, where blacklist is their notSet
- label2group.foreach { case (label, set) =>
- set.foreach { x =>
- deps.addPairWithEdge(label, x)
- }
- deps.reachableFrom(label, notSet(label)) foreach { node =>
- reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label
- }
+ label2group.foreach {
+ case (label, set) =>
+ set.foreach { x =>
+ deps.addPairWithEdge(label, x)
+ }
+ deps.reachableFrom(label, notSet(label)).foreach { node =>
+ reachableNodes.getOrElseUpdate(node, mutable.Set.empty[String]) += label
+ }
}
// Unused nodes are not reachable from any group nor the root--add them to root group
@@ -129,12 +133,13 @@ class GroupComponents extends Transform with DependencyAPIMigration {
}
// Add nodes who are reached by a single group, to that group
- reachableNodes.foreach { case (node, membership) =>
- if(membership.size == 1) {
- label2group(membership.head) += node
- } else {
- label2group("") += node
- }
+ reachableNodes.foreach {
+ case (node, membership) =>
+ if (membership.size == 1) {
+ label2group(membership.head) += node
+ } else {
+ label2group("") += node
+ }
}
applyGrouping(m, labelOrder, label2group, label2module, label2instance, label2annotation)
@@ -150,19 +155,21 @@ class GroupComponents extends Transform with DependencyAPIMigration {
* @param label2annotation annotation specifying the group, by label
* @return new modules, including each group's module and the new split module
*/
- def applyGrouping( m: Module,
- labelOrder: Seq[String],
- label2group: Map[String, MSet[String]],
- label2module: Map[String, String],
- label2instance: Map[String, String],
- label2annotation: Map[String, GroupAnnotation]
- ): Seq[Module] = {
+ def applyGrouping(
+ m: Module,
+ labelOrder: Seq[String],
+ label2group: Map[String, MSet[String]],
+ label2module: Map[String, String],
+ label2instance: Map[String, String],
+ label2annotation: Map[String, GroupAnnotation]
+ ): Seq[Module] = {
// Maps node to group
val byNode = mutable.HashMap[String, String]()
- label2group.foreach { case (group, nodes) =>
- nodes.foreach { node =>
- byNode(node) = group
- }
+ label2group.foreach {
+ case (group, nodes) =>
+ nodes.foreach { node =>
+ byNode(node) = group
+ }
}
val groupNamespace = label2group.map { case (head, set) => head -> Namespace(set.toSeq) }
@@ -180,7 +187,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
val portNames = groupPortNames(group)
val suffix = d match {
case Output => label2annotation(group).outputSuffix.getOrElse("")
- case Input => label2annotation(group).inputSuffix.getOrElse("")
+ case Input => label2annotation(group).inputSuffix.getOrElse("")
}
val newName = groupNamespace(group).newName(source + suffix)
val portName = portNames.getOrElseUpdate(source, newName)
@@ -192,7 +199,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
val portName = addPort(group, exp, Output)
val connectStatement = exp.tpe match {
case AnalogType(_) => Attach(NoInfo, Seq(WRef(portName), exp))
- case _ => Connect(NoInfo, WRef(portName), exp)
+ case _ => Connect(NoInfo, WRef(portName), exp)
}
groupStatements(group) += connectStatement
portName
@@ -201,7 +208,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Given the sink is in a group, tidy up source references
def inGroupFixExps(group: String, added: mutable.ArrayBuffer[Statement])(e: Expression): Expression = e match {
case _: Literal => e
- case _: DoPrim | _: Mux | _: ValidIf => e map inGroupFixExps(group, added)
+ case _: DoPrim | _: Mux | _: ValidIf => e.map(inGroupFixExps(group, added))
case otherExp: Expression =>
val wref = getWRef(otherExp)
val source = wref.name
@@ -238,10 +245,10 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// Given the sink is in the parent module, tidy up source references belonging to groups
def inTopFixExps(e: Expression): Expression = e match {
- case _: DoPrim | _: Mux | _: ValidIf => e map inTopFixExps
+ case _: DoPrim | _: Mux | _: ValidIf => e.map(inTopFixExps)
case otherExp: Expression =>
val wref = getWRef(otherExp)
- if(byNode(wref.name) != "") {
+ if (byNode(wref.name) != "") {
// Get the name of source's group
val otherGroup = byNode(wref.name)
@@ -260,7 +267,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
case r: IsDeclaration if byNode(r.name) != "" =>
val topStmts = mutable.ArrayBuffer[Statement]()
val group = byNode(r.name)
- groupStatements(group) += r mapExpr inGroupFixExps(group, topStmts)
+ groupStatements(group) += r.mapExpr(inGroupFixExps(group, topStmts))
Block(topStmts.toSeq)
case c: Connect if byNode(getWRef(c.loc).name) != "" =>
// Sink is in a group
@@ -276,20 +283,26 @@ class GroupComponents extends Transform with DependencyAPIMigration {
// TODO Attach if all are in a group?
case _: IsDeclaration | _: Connect | _: Attach =>
// Sink is in Top
- val ret = s mapExpr inTopFixExps
+ val ret = s.mapExpr(inTopFixExps)
ret
- case other => other map onStmt
+ case other => other.map(onStmt)
}
}
-
// Build datastructures
- val newTopBody = Block(labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body)))
+ val newTopBody = Block(
+ labelOrder.map(g => WDefInstance(NoInfo, label2instance(g), label2module(g), UnknownType)) ++ Seq(onStmt(m.body))
+ )
val finalTopBody = Block(Utils.squashEmpty(newTopBody).asInstanceOf[Block].stmts.distinct)
// For all group labels (not including the original module label), return a new Module.
- val newModules = labelOrder.filter(_ != "") map { group =>
- Module(NoInfo, label2module(group), groupPorts(group).distinct.toSeq, Block(groupStatements(group).distinct.toSeq))
+ val newModules = labelOrder.filter(_ != "").map { group =>
+ Module(
+ NoInfo,
+ label2module(group),
+ groupPorts(group).distinct.toSeq,
+ Block(groupStatements(group).distinct.toSeq)
+ )
}
Seq(m.copy(body = finalTopBody)) ++ newModules
}
@@ -298,7 +311,7 @@ class GroupComponents extends Transform with DependencyAPIMigration {
case w: WRef => w
case other =>
var w = WRef("")
- other mapExpr { e => w = getWRef(e); e}
+ other.mapExpr { e => w = getWRef(e); e }
w
}
@@ -317,25 +330,25 @@ class GroupComponents extends Transform with DependencyAPIMigration {
bidirGraph.addPairWithEdge(sink.name, name)
bidirGraph.addPairWithEdge(name, sink.name)
w
- case other => other map onExpr(sink)
+ case other => other.map(onExpr(sink))
}
def onStmt(stmt: Statement): Unit = stmt match {
case w: WDefInstance =>
case h: IsDeclaration =>
bidirGraph.addVertex(h.name)
- h map onExpr(WRef(h.name))
+ h.map(onExpr(WRef(h.name)))
case Attach(_, exprs) => // Add edge between each expression
- exprs.tail map onExpr(getWRef(exprs.head))
+ exprs.tail.map(onExpr(getWRef(exprs.head)))
case Connect(_, loc, expr) =>
onExpr(getWRef(loc))(expr)
- case q @ Stop(_,_, clk, en) =>
+ case q @ Stop(_, _, clk, en) =>
val simName = simNamespace.newTemp
simulations(simName) = q
- Seq(clk, en) map onExpr(WRef(simName))
+ Seq(clk, en).map(onExpr(WRef(simName)))
case q @ Print(_, _, args, clk, en) =>
val simName = simNamespace.newTemp
simulations(simName) = q
- (args :+ clk :+ en) map onExpr(WRef(simName))
+ (args :+ clk :+ en).map(onExpr(WRef(simName)))
case Block(stmts) => stmts.foreach(onStmt)
case ignore @ (_: IsInvalid | EmptyStmt) => // do nothing
case other => throw new Exception(s"Unexpected Statement $other")
@@ -358,7 +371,7 @@ class GroupAndDedup extends GroupComponents {
override def invalidates(a: Transform): Boolean = a match {
case _: DedupModules => true
- case _ => super.invalidates(a)
+ case _ => super.invalidates(a)
}
}
diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala
index dd073001..376382cc 100644
--- a/src/main/scala/firrtl/transforms/InferResets.scala
+++ b/src/main/scala/firrtl/transforms/InferResets.scala
@@ -7,9 +7,9 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.annotations.{ReferenceTarget, TargetToken}
-import firrtl.Utils.{toTarget, throwInternalError}
+import firrtl.Utils.{throwInternalError, toTarget}
import firrtl.options.Dependency
-import firrtl.passes.{Pass, PassException, InferTypes}
+import firrtl.passes.{InferTypes, Pass, PassException}
import firrtl.graph.MutableDiGraph
import scala.collection.mutable
@@ -83,14 +83,13 @@ object InferResets {
// Vectors must all have the same type, so we only process Index 0
// If the subtype is an aggregate, there can be multiple of each index
val ts = tokens.collect { case (TargetToken.Index(0) +: tail, tpe) => (tail, tpe) }
- VectorTree(fromTokens(ts:_*))
+ VectorTree(fromTokens(ts: _*))
// BundleTree
case (TargetToken.Field(_) +: _, _) +: _ =>
val fields =
- tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }
- .mapValues { ts =>
- fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }:_*)
- }.toMap
+ tokens.groupBy { case (TargetToken.Field(n) +: t, _) => n }.mapValues { ts =>
+ fromTokens(ts.map { case (_ +: t, tpe) => (t, tpe) }: _*)
+ }.toMap
BundleTree(fields)
}
}
@@ -113,14 +112,16 @@ object InferResets {
class InferResets extends Transform with DependencyAPIMigration {
override def prerequisites =
- Seq( Dependency(passes.ResolveKinds),
- Dependency(passes.InferTypes),
- Dependency(passes.ResolveFlows),
- Dependency[passes.InferWidths] ) ++ stage.Forms.WorkingIR
+ Seq(
+ Dependency(passes.ResolveKinds),
+ Dependency(passes.InferTypes),
+ Dependency(passes.ResolveFlows),
+ Dependency[passes.InferWidths]
+ ) ++ stage.Forms.WorkingIR
override def invalidates(a: Transform): Boolean = a match {
case _: checks.CheckResets | passes.CheckTypes => true
- case _ => false
+ case _ => false
}
import InferResets._
@@ -138,7 +139,7 @@ class InferResets extends Transform with DependencyAPIMigration {
val mod = instMap(target.ref)
val port = target.component.head match {
case TargetToken.Field(name) => name
- case bad => Utils.throwInternalError(s"Unexpected token $bad")
+ case bad => Utils.throwInternalError(s"Unexpected token $bad")
}
target.copy(module = mod, ref = port, component = target.component.tail)
case _ => target
@@ -148,17 +149,18 @@ class InferResets extends Transform with DependencyAPIMigration {
// Mark driver of a ResetType leaf
def markResetDriver(lhs: Expression, rhs: Expression): Unit = {
val con = Utils.flow(lhs) match {
- case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs))
+ case SinkFlow if lhs.tpe == ResetType => Some((lhs, rhs))
case SourceFlow if rhs.tpe == ResetType => Some((rhs, lhs))
// If sink is not ResetType, do nothing
- case _ => None
+ case _ => None
}
- con.foreach { case (loc, exp) =>
- val driver = exp.tpe match {
- case ResetType => TargetDriver(makeTarget(exp))
- case tpe => TypeDriver(tpe, () => makeTarget(exp))
- }
- map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver
+ con.foreach {
+ case (loc, exp) =>
+ val driver = exp.tpe match {
+ case ResetType => TargetDriver(makeTarget(exp))
+ case tpe => TypeDriver(tpe, () => makeTarget(exp))
+ }
+ map.getOrElseUpdate(makeTarget(loc), mutable.ListBuffer()) += driver
}
}
stmt match {
@@ -227,7 +229,7 @@ class InferResets extends Transform with DependencyAPIMigration {
private def resolve(map: Map[ReferenceTarget, List[ResetDriver]]): Try[Map[ReferenceTarget, Type]] = {
val graph = new MutableDiGraph[Node]
val asyncNode = Typ(AsyncResetType)
- val syncNode = Typ(Utils.BoolType)
+ val syncNode = Typ(Utils.BoolType)
for ((target, drivers) <- map) {
val v = Var(target)
drivers.foreach {
@@ -247,7 +249,7 @@ class InferResets extends Transform with DependencyAPIMigration {
// do the actual inference, the check is simply if syncNode is reachable from asyncNode
graph.addPairWithEdge(v, u)
case InvalidDriver =>
- graph.addVertex(v) // Must be in the graph or won't be inferred
+ graph.addVertex(v) // Must be in the graph or won't be inferred
}
}
val async = graph.reachableFrom(asyncNode)
@@ -257,7 +259,7 @@ class InferResets extends Transform with DependencyAPIMigration {
case (a, _) if a.contains(syncNode) => throw InferResetsException(graph.path(asyncNode, syncNode))
case (a, s) =>
(a.view.collect { case Var(t) => t -> asyncNode.tpe } ++
- s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap
+ s.view.collect { case Var(t) => t -> syncNode.tpe }).toMap
}
}
}
@@ -265,34 +267,40 @@ class InferResets extends Transform with DependencyAPIMigration {
private def fixupType(tpe: Type, tree: TypeTree): Type = (tpe, tree) match {
case (BundleType(fields), BundleTree(map)) =>
val fieldsx =
- fields.map(f => map.get(f.name) match {
- case Some(t) => f.copy(tpe = fixupType(f.tpe, t))
- case None => f
- })
+ fields.map(f =>
+ map.get(f.name) match {
+ case Some(t) => f.copy(tpe = fixupType(f.tpe, t))
+ case None => f
+ }
+ )
BundleType(fieldsx)
case (VectorType(vtpe, size), VectorTree(t)) =>
VectorType(fixupType(vtpe, t), size)
case (_, GroundTree(t)) => t
- case x => throw new Exception(s"Error! Unexpected pair $x")
+ case x => throw new Exception(s"Error! Unexpected pair $x")
}
// Assumes all ReferenceTargets are in the same module
private def makeDeclMap(map: Map[ReferenceTarget, Type]): Map[String, TypeTree] =
- map.groupBy(_._1.ref).mapValues { ts =>
- TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }:_*)
- }.toMap
+ map
+ .groupBy(_._1.ref)
+ .mapValues { ts =>
+ TypeTree.fromTokens(ts.toSeq.map { case (target, tpe) => (target.component, tpe) }: _*)
+ }
+ .toMap
private def implPort(map: Map[String, TypeTree])(port: Port): Port =
- map.get(port.name)
- .map(tree => port.copy(tpe = fixupType(port.tpe, tree)))
- .getOrElse(port)
+ map
+ .get(port.name)
+ .map(tree => port.copy(tpe = fixupType(port.tpe, tree)))
+ .getOrElse(port)
private def implStmt(map: Map[String, TypeTree])(stmt: Statement): Statement =
stmt.map(implStmt(map)) match {
case decl: IsDeclaration if map.contains(decl.name) =>
val tree = map(decl.name)
decl match {
- case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree))
- case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree))
+ case reg: DefRegister => reg.copy(tpe = fixupType(reg.tpe, tree))
+ case wire: DefWire => wire.copy(tpe = fixupType(wire.tpe, tree))
// TODO Can this really happen?
case mem: DefMemory => mem.copy(dataType = fixupType(mem.dataType, tree))
case other => other
@@ -303,10 +311,13 @@ class InferResets extends Transform with DependencyAPIMigration {
private def implement(c: Circuit, map: Map[ReferenceTarget, Type]): Circuit = {
val modMaps = map.groupBy(_._1.module)
def onMod(mod: DefModule): DefModule = {
- modMaps.get(mod.name).map { tmap =>
- val declMap = makeDeclMap(tmap)
- mod.map(implPort(declMap)).map(implStmt(declMap))
- }.getOrElse(mod)
+ modMaps
+ .get(mod.name)
+ .map { tmap =>
+ val declMap = makeDeclMap(tmap)
+ mod.map(implPort(declMap)).map(implStmt(declMap))
+ }
+ .getOrElse(mod)
}
c.map(onMod)
}
diff --git a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
index 515bf407..100b598f 100644
--- a/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
+++ b/src/main/scala/firrtl/transforms/InlineBitExtractions.scala
@@ -6,7 +6,7 @@ package transforms
import firrtl.ir._
import firrtl.Mappers._
import firrtl.options.Dependency
-import firrtl.PrimOps.{Bits, Head, Tail, Shr}
+import firrtl.PrimOps.{Bits, Head, Shr, Tail}
import firrtl.Utils.{isBitExtract, isTemp}
import firrtl.WrappedExpression._
@@ -19,8 +19,8 @@ object InlineBitExtractionsTransform {
// Note that this can have false negatives but MUST NOT have false positives.
private def isSimpleExpr(expr: Expression): Boolean = expr match {
case _: WRef | _: Literal | _: WSubField => true
- case DoPrim(op, args, _,_) if isBitExtract(op) => args.forall(isSimpleExpr)
- case _ => false
+ case DoPrim(op, args, _, _) if isBitExtract(op) => args.forall(isSimpleExpr)
+ case _ => false
}
// replace Head/Tail/Shr with Bits for easier back-to-back Bits Extractions
@@ -28,12 +28,12 @@ object InlineBitExtractionsTransform {
case DoPrim(Head, rhs, c, tpe) if isSimpleExpr(expr) =>
val msb = bitWidth(rhs.head.tpe) - 1
val lsb = bitWidth(rhs.head.tpe) - c.head
- DoPrim(Bits, rhs, Seq(msb,lsb), tpe)
+ DoPrim(Bits, rhs, Seq(msb, lsb), tpe)
case DoPrim(Tail, rhs, c, tpe) if isSimpleExpr(expr) =>
val msb = bitWidth(rhs.head.tpe) - c.head - 1
- DoPrim(Bits, rhs, Seq(msb,0), tpe)
+ DoPrim(Bits, rhs, Seq(msb, 0), tpe)
case DoPrim(Shr, rhs, c, tpe) if isSimpleExpr(expr) =>
- DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe)-1, c.head), tpe)
+ DoPrim(Bits, rhs, Seq(bitWidth(rhs.head.tpe) - 1, c.head), tpe)
case _ => expr // Not a candidate
}
@@ -49,26 +49,28 @@ object InlineBitExtractionsTransform {
*/
def onExpr(netlist: Netlist)(expr: Expression): Expression = {
expr.map(onExpr(netlist)) match {
- case e @ WRef(name, _,_,_) =>
- netlist.get(we(e))
- .filter(isBitExtract)
- .getOrElse(e)
+ case e @ WRef(name, _, _, _) =>
+ netlist
+ .get(we(e))
+ .filter(isBitExtract)
+ .getOrElse(e)
// replace back-to-back Bits Extractions
case lhs @ DoPrim(lop, ival, lc, ltpe) if isSimpleExpr(lhs) =>
ival.head match {
case of @ DoPrim(rop, rhs, rc, rtpe) if isSimpleExpr(of) =>
(lop, rop) match {
- case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head min rc.head), ltpe)
+ case (Head, Head) => DoPrim(Head, rhs, Seq(lc.head.min(rc.head)), ltpe)
case (Tail, Tail) => DoPrim(Tail, rhs, Seq(lc.head + rc.head), ltpe)
- case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
- case (_,_) => (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
- case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
- DoPrim(Bits, rhs, Seq(lmsb+rlsb,llsb+rlsb), ltpe)
- case (_,_) => lhs // Not a candidate
- }
+ case (Shr, Shr) => DoPrim(Shr, rhs, Seq(lc.head + rc.head), ltpe)
+ case (_, _) =>
+ (lowerToDoPrimOpBits(lhs), lowerToDoPrimOpBits(of)) match {
+ case (DoPrim(Bits, _, Seq(lmsb, llsb), _), DoPrim(Bits, _, Seq(rmsb, rlsb), _)) =>
+ DoPrim(Bits, rhs, Seq(lmsb + rlsb, llsb + rlsb), ltpe)
+ case (_, _) => lhs // Not a candidate
+ }
}
- case _ => lhs // Not a candidate
- }
+ case _ => lhs // Not a candidate
+ }
case other => other // Not a candidate
}
}
@@ -97,9 +99,11 @@ object InlineBitExtractionsTransform {
class InlineBitExtractionsTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/InlineCasts.scala b/src/main/scala/firrtl/transforms/InlineCasts.scala
index 3dac938e..0efc0727 100644
--- a/src/main/scala/firrtl/transforms/InlineCasts.scala
+++ b/src/main/scala/firrtl/transforms/InlineCasts.scala
@@ -8,7 +8,7 @@ import firrtl.Mappers._
import firrtl.PrimOps.Pad
import firrtl.options.Dependency
-import firrtl.Utils.{isCast, isBitExtract, NodeMap}
+import firrtl.Utils.{isBitExtract, isCast, NodeMap}
object InlineCastsTransform {
@@ -17,8 +17,8 @@ object InlineCastsTransform {
// Note that this can have false negatives but MUST NOT have false positives
private def isSimpleCast(castSeen: Boolean)(expr: Expression): Boolean = expr match {
case _: WRef | _: Literal | _: WSubField => castSeen
- case DoPrim(op, args, _,_) if isCast(op) => args.forall(isSimpleCast(true))
- case _ => false
+ case DoPrim(op, args, _, _) if isCast(op) => args.forall(isSimpleCast(true))
+ case _ => false
}
/** Recursively replace [[WRef]]s with new [[firrtl.ir.Expression Expression]]s
@@ -31,17 +31,20 @@ object InlineCastsTransform {
def onExpr(replace: NodeMap)(expr: Expression): Expression = expr match {
// Anything that may generate a part-select should not be inlined!
case DoPrim(op, _, _, _) if (isBitExtract(op) || op == Pad) => expr
- case e => e.map(onExpr(replace)) match {
- case e @ WRef(name, _,_,_) =>
- replace.get(name)
- .filter(isSimpleCast(castSeen=false))
- .getOrElse(e)
- case e @ DoPrim(op, Seq(WRef(name, _,_,_)), _,_) if isCast(op) =>
- replace.get(name)
- .map(value => e.copy(args = Seq(value)))
- .getOrElse(e)
- case other => other // Not a candidate
- }
+ case e =>
+ e.map(onExpr(replace)) match {
+ case e @ WRef(name, _, _, _) =>
+ replace
+ .get(name)
+ .filter(isSimpleCast(castSeen = false))
+ .getOrElse(e)
+ case e @ DoPrim(op, Seq(WRef(name, _, _, _)), _, _) if isCast(op) =>
+ replace
+ .get(name)
+ .map(value => e.copy(args = Seq(value)))
+ .getOrElse(e)
+ case other => other // Not a candidate
+ }
}
/** Inline casts in a Statement
@@ -69,11 +72,13 @@ object InlineCastsTransform {
class InlineCastsTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[PropagatePresetAnnotations] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[PropagatePresetAnnotations]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/LegalizeClocks.scala b/src/main/scala/firrtl/transforms/LegalizeClocks.scala
index f439fdc9..248775d9 100644
--- a/src/main/scala/firrtl/transforms/LegalizeClocks.scala
+++ b/src/main/scala/firrtl/transforms/LegalizeClocks.scala
@@ -18,8 +18,8 @@ object LegalizeClocksTransform {
// Currently only looks for literals nested within casts
private def illegalClockExpr(expr: Expression): Boolean = expr match {
case _: Literal => true
- case DoPrim(op, args, _,_) if isCast(op) => args.exists(illegalClockExpr)
- case _ => false
+ case DoPrim(op, args, _, _) if isCast(op) => args.exists(illegalClockExpr)
+ case _ => false
}
/** Legalize Clocks in a Statement
@@ -66,11 +66,13 @@ object LegalizeClocksTransform {
class LegalizeClocksTransform extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[InlineCastsTransform] )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[InlineCastsTransform]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/LegalizeReductions.scala b/src/main/scala/firrtl/transforms/LegalizeReductions.scala
index 2e60aae7..33a10349 100644
--- a/src/main/scala/firrtl/transforms/LegalizeReductions.scala
+++ b/src/main/scala/firrtl/transforms/LegalizeReductions.scala
@@ -6,17 +6,16 @@ import firrtl.Mappers._
import firrtl.options.Dependency
import firrtl.Utils.BoolType
-
object LegalizeAndReductionsTransform {
private def allOnesOfType(tpe: Type): Literal = tpe match {
case UIntType(width @ IntWidth(x)) => UIntLiteral((BigInt(1) << x.toInt) - 1, width)
- case SIntType(width) => SIntLiteral(-1, width)
+ case SIntType(width) => SIntLiteral(-1, width)
}
def onExpr(expr: Expression): Expression = expr.map(onExpr) match {
- case DoPrim(PrimOps.Andr, Seq(arg), _,_) if bitWidth(arg.tpe) > 64 =>
+ case DoPrim(PrimOps.Andr, Seq(arg), _, _) if bitWidth(arg.tpe) > 64 =>
DoPrim(PrimOps.Eq, Seq(arg, allOnesOfType(arg.tpe)), Seq(), BoolType)
case other => other
}
@@ -35,8 +34,7 @@ class LegalizeAndReductionsTransform extends Transform with DependencyAPIMigrati
override def prerequisites =
firrtl.stage.Forms.WorkingIR ++
- Seq( Dependency(passes.CheckTypes),
- Dependency(passes.CheckWidths))
+ Seq(Dependency(passes.CheckTypes), Dependency(passes.CheckWidths))
override def optionalPrerequisites = Nil
diff --git a/src/main/scala/firrtl/transforms/ManipulateNames.scala b/src/main/scala/firrtl/transforms/ManipulateNames.scala
index f15b546f..d0b12e66 100644
--- a/src/main/scala/firrtl/transforms/ManipulateNames.scala
+++ b/src/main/scala/firrtl/transforms/ManipulateNames.scala
@@ -57,8 +57,9 @@ sealed trait ManipulateNamesListAnnotation[A <: ManipulateNames[_]] extends Mult
* @note $noteLocalTargets
*/
case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]](
- targets: Seq[Seq[Target]],
- transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] {
+ targets: Seq[Seq[Target]],
+ transform: Dependency[A])
+ extends ManipulateNamesListAnnotation[A] {
override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a)
@@ -77,8 +78,9 @@ case class ManipulateNamesBlocklistAnnotation[A <: ManipulateNames[_]](
* @note $noteLocalTargets
*/
case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]](
- targets: Seq[Seq[Target]],
- transform: Dependency[A]) extends ManipulateNamesListAnnotation[A] {
+ targets: Seq[Seq[Target]],
+ transform: Dependency[A])
+ extends ManipulateNamesListAnnotation[A] {
override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a)
@@ -94,19 +96,21 @@ case class ManipulateNamesAllowlistAnnotation[A <: ManipulateNames[_]](
* @param oldTargets the old targets
*/
case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]](
- targets: Seq[Seq[Target]],
- transform: Dependency[A],
- oldTargets: Seq[Seq[Target]]) extends MultiTargetAnnotation {
+ targets: Seq[Seq[Target]],
+ transform: Dependency[A],
+ oldTargets: Seq[Seq[Target]])
+ extends MultiTargetAnnotation {
override def duplicate(a: Seq[Seq[Target]]) = this.copy(targets = a)
override def update(renames: RenameMap) = {
val (targetsx, oldTargetsx) = targets.zip(oldTargets).foldLeft((Seq.empty[Seq[Target]], Seq.empty[Seq[Target]])) {
- case ((accT, accO), (t, o)) => t.flatMap(renames(_)) match {
- /* If the target was deleted, delete the old target */
- case tx if tx.isEmpty => (accT, accO)
- case tx => (Seq(tx) ++ accT, Seq(o) ++ accO)
- }
+ case ((accT, accO), (t, o)) =>
+ t.flatMap(renames(_)) match {
+ /* If the target was deleted, delete the old target */
+ case tx if tx.isEmpty => (accT, accO)
+ case tx => (Seq(tx) ++ accT, Seq(o) ++ accO)
+ }
}
targetsx match {
/* If all targets were deleted, delete the annotation */
@@ -117,9 +121,13 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]](
/** Return [[firrtl.RenameMap RenameMap]] from old targets to new targets */
def toRenameMap: RenameMap = {
- val m = oldTargets.zip(targets).flatMap {
- case (a, b) => a.map(_ -> b)
- }.toMap.asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]]
+ val m = oldTargets
+ .zip(targets)
+ .flatMap {
+ case (a, b) => a.map(_ -> b)
+ }
+ .toMap
+ .asInstanceOf[Map[CompleteTarget, Seq[CompleteTarget]]]
RenameMap.create(m)
}
@@ -132,25 +140,28 @@ case class ManipulateNamesAllowlistResultAnnotation[A <: ManipulateNames[_]](
* @param allow a function that returns true if a [[firrtl.annotations.Target Target]] should be renamed
*/
private class RenameDataStructure(
- circuit: ir.Circuit,
+ circuit: ir.Circuit,
val renames: RenameMap,
- val block: Target => Boolean,
- val allow: Target => Boolean) {
+ val block: Target => Boolean,
+ val allow: Target => Boolean) {
/** A mapping of targets to associated namespaces */
val namespaces: mutable.HashMap[CompleteTarget, Namespace] =
mutable.HashMap(CircuitTarget(circuit.main) -> Namespace(circuit))
- /** Wraps a HashMap to provide better error messages when accessing a non-existing element */
+ /** Wraps a HashMap to provide better error messages when accessing a non-existing element */
class InstanceHashMap {
type Key = ReferenceTarget
type Value = Either[ReferenceTarget, InstanceTarget]
private val m = mutable.HashMap[Key, Value]()
- def apply(key: ReferenceTarget): Value = m.getOrElse(key, {
- throw new FirrtlUserException(
- s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems.
- | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin)
- })
+ def apply(key: ReferenceTarget): Value = m.getOrElse(
+ key, {
+ throw new FirrtlUserException(
+ s"""|Reference target '${key.serialize}' did not exist in mapping of reference targets to insts/mems.
+ | This is indicative of a circuit that has not been run through LowerTypes.""".stripMargin
+ )
+ }
+ )
def update(key: Key, value: Value): Unit = m.update(key, value)
}
@@ -165,17 +176,17 @@ private class RenameDataStructure(
/** Transform for manipulate all the names in a FIRRTL circuit.
* @tparam A the type of the child transform
*/
-abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Transform with DependencyAPIMigration {
+abstract class ManipulateNames[A <: ManipulateNames[_]: ClassTag] extends Transform with DependencyAPIMigration {
/** A function used to manipulate a name in a FIRRTL circuit */
def manipulate: (String, Namespace) => Option[String]
- override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes))
- override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
+ override def prerequisites: Seq[TransformDependency] = Seq(Dependency(firrtl.passes.LowerTypes))
+ override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisiteOf: Seq[TransformDependency] = Forms.LowEmitters
override def invalidates(a: Transform) = a match {
case _: analyses.GetNamespace => true
- case _ => false
+ case _ => false
}
/** Compute a new name for some target and record the rename if the new name differs. If the top module or the circuit
@@ -192,27 +203,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
case a if r.skip(a) =>
(name, None)
/* Circuit renaming */
- case a@ CircuitTarget(b) => manipulate(b, r.namespaces(a)) match {
- case Some(str) => (str, Some(a.copy(circuit = str)))
- case None => (b, None)
- }
+ case a @ CircuitTarget(b) =>
+ manipulate(b, r.namespaces(a)) match {
+ case Some(str) => (str, Some(a.copy(circuit = str)))
+ case None => (b, None)
+ }
/* Module renaming for non-top modules */
- case a@ ModuleTarget(_, b) => manipulate(b, r.namespaces(a.circuitTarget)) match {
- case Some(str) => (str, Some(a.copy(module = str)))
- case None => (b, None)
- }
+ case a @ ModuleTarget(_, b) =>
+ manipulate(b, r.namespaces(a.circuitTarget)) match {
+ case Some(str) => (str, Some(a.copy(module = str)))
+ case None => (b, None)
+ }
/* Instance renaming */
- case a@ InstanceTarget(_, _, Nil, b, c) => manipulate(b, r.namespaces(a.moduleTarget)) match {
- case Some(str) => (str, Some(a.copy(instance = str)))
- case None => (b, None)
- }
+ case a @ InstanceTarget(_, _, Nil, b, c) =>
+ manipulate(b, r.namespaces(a.moduleTarget)) match {
+ case Some(str) => (str, Some(a.copy(instance = str)))
+ case None => (b, None)
+ }
/* Rename either a module component or a memory */
- case a@ ReferenceTarget(_, _, _, b, Nil) => manipulate(b, r.namespaces(a.moduleTarget)) match {
- case Some(str) => (str, Some(a.copy(ref = str)))
- case None => (b, None)
- }
+ case a @ ReferenceTarget(_, _, _, b, Nil) =>
+ manipulate(b, r.namespaces(a.moduleTarget)) match {
+ case Some(str) => (str, Some(a.copy(ref = str)))
+ case None => (b, None)
+ }
/* Rename an instance port or a memory reader/writer/readwriter */
- case a@ ReferenceTarget(_, _, _, b, (token@ TargetToken.Field(c)) :: Nil) =>
+ case a @ ReferenceTarget(_, _, _, b, (token @ TargetToken.Field(c)) :: Nil) =>
val ref = r.instanceMap(a.moduleTarget.ref(b)) match {
case Right(inst) => inst.ofModuleTarget
case Left(mem) => mem
@@ -224,8 +239,8 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
}
/* Record the optional rename. If the circuit was renamed, also rename the top module. If the top module was
* renamed, also rename the circuit. */
- ax.foreach(
- axx => target match {
+ ax.foreach(axx =>
+ target match {
case c: CircuitTarget =>
r.renames.rename(target, r.renames(axx))
r.renames.rename(c.module(c.circuit), CircuitTarget(namex).module(namex))
@@ -252,21 +267,26 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
r.renames.underlying.get(t) match {
case Some(ax) if ax.size == 1 =>
ax match {
- case Seq(foo: CircuitTarget) => foo.name
- case Seq(foo: ModuleTarget) => foo.module
- case Seq(foo: InstanceTarget) => foo.instance
- case Seq(foo: ReferenceTarget) => foo.tokens.last match {
- case TargetToken.Ref(value) => value
- case TargetToken.Field(value) => value
- case _ => Utils.throwInternalError(
- s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field'
+ case Seq(foo: CircuitTarget) => foo.name
+ case Seq(foo: ModuleTarget) => foo.module
+ case Seq(foo: InstanceTarget) => foo.instance
+ case Seq(foo: ReferenceTarget) =>
+ foo.tokens.last match {
+ case TargetToken.Ref(value) => value
+ case TargetToken.Field(value) => value
+ case _ =>
+ Utils.throwInternalError(
+ s"""|Reference target '${t.serialize}'must end in 'Ref' or 'Field'
| This is indicative of a circuit that has not been run through LowerTypes.""",
- Some(new MatchError(foo.serialize)))
- }
+ Some(new MatchError(foo.serialize))
+ )
+ }
}
- case s@ Some(ax) => Utils.throwInternalError(
- s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""",
- Some(new MatchError(s)))
+ case s @ Some(ax) =>
+ Utils.throwInternalError(
+ s"""Found multiple renames '${t}' -> [${ax.map(_.serialize).mkString(",")}]. This should be impossible.""",
+ Some(new MatchError(s))
+ )
case None => name
}
@@ -280,27 +300,34 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
/* A reference to something inside this module */
case w: WRef => w.copy(name = maybeRename(w.name, r, Target.asTarget(t)(w)))
/* This is either the subfield of an instance or a subfield of a memory reader/writer/readwriter */
- case w@ WSubField(expr, ref, _, _) => expr match {
- /* This is an instance */
- case we@ WRef(inst, _, _, _) =>
- val tx = Target.asTarget(t)(we)
- val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match {
- case Right(a) => (a.ofModuleTarget.ref(ref), a)
- case a@ Left(ref) => throw new FirrtlUserException(
- s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'.
- | This is indicative of a circuit that has not been run through LowerTypes.""", new MatchError(a))
- }
- w.copy(we.copy(name=maybeRename(inst, r, iTarget)), name=maybeRename(ref, r, rTarget))
- /* This is a reader/writer/readwriter */
- case ws@ WSubField(expr, port, _, _) => expr match {
- /* This is the memory. */
- case wr@ WRef(mem, _, _, _) =>
- w.copy(
- expr=ws.copy(
- expr=wr.copy(name=maybeRename(mem, r, t.ref(mem))),
- name=maybeRename(port, r, t.ref(mem).field(port))))
+ case w @ WSubField(expr, ref, _, _) =>
+ expr match {
+ /* This is an instance */
+ case we @ WRef(inst, _, _, _) =>
+ val tx = Target.asTarget(t)(we)
+ val (rTarget: ReferenceTarget, iTarget: InstanceTarget) = r.instanceMap(tx) match {
+ case Right(a) => (a.ofModuleTarget.ref(ref), a)
+ case a @ Left(ref) =>
+ throw new FirrtlUserException(
+ s"""|Unexpected '${ref.serialize}' in instanceMap for key '${tx.serialize}' on expression '${w.serialize}'.
+ | This is indicative of a circuit that has not been run through LowerTypes.""",
+ new MatchError(a)
+ )
+ }
+ w.copy(we.copy(name = maybeRename(inst, r, iTarget)), name = maybeRename(ref, r, rTarget))
+ /* This is a reader/writer/readwriter */
+ case ws @ WSubField(expr, port, _, _) =>
+ expr match {
+ /* This is the memory. */
+ case wr @ WRef(mem, _, _, _) =>
+ w.copy(
+ expr = ws.copy(
+ expr = wr.copy(name = maybeRename(mem, r, t.ref(mem))),
+ name = maybeRename(port, r, t.ref(mem).field(port))
+ )
+ )
+ }
}
- }
case e => e.map(onExpression(_: ir.Expression, r, t))
}
@@ -310,30 +337,31 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
* and readwriters.
*/
private def onStatement(s: ir.Statement, r: RenameDataStructure, t: ModuleTarget): ir.Statement = s match {
- case decl: ir.IsDeclaration => decl match {
- case decl@ WDefInstance(_, inst, mod, _) =>
- val modx = maybeRename(mod, r, t.circuitTarget.module(mod))
- val instx = doRename(inst, r, t.instOf(inst, mod))
- r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod))
- decl.copy(name = instx, module = modx)
- case decl: ir.DefMemory =>
- val namex = doRename(decl.name, r, t.ref(decl.name))
- val tx = t.ref(decl.name)
- r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters)
- r.instanceMap(tx) = Left(tx)
- decl
- .copy(
- name = namex,
- readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))),
- writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))),
- readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw)))
- )
- .map(onExpression(_: ir.Expression, r, t))
- case decl =>
- decl
- .map(doRename(_: String, r, t.ref(decl.name)))
- .map(onExpression(_: ir.Expression, r, t))
- }
+ case decl: ir.IsDeclaration =>
+ decl match {
+ case decl @ WDefInstance(_, inst, mod, _) =>
+ val modx = maybeRename(mod, r, t.circuitTarget.module(mod))
+ val instx = doRename(inst, r, t.instOf(inst, mod))
+ r.instanceMap(t.ref(inst)) = Right(t.instOf(inst, mod))
+ decl.copy(name = instx, module = modx)
+ case decl: ir.DefMemory =>
+ val namex = doRename(decl.name, r, t.ref(decl.name))
+ val tx = t.ref(decl.name)
+ r.namespaces(tx) = Namespace(decl.readers ++ decl.writers ++ decl.readwriters)
+ r.instanceMap(tx) = Left(tx)
+ decl
+ .copy(
+ name = namex,
+ readers = decl.readers.map(_r => doRename(_r, r, tx.field(_r))),
+ writers = decl.writers.map(_w => doRename(_w, r, tx.field(_w))),
+ readwriters = decl.readwriters.map(_rw => doRename(_rw, r, tx.field(_rw)))
+ )
+ .map(onExpression(_: ir.Expression, r, t))
+ case decl =>
+ decl
+ .map(doRename(_: String, r, t.ref(decl.name)))
+ .map(onExpression(_: ir.Expression, r, t))
+ }
case s =>
s
.map(onStatement(_: ir.Statement, r, t))
@@ -362,7 +390,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
*/
val onName: String => String = t.circuit match {
case `main` => maybeRename(_, r, moduleTarget)
- case _ => doRename(_, r, moduleTarget)
+ case _ => doRename(_, r, moduleTarget)
}
m
@@ -380,11 +408,11 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
* @return the circuit with manipulated names
*/
def run(
- c: ir.Circuit,
+ c: ir.Circuit,
renames: RenameMap,
- block: Target => Boolean,
- allow: Target => Boolean)
- : ir.Circuit = {
+ block: Target => Boolean,
+ allow: Target => Boolean
+ ): ir.Circuit = {
val t = CircuitTarget(c.main)
/* If the circuit is a skip, return the original circuit. Otherwise, walk all the modules and rename them. Rename the
@@ -427,8 +455,7 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
.toMap
/* Replace the old modules making sure that they are still in the same order */
- c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))),
- main = mainx)
+ c.copy(modules = c.modules.map(m => modulesx(t.module(m.name))), main = mainx)
}
}
@@ -436,18 +463,20 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
def execute(state: CircuitState): CircuitState = {
val block = state.annotations.collect {
- case ManipulateNamesBlocklistAnnotation(targetSeq, t) => t.getObject match {
- case _: A => targetSeq
- case _ => Nil
- }
+ case ManipulateNamesBlocklistAnnotation(targetSeq, t) =>
+ t.getObject match {
+ case _: A => targetSeq
+ case _ => Nil
+ }
}.flatten.flatten.toSet
val allow = {
val allowx = state.annotations.collect {
- case ManipulateNamesAllowlistAnnotation(targetSeq, t) => t.getObject match {
- case _: A => targetSeq
- case _ => Nil
- }
+ case ManipulateNamesAllowlistAnnotation(targetSeq, t) =>
+ t.getObject match {
+ case _: A => targetSeq
+ case _ => Nil
+ }
}.flatten.flatten
allowx match {
@@ -461,17 +490,19 @@ abstract class ManipulateNames[A <: ManipulateNames[_] : ClassTag] extends Trans
val annotationsx = state.annotations.flatMap {
/* Consume blocklist annotations */
- case foo@ ManipulateNamesBlocklistAnnotation(_, t) => t.getObject match {
- case _: A => None
- case _ => Some(foo)
- }
+ case foo @ ManipulateNamesBlocklistAnnotation(_, t) =>
+ t.getObject match {
+ case _: A => None
+ case _ => Some(foo)
+ }
/* Convert allowlist annotations to result annotations */
- case foo@ ManipulateNamesAllowlistAnnotation(a, t) =>
+ case foo @ ManipulateNamesAllowlistAnnotation(a, t) =>
t.getObject match {
- case _: A => (a, a.map(_.map(renames(_)).flatten)) match {
- case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a))
- }
- case _ => Some(foo)
+ case _: A =>
+ (a, a.map(_.map(renames(_)).flatten)) match {
+ case (a, b) => Some(ManipulateNamesAllowlistResultAnnotation(b, t, a))
+ }
+ case _ => Some(foo)
}
case a => Some(a)
}
diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
index ff44afec..5532d0f0 100644
--- a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
+++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
@@ -1,4 +1,3 @@
-
package firrtl
package transforms
@@ -34,17 +33,19 @@ trait DontTouchAllTargets extends HasDontTouches { self: Annotation =>
* DCE treats the component as a top-level sink of the circuit
*/
case class DontTouchAnnotation(target: ReferenceTarget)
- extends SingleTargetAnnotation[ReferenceTarget] with DontTouchAllTargets {
+ extends SingleTargetAnnotation[ReferenceTarget]
+ with DontTouchAllTargets {
def targets = Seq(target)
def duplicate(n: ReferenceTarget) = this.copy(n)
}
object DontTouchAnnotation {
- class DontTouchNotFoundException(module: String, component: String) extends PassException(
- s"""|Target marked dontTouch ($module.$component) not found!
- |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then
- |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin
- )
+ class DontTouchNotFoundException(module: String, component: String)
+ extends PassException(
+ s"""|Target marked dontTouch ($module.$component) not found!
+ |It was probably accidentally deleted. Please check that your custom transforms are not responsible and then
+ |file an issue on GitHub: https://github.com/freechipsproject/firrtl/issues/new""".stripMargin
+ )
def errorNotFound(module: String, component: String) =
throw new DontTouchNotFoundException(module, component)
@@ -58,7 +59,6 @@ object DontTouchAnnotation {
*
* @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted
*/
-case class OptimizableExtModuleAnnotation(target: ModuleName) extends
- SingleTargetAnnotation[ModuleName] {
+case class OptimizableExtModuleAnnotation(target: ModuleName) extends SingleTargetAnnotation[ModuleName] {
def duplicate(n: ModuleName) = this.copy(n)
}
diff --git a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala
index da803837..97db0219 100644
--- a/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala
+++ b/src/main/scala/firrtl/transforms/PropagatePresetAnnotations.scala
@@ -11,8 +11,10 @@ import firrtl.options.Dependency
import scala.collection.mutable
object PropagatePresetAnnotations {
- val advice = "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset."
- case class TreeCleanUpOrphanException(message: String) extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice")
+ val advice =
+ "Please Note that a Preset-annotated AsyncReset shall NOT be casted to other types with any of the following functions: asInterval, asUInt, asSInt, asClock, asFixedPoint, asAsyncReset."
+ case class TreeCleanUpOrphanException(message: String)
+ extends FirrtlUserException(s"Node left an orphan during tree cleanup: $message $advice")
}
/** Propagate PresetAnnotations to all children of targeted AsyncResets
@@ -39,9 +41,11 @@ object PropagatePresetAnnotations {
class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic])
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic]
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
@@ -52,7 +56,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
import PropagatePresetAnnotations._
private type TargetSet = mutable.HashSet[ReferenceTarget]
- private type TargetMap = mutable.HashMap[ReferenceTarget,String]
+ private type TargetMap = mutable.HashMap[ReferenceTarget, String]
private type TargetSetMap = mutable.HashMap[ReferenceTarget, TargetSet]
private val toCleanUp = new TargetSet()
@@ -71,7 +75,11 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
* @param presetAnnos all the annotations
* @return updated annotations
*/
- private def propagate(cs: CircuitState, presetAnnos: Seq[PresetAnnotation], otherAnnos: Seq[Annotation]): AnnotationSeq = {
+ private def propagate(
+ cs: CircuitState,
+ presetAnnos: Seq[PresetAnnotation],
+ otherAnnos: Seq[Annotation]
+ ): AnnotationSeq = {
val presets = presetAnnos.groupBy(_.target)
// store all annotated asyncreset references
val asyncToAnnotate = new TargetSet()
@@ -85,34 +93,34 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
val circuitTarget = CircuitTarget(cs.circuit.main)
/*
- * WALK I PHASE 1 FUNCTIONS
- */
+ * WALK I PHASE 1 FUNCTIONS
+ */
/* Walk current module
- * - process ports
- * - store connections & entry points for PHASE 2
- * - process statements
- * - Instances => record local instances for cross module AsyncReset Tree Buidling
- * - Registers => store AsyncReset bound registers for PHASE 2
- * - Wire => store AsyncReset Connections & entry points for PHASE 2
- * - Connect => store AsyncReset Connections & entry points for PHASE 2
- *
- * @param m module
- */
+ * - process ports
+ * - store connections & entry points for PHASE 2
+ * - process statements
+ * - Instances => record local instances for cross module AsyncReset Tree Buidling
+ * - Registers => store AsyncReset bound registers for PHASE 2
+ * - Wire => store AsyncReset Connections & entry points for PHASE 2
+ * - Connect => store AsyncReset Connections & entry points for PHASE 2
+ *
+ * @param m module
+ */
def processModule(m: DefModule): Unit = {
val moduleTarget = circuitTarget.module(m.name)
val localInstances = new TargetMap()
/* Recursively process a given type
- * Recursive on Bundle and Vector Type only
- * Store Register and Connections for AsyncResetType
- * @param tpe [[Type]] to be processed
- * @param target [[ReferenceTarget]] associated to the tpe
- * @param all Boolean indicating whether all subelements of the current
- * tpe should also be stored as Annotated AsyncReset entry points
- */
+ * Recursive on Bundle and Vector Type only
+ * Store Register and Connections for AsyncResetType
+ * @param tpe [[Type]] to be processed
+ * @param target [[ReferenceTarget]] associated to the tpe
+ * @param all Boolean indicating whether all subelements of the current
+ * tpe should also be stored as Annotated AsyncReset entry points
+ */
def processType(tpe: Type, target: ReferenceTarget, all: Boolean): Unit = {
- if(tpe == AsyncResetType){
+ if (tpe == AsyncResetType) {
asyncRegMap(target) = new TargetSet()
asyncCoMap(target) = new TargetSet()
if (presets.contains(target) || all) {
@@ -121,14 +129,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
} else {
tpe match {
case b: BundleType =>
- b.fields.foreach{
- (x: Field) =>
- val tar = target.field(x.name)
- processType(x.tpe, tar, (presets.contains(tar) || all))
+ b.fields.foreach { (x: Field) =>
+ val tar = target.field(x.name)
+ processType(x.tpe, tar, (presets.contains(tar) || all))
}
case v: VectorType =>
- for(i <- 0 until v.size) {
+ for (i <- 0 until v.size) {
val tar = target.index(i)
processType(v.tpe, tar, (presets.contains(tar) || all))
}
@@ -143,19 +150,19 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
/* Recursively search for the ReferenceTarget of a given Expression
- * @param e Targeted Expression
- * @param ta Local ReferenceTarget of the Targeted Expression
- * @return a ReferenceTarget in case of success, a GenericTarget otherwise
- * @throws [[InternalError]] on unexpected recursive path return results
- */
- def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false) : Target = {
+ * @param e Targeted Expression
+ * @param ta Local ReferenceTarget of the Targeted Expression
+ * @return a ReferenceTarget in case of success, a GenericTarget otherwise
+ * @throws [[InternalError]] on unexpected recursive path return results
+ */
+ def getRef(e: Expression, ta: ReferenceTarget, annoCo: Boolean = false): Target = {
e match {
case w: WRef => moduleTarget.ref(w.name)
case w: WSubField =>
getRef(w.expr, ta, annoCo) match {
case rt: ReferenceTarget =>
- if(localInstances.contains(rt)){
- val remote_ref = circuitTarget.module(localInstances(rt))
+ if (localInstances.contains(rt)) {
+ val remote_ref = circuitTarget.module(localInstances(rt))
if (annoCo)
asyncCoMap(ta) += rt.field(w.name)
remote_ref.ref(w.name)
@@ -163,7 +170,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
rt.field(w.name)
}
case remote_target => remote_target
- }
+ }
case w: WSubIndex =>
getRef(w.expr, ta, annoCo) match {
case remote_target: ReferenceTarget =>
@@ -179,7 +186,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def processRegister(r: DefRegister): Unit = {
getRef(r.reset, moduleTarget.ref(r.name), false) match {
- case rt : ReferenceTarget =>
+ case rt: ReferenceTarget =>
if (asyncRegMap.contains(rt)) {
asyncRegMap(rt) += moduleTarget.ref(r.name)
}
@@ -189,12 +196,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
def processConnect(c: Connect): Unit = {
- getRef(c.expr, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match {
+ getRef(c.expr, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match {
case rhs: ReferenceTarget =>
if (presets.contains(rhs) || asyncRegMap.contains(rhs)) {
getRef(c.loc, rhs, true) match {
- case lhs : ReferenceTarget =>
- if(asyncRegMap.contains(rhs)){
+ case lhs: ReferenceTarget =>
+ if (asyncRegMap.contains(rhs)) {
asyncRegMap(rhs) += lhs
} else {
asyncToAnnotate += lhs
@@ -211,10 +218,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
val target = moduleTarget.ref(n.name)
processType(n.value.tpe, target, presets.contains(target))
- getRef(n.value, ReferenceTarget("","", Seq.empty, "", Seq.empty)) match {
+ getRef(n.value, ReferenceTarget("", "", Seq.empty, "", Seq.empty)) match {
case rhs: ReferenceTarget =>
if (presets.contains(rhs) || asyncRegMap.contains(rhs)) {
- if(asyncRegMap.contains(rhs)){
+ if (asyncRegMap.contains(rhs)) {
asyncRegMap(rhs) += target
} else {
asyncToAnnotate += target
@@ -227,18 +234,18 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def processStatements(statement: Statement): Unit = {
statement match {
- case i : WDefInstance =>
+ case i: WDefInstance =>
localInstances(moduleTarget.ref(i.name)) = i.module
- case r : DefRegister => processRegister(r)
- case w : DefWire => processWire(w)
- case n : DefNode => processNode(n)
- case c : Connect => processConnect(c)
- case s => s.foreachStmt(processStatements)
+ case r: DefRegister => processRegister(r)
+ case w: DefWire => processWire(w)
+ case n: DefNode => processNode(n)
+ case c: Connect => processConnect(c)
+ case s => s.foreachStmt(processStatements)
}
}
def processPorts(port: Port): Unit = {
- if(port.tpe == AsyncResetType){
+ if (port.tpe == AsyncResetType) {
val target = moduleTarget.ref(port.name)
asyncRegMap(target) = new TargetSet()
asyncCoMap(target) = new TargetSet()
@@ -263,17 +270,17 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
/** Annotate a given target and all its children according to the asyncCoMap */
def annotateCo(ta: ReferenceTarget): Unit = {
- if (asyncCoMap.contains(ta)){
+ if (asyncCoMap.contains(ta)) {
toCleanUp += ta
- asyncCoMap(ta) foreach( (t: ReferenceTarget) => {
+ asyncCoMap(ta).foreach((t: ReferenceTarget) => {
toCleanUp += t
})
}
}
/** Annotate all registers somehow connected to the orignal annotated async reset */
- def annotateRegSet(set: TargetSet) : Unit = {
- set foreach ( (ta: ReferenceTarget) => {
+ def annotateRegSet(set: TargetSet): Unit = {
+ set.foreach((ta: ReferenceTarget) => {
annotateCo(ta)
if (asyncRegMap.contains(ta)) {
annotateRegSet(asyncRegMap(ta))
@@ -287,8 +294,8 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
* Walk AsyncReset Trees with all Annotated AsyncReset as entry points
* Annotate all leaf registers and intermediate wires, nodes, connectors along the way
*/
- def annotateAsyncSet(set: TargetSet) : Unit = {
- set foreach ((t: ReferenceTarget) => {
+ def annotateAsyncSet(set: TargetSet): Unit = {
+ set.foreach((t: ReferenceTarget) => {
annotateCo(t)
if (asyncRegMap.contains(t))
annotateRegSet(asyncRegMap(t))
@@ -300,7 +307,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
*/
cs.circuit.foreachModule(processModule) // PHASE 1 : Initialize
- annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate
+ annotateAsyncSet(asyncToAnnotate) // PHASE 2 : Annotate
otherAnnos ++ newAnnos
}
@@ -312,21 +319,21 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
* Clean-up useless reset tree (not relying on DCE)
* Disconnect preset registers from their reset tree
*/
- private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq) : Circuit = {
- val presetRegs = annos.collect {case a : PresetRegAnnotation => a}.groupBy(_.target)
+ private def cleanUpPresetTree(circuit: Circuit, annos: AnnotationSeq): Circuit = {
+ val presetRegs = annos.collect { case a: PresetRegAnnotation => a }.groupBy(_.target)
val circuitTarget = CircuitTarget(circuit.main)
def processModule(m: DefModule): DefModule = {
val moduleTarget = circuitTarget.module(m.name)
val localInstances = new TargetMap()
- def getRef(e: Expression) : Target = {
+ def getRef(e: Expression): Target = {
e match {
case w: WRef => moduleTarget.ref(w.name)
case w: WSubField =>
getRef(w.expr) match {
case rt: ReferenceTarget =>
- if(localInstances.contains(rt)){
+ if (localInstances.contains(rt)) {
circuitTarget.module(localInstances(rt)).ref(w.name)
} else {
rt.field(w.name)
@@ -341,14 +348,13 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
case DoPrim(op, args, _, _) =>
op match {
case AsInterval | AsUInt | AsSInt | AsClock | AsFixedPoint | AsAsyncReset => getRef(args.head)
- case _ => Target(None, None, Seq.empty)
+ case _ => Target(None, None, Seq.empty)
}
case _ => Target(None, None, Seq.empty)
}
}
-
- def processRegister(r: DefRegister) : DefRegister = {
+ def processRegister(r: DefRegister): DefRegister = {
if (presetRegs.contains(moduleTarget.ref(r.name))) {
r.copy(reset = UIntLiteral(0))
} else {
@@ -356,7 +362,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
}
- def processWire(w: DefWire) : Statement = {
+ def processWire(w: DefWire): Statement = {
if (toCleanUp.contains(moduleTarget.ref(w.name))) {
EmptyStmt
} else {
@@ -364,12 +370,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
}
- def processNode(n: DefNode) : Statement = {
+ def processNode(n: DefNode): Statement = {
if (toCleanUp.contains(moduleTarget.ref(n.name))) {
EmptyStmt
} else {
getRef(n.value) match {
- case rt : ReferenceTarget if(toCleanUp.contains(rt)) =>
+ case rt: ReferenceTarget if (toCleanUp.contains(rt)) =>
throw TreeCleanUpOrphanException(s"Orphan (${moduleTarget.ref(n.name)}) the way.")
case _ => n
}
@@ -380,7 +386,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
getRef(c.expr) match {
case rhs: ReferenceTarget if (toCleanUp.contains(rhs)) =>
getRef(c.loc) match {
- case lhs : ReferenceTarget if(!toCleanUp.contains(lhs)) =>
+ case lhs: ReferenceTarget if (!toCleanUp.contains(lhs)) =>
throw TreeCleanUpOrphanException(s"Orphan ${lhs} connected deleted node $rhs.")
case _ => EmptyStmt
}
@@ -388,7 +394,7 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
}
}
- def processInstance(i: WDefInstance) : WDefInstance = {
+ def processInstance(i: WDefInstance): WDefInstance = {
localInstances(moduleTarget.ref(i.name)) = i.module
val tpe = i.tpe match {
case b: BundleType =>
@@ -401,12 +407,12 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def processStatements(statement: Statement): Statement = {
statement match {
- case i : WDefInstance => processInstance(i)
- case r : DefRegister => processRegister(r)
- case w : DefWire => processWire(w)
- case n : DefNode => processNode(n)
- case c : Connect => processConnect(c)
- case s => s.mapStmt(processStatements)
+ case i: WDefInstance => processInstance(i)
+ case r: DefRegister => processRegister(r)
+ case w: DefWire => processWire(w)
+ case n: DefNode => processNode(n)
+ case c: Connect => processConnect(c)
+ case s => s.mapStmt(processStatements)
}
}
@@ -422,10 +428,10 @@ class PropagatePresetAnnotations extends Transform with DependencyAPIMigration {
def execute(state: CircuitState): CircuitState = {
// Collect all user-defined PresetAnnotation
- val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true ; case _ => false }
+ val (presets, otherAnnos) = state.annotations.partition { case _: PresetAnnotation => true; case _ => false }
// No PresetAnnotation => no need to walk the IR
- if (presets.isEmpty){
+ if (presets.isEmpty) {
state
} else {
// PHASE I - Propagate
diff --git a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala
index 840a3d99..ae3bc693 100644
--- a/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala
+++ b/src/main/scala/firrtl/transforms/RemoveKeywordCollisions.scala
@@ -21,10 +21,11 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames {
* @return Some name if a rename occurred, None otherwise
* @note prefix uniqueness is not respected
*/
- override def manipulate = (n: String, ns: Namespace) => keywords.contains(n) match {
- case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords))
- case false => None
- }
+ override def manipulate = (n: String, ns: Namespace) =>
+ keywords.contains(n) match {
+ case true => Some(Uniquify.findValidPrefix(n + inlineDelim, Seq(""), ns.cloneUnderlying ++ keywords))
+ case false => None
+ }
}
@@ -32,14 +33,16 @@ class RemoveKeywordCollisions(keywords: Set[String]) extends ManipulateNames {
class VerilogRename extends RemoveKeywordCollisions(v_keywords) {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals],
- Dependency[ReplaceTruncatingArithmetic],
- Dependency[InlineBitExtractionsTransform],
- Dependency[InlineCastsTransform],
- Dependency[LegalizeClocksTransform],
- Dependency[FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup) )
+ Seq(
+ Dependency[BlackBoxSourceHelper],
+ Dependency[FixAddingNegativeLiterals],
+ Dependency[ReplaceTruncatingArithmetic],
+ Dependency[InlineBitExtractionsTransform],
+ Dependency[InlineCastsTransform],
+ Dependency[LegalizeClocksTransform],
+ Dependency[FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup)
+ )
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/RemoveReset.scala b/src/main/scala/firrtl/transforms/RemoveReset.scala
index 6b3a9d07..8736e21b 100644
--- a/src/main/scala/firrtl/transforms/RemoveReset.scala
+++ b/src/main/scala/firrtl/transforms/RemoveReset.scala
@@ -18,8 +18,7 @@ import scala.collection.{immutable, mutable}
object RemoveReset extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize) )
+ Seq(Dependency(passes.LowerTypes), Dependency(passes.Legalize))
override def optionalPrerequisites = Seq.empty
@@ -58,7 +57,7 @@ object RemoveReset extends Transform with DependencyAPIMigration {
reg.copy(reset = Utils.zero, init = WRef(reg))
case reg @ DefRegister(_, rname, _, _, Utils.zero, _) =>
reg.copy(init = WRef(reg)) // canonicalize
- case reg @ DefRegister(info , rname, _, _, reset, init) if reset.tpe != AsyncResetType =>
+ case reg @ DefRegister(info, rname, _, _, reset, init) if reset.tpe != AsyncResetType =>
// Add register reset to map
resets(rname) = Reset(reset, init, info)
reg.copy(reset = Utils.zero, init = WRef(reg))
@@ -68,7 +67,7 @@ object RemoveReset extends Transform with DependencyAPIMigration {
// Use reg source locator for mux enable and true value since that's where they're defined
val infox = MultiInfo(reset.info, reset.info, info)
Connect(infox, ref, Mux(reset.cond, reset.value, expr, muxType))
- case other => other map onStmt
+ case other => other.map(onStmt)
}
}
m.map(onStmt)
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index f692e513..31fa3b6f 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -8,11 +8,11 @@ import firrtl.Utils._
import firrtl.Mappers._
import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
-import firrtl.graph.{MutableDiGraph, CyclicException}
+import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import scala.collection.mutable
-import scala.util.{Try, Success, Failure}
+import scala.util.{Failure, Success, Try}
/** Replace wires with nodes in a legal, flow-forward order
*
@@ -23,11 +23,13 @@ import scala.util.{Try, Success, Failure}
class RemoveWires extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.MidForm ++
- Seq( Dependency(passes.LowerTypes),
- Dependency(passes.Legalize),
- Dependency(passes.ResolveKinds),
- Dependency(transforms.RemoveReset),
- Dependency[transforms.CheckCombLoops] )
+ Seq(
+ Dependency(passes.LowerTypes),
+ Dependency(passes.Legalize),
+ Dependency(passes.ResolveKinds),
+ Dependency(transforms.RemoveReset),
+ Dependency[transforms.CheckCombLoops]
+ )
override def optionalPrerequisites = Seq(Dependency[checks.CheckResets])
@@ -35,7 +37,7 @@ class RemoveWires extends Transform with DependencyAPIMigration {
override def invalidates(a: Transform) = a match {
case passes.ResolveKinds => true
- case _ => false
+ case _ => false
}
// Extract all expressions that are references to a Node, Wire, or Reg
@@ -44,7 +46,7 @@ class RemoveWires extends Transform with DependencyAPIMigration {
val refs = mutable.ArrayBuffer.empty[WRef]
def rec(e: Expression): Expression = {
e match {
- case ref @ WRef(_,_, WireKind | NodeKind | RegKind, _) => refs += ref
+ case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref
case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec)
case _ => // Do nothing
}
@@ -57,7 +59,8 @@ class RemoveWires extends Transform with DependencyAPIMigration {
// Transform netlist into DefNodes
private def getOrderedNodes(
netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
- regInfo: mutable.Map[WrappedExpression, DefRegister]): Try[Seq[Statement]] = {
+ regInfo: mutable.Map[WrappedExpression, DefRegister]
+ ): Try[Seq[Statement]] = {
val digraph = new MutableDiGraph[WrappedExpression]
for ((sink, (exprs, _)) <- netlist) {
digraph.addVertex(sink)
@@ -106,21 +109,22 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case reg: DefRegister =>
val resetDep = reg.reset.tpe match {
case AsyncResetType => Some(reg.reset)
- case _ => None
+ case _ => None
}
val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself
regInfo(we(WRef(reg))) = reg
netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info)
case decl: IsDeclaration => // Keep all declarations except for nodes and non-Analog wires
decls += decl
- case con @ Connect(cinfo, lhs, rhs) => kind(lhs) match {
- case WireKind =>
- // Be sure to pad the rhs since nodes get their type from the rhs
- val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
- val dinfo = wireInfo(lhs)
- netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
- case _ => otherStmts += con // Other connections just pass through
- }
+ case con @ Connect(cinfo, lhs, rhs) =>
+ kind(lhs) match {
+ case WireKind =>
+ // Be sure to pad the rhs since nodes get their type from the rhs
+ val paddedRhs = ConstantPropagation.pad(rhs, lhs.tpe)
+ val dinfo = wireInfo(lhs)
+ netlist(we(lhs)) = (Seq(paddedRhs), MultiInfo(dinfo, cinfo))
+ case _ => otherStmts += con // Other connections just pass through
+ }
case invalid @ IsInvalid(info, expr) =>
kind(expr) match {
case WireKind =>
@@ -146,8 +150,10 @@ class RemoveWires extends Transform with DependencyAPIMigration {
// If we hit a CyclicException, just abort removing wires
case Failure(c: CyclicException) =>
val problematicNode = c.node
- logger.warn(s"Cycle found in module $name, " +
- s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode")
+ logger.warn(
+ s"Cycle found in module $name, " +
+ s"wires will not be removed which can prevent optimizations! Problem node: $problematicNode"
+ )
mod
case Failure(other) => throw other
}
@@ -155,7 +161,6 @@ class RemoveWires extends Transform with DependencyAPIMigration {
}
}
-
def execute(state: CircuitState): CircuitState =
state.copy(circuit = state.circuit.map(onModule))
}
diff --git a/src/main/scala/firrtl/transforms/RenameModules.scala b/src/main/scala/firrtl/transforms/RenameModules.scala
index d37f8c39..16fd655a 100644
--- a/src/main/scala/firrtl/transforms/RenameModules.scala
+++ b/src/main/scala/firrtl/transforms/RenameModules.scala
@@ -44,7 +44,7 @@ class RenameModules extends Transform with DependencyAPIMigration {
moduleOrder.foreach(collectNameMapping(namespace.get, nameMappings))
val modulesx = state.circuit.modules.map {
- case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings)
+ case mod: Module => mod.mapStmt(onStmt(nameMappings)).mapString(nameMappings)
case ext: ExtModule => ext
}
diff --git a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
index a93087b9..14c84b91 100644
--- a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
+++ b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala
@@ -80,8 +80,7 @@ object ReplaceTruncatingArithmetic {
class ReplaceTruncatingArithmetic extends Transform with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[BlackBoxSourceHelper],
- Dependency[FixAddingNegativeLiterals] )
+ Seq(Dependency[BlackBoxSourceHelper], Dependency[FixAddingNegativeLiterals])
override def optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala
index a056c7da..7790d060 100644
--- a/src/main/scala/firrtl/transforms/SimplifyMems.scala
+++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala
@@ -33,12 +33,13 @@ class SimplifyMems extends Transform with DependencyAPIMigration {
def onExpr(e: Expression): Expression = e.map(onExpr) match {
case wr @ WRef(name, _, MemKind, _) if memAdapters.contains(name) => wr.copy(kind = WireKind)
- case e => e
+ case e => e
}
def simplifyMem(mem: DefMemory): Statement = {
val adapterDecl = DefWire(mem.info, mem.name, memType(mem))
- val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
+ val simpleMemDecl =
+ mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
val oldRT = mTarget.ref(mem.name)
val adapterConnects = memType(simpleMemDecl).fields.flatMap {
case Field(pName, Flip, pType: BundleType) =>
@@ -63,8 +64,10 @@ class SimplifyMems extends Transform with DependencyAPIMigration {
def canSimplify(mem: DefMemory) = mem.dataType match {
case at: AggregateType =>
- val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
- val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
+ val wMasks =
+ mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
+ val rwMasks =
+ mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
(wMasks ++ rwMasks).flatten.isEmpty
case _ => false
}
diff --git a/src/main/scala/firrtl/transforms/TopWiring.scala b/src/main/scala/firrtl/transforms/TopWiring.scala
index f5a5e2a3..b35fed22 100644
--- a/src/main/scala/firrtl/transforms/TopWiring.scala
+++ b/src/main/scala/firrtl/transforms/TopWiring.scala
@@ -4,7 +4,7 @@ package TopWiring
import firrtl._
import firrtl.ir._
-import firrtl.passes.{InferTypes, LowerTypes, ResolveKinds, ResolveFlows, ExpandConnects}
+import firrtl.passes.{ExpandConnects, InferTypes, LowerTypes, ResolveFlows, ResolveKinds}
import firrtl.annotations._
import firrtl.Mappers._
import firrtl.analyses.InstanceKeyGraph
@@ -13,22 +13,21 @@ import firrtl.options.Dependency
import collection.mutable
-/** Annotation for optional output files, and what directory to put those files in (absolute path) **/
-case class TopWiringOutputFilesAnnotation(dirName: String,
- outputFunction: (String,Seq[((ComponentName, Type, Boolean,
- Seq[String],String), Int)],
- CircuitState) => CircuitState) extends NoTargetAnnotation
+/** Annotation for optional output files, and what directory to put those files in (absolute path) * */
+case class TopWiringOutputFilesAnnotation(
+ dirName: String,
+ outputFunction: (String, Seq[((ComponentName, Type, Boolean, Seq[String], String), Int)],
+ CircuitState) => CircuitState)
+ extends NoTargetAnnotation
/** Annotation for indicating component to be wired, and what prefix to add to the ports that are generated */
-case class TopWiringAnnotation(target: ComponentName, prefix: String) extends
- SingleTargetAnnotation[ComponentName] {
+case class TopWiringAnnotation(target: ComponentName, prefix: String) extends SingleTargetAnnotation[ComponentName] {
def duplicate(n: ComponentName) = this.copy(target = n)
}
-
/** Punch out annotated ports out to the toplevel of the circuit.
- This also has an option to pass a function as a parmeter to generate
- custom output files as a result of the additional ports
+ * This also has an option to pass a function as a parmeter to generate
+ * custom output files as a result of the additional ports
* @note This *does* work for deduped modules
*/
class TopWiringTransform extends Transform with DependencyAPIMigration {
@@ -39,116 +38,133 @@ class TopWiringTransform extends Transform with DependencyAPIMigration {
override def invalidates(a: Transform): Boolean = a match {
case InferTypes | ResolveKinds | ResolveFlows | ExpandConnects => true
- case _ => false
+ case _ => false
}
type InstPath = Seq[String]
/** Get the names of the targets that need to be wired */
private def getSourceNames(state: CircuitState): Map[ComponentName, String] = {
- state.annotations.collect { case TopWiringAnnotation(srcname,prefix) =>
- (srcname -> prefix) }.toMap.withDefaultValue("")
+ state.annotations.collect {
+ case TopWiringAnnotation(srcname, prefix) =>
+ (srcname -> prefix)
+ }.toMap.withDefaultValue("")
}
-
/** Get the names of the modules which include the targets that need to be wired */
private def getSourceModNames(state: CircuitState): Seq[String] = {
- state.annotations.collect { case TopWiringAnnotation(ComponentName(_,ModuleName(srcmodname, _)),_) => srcmodname }
+ state.annotations.collect { case TopWiringAnnotation(ComponentName(_, ModuleName(srcmodname, _)), _) => srcmodname }
}
-
-
/** Get the Type of each wire to be connected
*
* Find the definition of each wire in sourceList, and get the type and whether or not it's a port
* Update the results in sourceMap
*/
- private def getSourceTypes(sourceList: Map[ComponentName, String],
- sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
- currentmodule: ModuleName, state: CircuitState)(s: Statement): Statement = s match {
+ private def getSourceTypes(
+ sourceList: Map[ComponentName, String],
+ sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
+ currentmodule: ModuleName,
+ state: CircuitState
+ )(s: Statement
+ ): Statement = s match {
// If target wire, add name and size to to sourceMap
case w: IsDeclaration =>
if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) {
- val (isport, tpe, prefix) = w match {
- case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
- }
- sourceMap.get(currentmodule.name) match {
- case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
- sourceMap.update(currentmodule.name, xs :+(
- (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) ))
- case None =>
- sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule),
- tpe, isport ,Seq[String](w.name), prefix))
- }
+ val (isport, tpe, prefix) = w match {
+ case d: DefWire => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case d: DefNode => (false, d.value.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case d: DefRegister => (false, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
+ }
+ sourceMap.get(currentmodule.name) match {
+ case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
+ sourceMap.update(
+ currentmodule.name,
+ xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix))
+ )
+ case None =>
+ sourceMap(currentmodule.name) = Seq(
+ (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)
+ )
+ }
}
w // Return argument unchanged (ok because DefWire has no Statement children)
// If not, apply to all children Statement
- case _ => s map getSourceTypes(sourceList, sourceMap, currentmodule, state)
+ case _ => s.map(getSourceTypes(sourceList, sourceMap, currentmodule, state))
}
-
-
/** Get the Type of each port to be connected
*
* Similar to getSourceTypes, but specifically for ports since they are not found in statements.
* Find the definition of each port in sourceList, and get the type and whether or not it's a port
* Update the results in sourceMap
*/
- private def getSourceTypesPorts(sourceList: Map[ComponentName, String], sourceMap: mutable.Map[String,
- Seq[(ComponentName, Type, Boolean, InstPath, String)]],
- currentmodule: ModuleName, state: CircuitState)(s: Port): CircuitState = s match {
+ private def getSourceTypesPorts(
+ sourceList: Map[ComponentName, String],
+ sourceMap: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
+ currentmodule: ModuleName,
+ state: CircuitState
+ )(s: Port
+ ): CircuitState = s match {
// If target port, add name and size to to sourceMap
case w: IsDeclaration =>
if (sourceList.keys.toSeq.contains(ComponentName(w.name, currentmodule))) {
- val (isport, tpe, prefix) = w match {
- case d: Port => (true, d.tpe, sourceList(ComponentName(w.name,currentmodule)))
- case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
- }
- sourceMap.get(currentmodule.name) match {
- case Some(xs:Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
- sourceMap.update(currentmodule.name, xs :+(
- (ComponentName(w.name,currentmodule), tpe, isport ,Seq[String](w.name), prefix) ))
- case None =>
- sourceMap(currentmodule.name) = Seq((ComponentName(w.name,currentmodule),
- tpe, isport ,Seq[String](w.name), prefix))
- }
+ val (isport, tpe, prefix) = w match {
+ case d: Port => (true, d.tpe, sourceList(ComponentName(w.name, currentmodule)))
+ case _ => throw new Exception(s"Cannot wire this type of declaration! ${w.serialize}")
+ }
+ sourceMap.get(currentmodule.name) match {
+ case Some(xs: Seq[(ComponentName, Type, Boolean, InstPath, String)]) =>
+ sourceMap.update(
+ currentmodule.name,
+ xs :+ ((ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix))
+ )
+ case None =>
+ sourceMap(currentmodule.name) = Seq(
+ (ComponentName(w.name, currentmodule), tpe, isport, Seq[String](w.name), prefix)
+ )
+ }
}
state // Return argument unchanged (ok because DefWire has no Statement children)
// If not, apply to all children Statement
case _ => state
}
-
/** Create a map of Module name to target wires under this module
*
* These paths are relative but cross module (they refer down through instance hierarchy)
*/
- private def getSourcesMap(state: CircuitState): Map[String,Seq[(ComponentName, Type, Boolean, InstPath, String)]] = {
+ private def getSourcesMap(state: CircuitState): Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] = {
val sSourcesModNames = getSourceModNames(state)
val sSourcesNames = getSourceNames(state)
val instGraph = firrtl.analyses.InstanceKeyGraph(state.circuit)
- val cMap = instGraph.getChildInstances.map{ case (m, wdis) =>
- (m -> wdis.map{ case wdi => (wdi.name, wdi.module) }.toSeq) }.toMap
+ val cMap = instGraph.getChildInstances.map {
+ case (m, wdis) =>
+ (m -> wdis.map { case wdi => (wdi.name, wdi.module) }.toSeq)
+ }.toMap
val topSort = instGraph.moduleOrder.reverse
// Map of component name to relative instance paths that result in a debug wire
val sourcemods: mutable.Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]] =
mutable.Map(sSourcesModNames.map(_ -> Seq()): _*)
- state.circuit.modules.foreach { m => m map
- getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) }
- state.circuit.modules.foreach { m => m.ports.foreach {
- p => Seq(p) map
- getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)) , state) }}
+ state.circuit.modules.foreach { m =>
+ m.map(getSourceTypes(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state))
+ }
+ state.circuit.modules.foreach { m =>
+ m.ports.foreach { p =>
+ Seq(p).map(
+ getSourceTypesPorts(sSourcesNames, sourcemods, ModuleName(m.name, CircuitName(state.circuit.main)), state)
+ )
+ }
+ }
for (mod <- topSort) {
- val seqChildren: Seq[(ComponentName,Type,Boolean,InstPath,String)] = cMap(mod.name).flatMap {
+ val seqChildren: Seq[(ComponentName, Type, Boolean, InstPath, String)] = cMap(mod.name).flatMap {
case (inst, module) =>
- sourcemods.get(module).map( _.map { case (a,b,c,path,p) => (a,b,c, inst +: path, p)})
+ sourcemods.get(module).map(_.map { case (a, b, c, path, p) => (a, b, c, inst +: path, p) })
}.flatten
if (seqChildren.nonEmpty) {
sourcemods(mod.name) = sourcemods.getOrElse(mod.name, Seq()) ++ seqChildren
@@ -158,108 +174,113 @@ class TopWiringTransform extends Transform with DependencyAPIMigration {
sourcemods.toMap
}
-
-
/** Process a given DefModule
*
* For Modules that contain or are in the parent hierarchy to modules containing target wires
* 1. Add ports for each target wire this module is parent to
* 2. Connect these ports to ports of instances that are parents to some number of target wires
*/
- private def onModule(sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
- portnamesmap : mutable.Map[String,String],
- instgraph : firrtl.analyses.InstanceKeyGraph,
- namespacemap : Map[String, Namespace])
- (module: DefModule): DefModule = {
+ private def onModule(
+ sources: Map[String, Seq[(ComponentName, Type, Boolean, InstPath, String)]],
+ portnamesmap: mutable.Map[String, String],
+ instgraph: firrtl.analyses.InstanceKeyGraph,
+ namespacemap: Map[String, Namespace]
+ )(module: DefModule
+ ): DefModule = {
val namespace = namespacemap(module.name)
sources.get(module.name) match {
case Some(p) =>
- val newPorts = p.map{ case (ComponentName(cname,_), tpe, _ , path, prefix) => {
- val newportname = portnamesmap.get(prefix + path.mkString("_")) match {
- case Some(pn) => pn
- case None => {
- val npn = namespace.newName(prefix + path.mkString("_"))
- portnamesmap(prefix + path.mkString("_")) = npn
- npn
- }
+ val newPorts = p.map {
+ case (ComponentName(cname, _), tpe, _, path, prefix) => {
+ val newportname = portnamesmap.get(prefix + path.mkString("_")) match {
+ case Some(pn) => pn
+ case None => {
+ val npn = namespace.newName(prefix + path.mkString("_"))
+ portnamesmap(prefix + path.mkString("_")) = npn
+ npn
}
- Port(NoInfo, newportname, Output, tpe)
- } }
+ }
+ Port(NoInfo, newportname, Output, tpe)
+ }
+ }
// Add connections to Module
val childInstances = instgraph.getChildInstances.toMap
module match {
case m: Module =>
- val connections: Seq[Connect] = p.map { case (ComponentName(cname,_), _, _ , path, prefix) =>
+ val connections: Seq[Connect] = p.map {
+ case (ComponentName(cname, _), _, _, path, prefix) =>
val modRef = portnamesmap.get(prefix + path.mkString("_")) match {
- case Some(pn) => WRef(pn)
- case None => {
- portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_"))
- WRef(portnamesmap(prefix + path.mkString("_")))
- }
+ case Some(pn) => WRef(pn)
+ case None => {
+ portnamesmap(prefix + path.mkString("_")) = namespace.newName(prefix + path.mkString("_"))
+ WRef(portnamesmap(prefix + path.mkString("_")))
+ }
}
path.size match {
- case 1 => {
- val leafRef = WRef(path.head.mkString(""))
- Connect(NoInfo, modRef, leafRef)
- }
- case _ => {
- val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match {
- case Some(ipn) => ipn
- case None => {
- val instmod = childInstances(module.name).collectFirst {
- case wdi if wdi.name == path.head => wdi.module}.get
- val instnamespace = namespacemap(instmod)
- portnamesmap(prefix + path.tail.mkString("_")) =
- instnamespace.newName(prefix + path.tail.mkString("_"))
- portnamesmap(prefix + path.tail.mkString("_"))
- }
- }
- val instRef = WSubField(WRef(path.head), instportname)
- Connect(NoInfo, modRef, instRef)
+ case 1 => {
+ val leafRef = WRef(path.head.mkString(""))
+ Connect(NoInfo, modRef, leafRef)
+ }
+ case _ => {
+ val instportname = portnamesmap.get(prefix + path.tail.mkString("_")) match {
+ case Some(ipn) => ipn
+ case None => {
+ val instmod = childInstances(module.name).collectFirst {
+ case wdi if wdi.name == path.head => wdi.module
+ }.get
+ val instnamespace = namespacemap(instmod)
+ portnamesmap(prefix + path.tail.mkString("_")) =
+ instnamespace.newName(prefix + path.tail.mkString("_"))
+ portnamesmap(prefix + path.tail.mkString("_"))
+ }
+ }
+ val instRef = WSubField(WRef(path.head), instportname)
+ Connect(NoInfo, modRef, instRef)
}
}
}
- m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections ))
+ m.copy(ports = m.ports ++ newPorts, body = Block(Seq(m.body) ++ connections))
case e: ExtModule =>
e.copy(ports = e.ports ++ newPorts)
- }
+ }
case None => module // unchanged if no paths
}
}
- /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */
- def topWiringDummyOutputFilesFunction(dir: String,
- mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)],
- state: CircuitState): CircuitState = {
- state
+ /** Dummy function that is currently unused. Can be used to fill an outputFunction requirment in the future */
+ def topWiringDummyOutputFilesFunction(
+ dir: String,
+ mapping: Seq[((ComponentName, Type, Boolean, InstPath, String), Int)],
+ state: CircuitState
+ ): CircuitState = {
+ state
}
-
def execute(state: CircuitState): CircuitState = {
- val outputTuples: Seq[(String,
- (String,Seq[((ComponentName, Type, Boolean, InstPath, String), Int)],
- CircuitState) => CircuitState)] = state.annotations.collect {
- case TopWiringOutputFilesAnnotation(td,of) => (td, of) }
+ val outputTuples: Seq[
+ (String, (String, Seq[((ComponentName, Type, Boolean, InstPath, String), Int)], CircuitState) => CircuitState)
+ ] = state.annotations.collect {
+ case TopWiringOutputFilesAnnotation(td, of) => (td, of)
+ }
// Do actual work of this transform
val sources = getSourcesMap(state)
val (nstate, nmappings) = if (sources.nonEmpty) {
- val portnamesmap: mutable.Map[String,String] = mutable.Map()
+ val portnamesmap: mutable.Map[String, String] = mutable.Map()
val instgraph = InstanceKeyGraph(state.circuit)
- val namespacemap = state.circuit.modules.map{ case m => (m.name -> Namespace(m)) }.toMap
- val modulesx = state.circuit.modules map onModule(sources, portnamesmap, instgraph, namespacemap)
+ val namespacemap = state.circuit.modules.map { case m => (m.name -> Namespace(m)) }.toMap
+ val modulesx = state.circuit.modules.map(onModule(sources, portnamesmap, instgraph, namespacemap))
val newCircuit = state.circuit.copy(modules = modulesx)
val mappings = sources(state.circuit.main).zipWithIndex
val annosx = state.annotations.filter {
case _: TopWiringAnnotation => false
- case _ => true
+ case _ => true
}
(state.copy(circuit = newCircuit, annotations = annosx), mappings)
- }
- else { (state, List.empty) }
+ } else { (state, List.empty) }
//Generate output files based on the mapping.
outputTuples.map { case (dir, outputfunction) => outputfunction(dir, nmappings, nstate) }
nstate
diff --git a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
index 7370fcfb..cdbee495 100644
--- a/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
+++ b/src/main/scala/firrtl/transforms/formal/AssertSubmoduleAssumptions.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms.formal
import firrtl.ir.{Circuit, Formal, Statement, Verification}
@@ -7,7 +6,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, Transform}
import firrtl.annotations.NoTargetAnnotation
import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption}
-
/**
* Assert Submodule Assumptions
*
@@ -16,12 +14,13 @@ import firrtl.options.{PreservesAll, RegisteredTransform, ShellOption}
* overly restrictive assume in a child module can prevent the model checker
* from searching valid inputs and states in the parent module.
*/
-class AssertSubmoduleAssumptions extends Transform
- with RegisteredTransform
- with DependencyAPIMigration
- with PreservesAll[Transform] {
+class AssertSubmoduleAssumptions
+ extends Transform
+ with RegisteredTransform
+ with DependencyAPIMigration
+ with PreservesAll[Transform] {
- override def prerequisites: Seq[TransformDependency] = Seq.empty
+ override def prerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisiteOf: Seq[TransformDependency] =
firrtl.stage.Forms.MidEmitters
@@ -29,9 +28,10 @@ class AssertSubmoduleAssumptions extends Transform
val options = Seq(
new ShellOption[Unit](
longOption = "no-asa",
- toAnnotationSeq = (_: Unit) => Seq(
- DontAssertSubmoduleAssumptionsAnnotation),
- helpText = "Disable assert submodule assumptions" ) )
+ toAnnotationSeq = (_: Unit) => Seq(DontAssertSubmoduleAssumptionsAnnotation),
+ helpText = "Disable assert submodule assumptions"
+ )
+ )
def assertAssumption(s: Statement): Statement = s match {
case Verification(Formal.Assume, info, clk, cond, en, msg) =>
@@ -50,8 +50,7 @@ class AssertSubmoduleAssumptions extends Transform
}
def execute(state: CircuitState): CircuitState = {
- val noASA = state.annotations.contains(
- DontAssertSubmoduleAssumptionsAnnotation)
+ val noASA = state.annotations.contains(DontAssertSubmoduleAssumptionsAnnotation)
if (noASA) {
logger.info("Skipping assert submodule assumptions")
state
diff --git a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala
index ddead331..5928c79c 100644
--- a/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala
+++ b/src/main/scala/firrtl/transforms/formal/ConvertAsserts.scala
@@ -14,10 +14,8 @@ import firrtl.options.Dependency
object ConvertAsserts extends Transform with DependencyAPIMigration {
override def prerequisites = Nil
override def optionalPrerequisites = Nil
- override def optionalPrerequisiteOf = Seq(
- Dependency[VerilogEmitter],
- Dependency[MinimumVerilogEmitter],
- Dependency[RemoveVerificationStatements])
+ override def optionalPrerequisiteOf =
+ Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter], Dependency[RemoveVerificationStatements])
override def invalidates(a: Transform): Boolean = false
@@ -28,7 +26,7 @@ object ConvertAsserts extends Transform with DependencyAPIMigration {
val stop = Stop(i, 1, clk, gatedNPred)
msg match {
case StringLit("") => stop
- case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop)
+ case _ => Block(Print(i, msg, Nil, clk, gatedNPred), stop)
}
case s => s.mapStmt(convertAsserts)
}
diff --git a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala
index 72890c07..1e6d2c72 100644
--- a/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala
+++ b/src/main/scala/firrtl/transforms/formal/RemoveVerificationStatements.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms.formal
import firrtl.ir.{Circuit, EmptyStmt, Statement, Verification}
@@ -6,7 +5,6 @@ import firrtl.{CircuitState, DependencyAPIMigration, MinimumVerilogEmitter, Tran
import firrtl.options.{Dependency, PreservesAll, StageUtils}
import firrtl.stage.TransformManager.TransformDependency
-
/**
* Remove Verification Statements
*
@@ -14,15 +12,12 @@ import firrtl.stage.TransformManager.TransformDependency
* This is intended to be required by the Verilog emitter to ensure compatibility
* with the Verilog 2001 standard.
*/
-class RemoveVerificationStatements extends Transform
- with DependencyAPIMigration
- with PreservesAll[Transform] {
+class RemoveVerificationStatements extends Transform with DependencyAPIMigration with PreservesAll[Transform] {
- override def prerequisites: Seq[TransformDependency] = Seq.empty
+ override def prerequisites: Seq[TransformDependency] = Seq.empty
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency(ConvertAsserts))
override def optionalPrerequisiteOf: Seq[TransformDependency] =
- Seq( Dependency[VerilogEmitter],
- Dependency[MinimumVerilogEmitter])
+ Seq(Dependency[VerilogEmitter], Dependency[MinimumVerilogEmitter])
private var removedCounter = 0
@@ -43,11 +38,13 @@ class RemoveVerificationStatements extends Transform
def execute(state: CircuitState): CircuitState = {
val newState = state.copy(circuit = run(state.circuit))
if (removedCounter > 0) {
- StageUtils.dramaticWarning(s"$removedCounter verification statements " +
- "(assert, assume or cover) " +
- "were removed when compiling to Verilog because the basic Verilog " +
- "standard does not support them. If this was not intended, compile " +
- "to System Verilog instead using the `-X sverilog` compiler flag.")
+ StageUtils.dramaticWarning(
+ s"$removedCounter verification statements " +
+ "(assert, assume or cover) " +
+ "were removed when compiling to Verilog because the basic Verilog " +
+ "standard does not support them. If this was not intended, compile " +
+ "to System Verilog instead using the `-X sverilog` compiler flag."
+ )
}
newState
}