aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms/DeadCodeElimination.scala
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms/DeadCodeElimination.scala')
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala130
1 files changed, 69 insertions, 61 deletions
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index c883bdfb..fb1bd1f6 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -1,4 +1,3 @@
-
package firrtl.transforms
import firrtl._
@@ -8,7 +7,7 @@ import firrtl.annotations._
import firrtl.graph._
import firrtl.analyses.InstanceKeyGraph
import firrtl.Mappers._
-import firrtl.Utils.{throwInternalError, kind}
+import firrtl.Utils.{kind, throwInternalError}
import firrtl.MemoizedHash._
import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
@@ -29,29 +28,34 @@ import collection.mutable
* circumstances of their instantiation in their parent module, they will still not be removed. To
* remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication.
*/
-class DeadCodeElimination extends Transform
+class DeadCodeElimination
+ extends Transform
with ResolvedAnnotationPaths
with RegisteredTransform
with DependencyAPIMigration {
override def prerequisites = firrtl.stage.Forms.LowForm ++
- Seq( Dependency(firrtl.passes.RemoveValidIf),
- Dependency[firrtl.transforms.ConstantPropagation],
- Dependency(firrtl.passes.memlib.VerilogMemDelays),
- Dependency(firrtl.passes.SplitExpressions),
- Dependency[firrtl.transforms.CombineCats],
- Dependency(passes.CommonSubexpressionElimination) )
+ Seq(
+ Dependency(firrtl.passes.RemoveValidIf),
+ Dependency[firrtl.transforms.ConstantPropagation],
+ Dependency(firrtl.passes.memlib.VerilogMemDelays),
+ Dependency(firrtl.passes.SplitExpressions),
+ Dependency[firrtl.transforms.CombineCats],
+ Dependency(passes.CommonSubexpressionElimination)
+ )
override def optionalPrerequisites = Seq.empty
override def optionalPrerequisiteOf =
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename],
- Dependency(passes.VerilogPrep),
- Dependency[firrtl.AddDescriptionNodes] )
+ Seq(
+ Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename],
+ Dependency(passes.VerilogPrep),
+ Dependency[firrtl.AddDescriptionNodes]
+ )
override def invalidates(a: Transform) = false
@@ -59,7 +63,9 @@ class DeadCodeElimination extends Transform
new ShellOption[Unit](
longOption = "no-dce",
toAnnotationSeq = (_: Unit) => Seq(NoDCEAnnotation),
- helpText = "Disable dead code elimination" ) )
+ helpText = "Disable dead code elimination"
+ )
+ )
/** Based on LogicNode ins CheckCombLoops, currently kind of faking it */
private type LogicNode = MemoizedHash[WrappedExpression]
@@ -72,6 +78,7 @@ class DeadCodeElimination extends Transform
val loweredName = LowerTypes.loweredName(component.name.split('.'))
apply(component.module.name, WRef(loweredName))
}
+
/** External Modules are representated as a single node driven by all inputs and driving all
* outputs
*/
@@ -87,7 +94,7 @@ class DeadCodeElimination extends Transform
def rec(e: Expression): Expression = {
e match {
case ref @ (_: WRef | _: WSubField) => refs += ref
- case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.map(rec)
case ignore @ (_: Literal) => // Do nothing
case unexpected => throwInternalError()
}
@@ -98,9 +105,7 @@ class DeadCodeElimination extends Transform
}
// Gets all dependencies and constructs LogicNodes from them
- private def getDepsImpl(mname: String,
- instMap: collection.Map[String, String])
- (expr: Expression): Seq[LogicNode] =
+ private def getDepsImpl(mname: String, instMap: collection.Map[String, String])(expr: Expression): Seq[LogicNode] =
extractRefs(expr).map { e =>
if (kind(e) == InstanceKind) {
val (inst, tail) = Utils.splitRef(e)
@@ -110,11 +115,12 @@ class DeadCodeElimination extends Transform
}
}
-
/** Construct the dependency graph within this module */
- private def setupDepGraph(depGraph: MutableDiGraph[LogicNode],
- instMap: collection.Map[String, String])
- (mod: Module): Unit = {
+ private def setupDepGraph(
+ depGraph: MutableDiGraph[LogicNode],
+ instMap: collection.Map[String, String]
+ )(mod: Module
+ ): Unit = {
def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
def onStmt(stmt: Statement): Unit = stmt match {
@@ -150,7 +156,7 @@ class DeadCodeElimination extends Transform
val node = getDeps(loc) match { case Seq(elt) => elt }
getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref))
// Simulation constructs are treated as top-level outputs
- case Stop(_,_, clk, en) =>
+ case Stop(_, _, clk, en) =>
Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Print(_, _, args, clk, en) =>
(args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
@@ -172,12 +178,14 @@ class DeadCodeElimination extends Transform
}
// TODO Make immutable?
- private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]],
- doTouchExtMods: Set[String],
- c: Circuit): MutableDiGraph[LogicNode] = {
+ private def createDependencyGraph(
+ instMaps: collection.Map[String, collection.Map[String, String]],
+ doTouchExtMods: Set[String],
+ c: Circuit
+ ): MutableDiGraph[LogicNode] = {
val depGraph = new MutableDiGraph[LogicNode]
c.modules.foreach {
- case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
+ case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
case ext: ExtModule =>
// Connect all inputs to all outputs
val node = LogicNode(ext)
@@ -205,23 +213,25 @@ class DeadCodeElimination extends Transform
depGraph
}
- private def deleteDeadCode(instMap: collection.Map[String, String],
- deadNodes: collection.Set[LogicNode],
- moduleMap: collection.Map[String, DefModule],
- renames: RenameMap,
- topName: String,
- doTouchExtMods: Set[String])
- (mod: DefModule): Option[DefModule] = {
+ private def deleteDeadCode(
+ instMap: collection.Map[String, String],
+ deadNodes: collection.Set[LogicNode],
+ moduleMap: collection.Map[String, DefModule],
+ renames: RenameMap,
+ topName: String,
+ doTouchExtMods: Set[String]
+ )(mod: DefModule
+ ): Option[DefModule] = {
// For log-level debug
def deleteMsg(decl: IsDeclaration): String = {
val tpe = decl match {
- case _: DefNode => "node"
+ case _: DefNode => "node"
case _: DefRegister => "reg"
- case _: DefWire => "wire"
- case _: Port => "port"
- case _: DefMemory => "mem"
+ case _: DefWire => "wire"
+ case _: Port => "port"
+ case _: DefMemory => "mem"
case (_: DefInstance | _: WDefInstance) => "inst"
- case _: Module => "module"
+ case _: Module => "module"
case _: ExtModule => "extmodule"
}
val ref = decl match {
@@ -237,7 +247,7 @@ class DeadCodeElimination extends Transform
def deleteIfNotEnabled(stmt: Statement, en: Expression): Statement = en match {
case UIntLiteral(v, _) if v == BigInt(0) => EmptyStmt
- case _ => stmt
+ case _ => stmt
}
def onStmt(stmt: Statement): Statement = {
@@ -256,12 +266,11 @@ class DeadCodeElimination extends Transform
logger.debug(deleteMsg(decl))
renames.delete(decl.name)
EmptyStmt
- }
- else decl
- case print: Print => deleteIfNotEnabled(print, print.en)
- case stop: Stop => deleteIfNotEnabled(stop, stop.en)
+ } else decl
+ case print: Print => deleteIfNotEnabled(print, print.en)
+ case stop: Stop => deleteIfNotEnabled(stop, stop.en)
case formal: Verification => deleteIfNotEnabled(formal, formal.en)
- case con: Connect =>
+ case con: Connect =>
val node = getDeps(con.loc) match { case Seq(elt) => elt }
if (deadNodes.contains(node)) EmptyStmt else con
case Attach(info, exprs) => // If any exprs are dead then all are
@@ -270,7 +279,7 @@ class DeadCodeElimination extends Transform
case IsInvalid(info, expr) =>
val node = getDeps(expr) match { case Seq(elt) => elt }
if (deadNodes.contains(node)) EmptyStmt else IsInvalid(info, expr)
- case block: Block => block map onStmt
+ case block: Block => block.map(onStmt)
case other => other
}
stmtx match { // Check if module empty
@@ -300,8 +309,7 @@ class DeadCodeElimination extends Transform
if (portsx.isEmpty && doTouchExtMods.contains(ext.name)) {
logger.debug(deleteMsg(mod))
None
- }
- else {
+ } else {
if (ext.ports != portsx) throwInternalError() // Sanity check
Some(ext.copy(ports = portsx))
}
@@ -309,14 +317,13 @@ class DeadCodeElimination extends Transform
}
- def run(state: CircuitState,
- dontTouches: Seq[LogicNode],
- doTouchExtMods: Set[String]): CircuitState = {
+ def run(state: CircuitState, dontTouches: Seq[LogicNode], doTouchExtMods: Set[String]): CircuitState = {
val c = state.circuit
val moduleMap = c.modules.map(m => m.name -> m).toMap
val iGraph = InstanceKeyGraph(c)
- val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) =>
- k.module -> v.map(i => i.name -> i.module).toMap
+ val moduleDeps = iGraph.graph.getEdgeMap.map({
+ case (k, v) =>
+ k.module -> v.map(i => i.name -> i.module).toMap
})
val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
@@ -347,11 +354,12 @@ class DeadCodeElimination extends Transform
// themselves. We iterate over the modules in a topological order from leaves to the top. The
// current status of the modulesxMap is used to either delete instances or update their types
val modulesxMap = mutable.HashMap.empty[String, DefModule]
- topoSortedModules.foreach { case mod =>
- deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match {
- case Some(m) => modulesxMap += m.name -> m
- case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
- }
+ topoSortedModules.foreach {
+ case mod =>
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match {
+ case Some(m) => modulesxMap += m.name -> m
+ case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
+ }
}
// Preserve original module order