diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/transforms/DeadCodeElimination.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/transforms/DeadCodeElimination.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/DeadCodeElimination.scala | 130 |
1 files changed, 69 insertions, 61 deletions
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index c883bdfb..fb1bd1f6 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -1,4 +1,3 @@ - package firrtl.transforms import firrtl._ @@ -8,7 +7,7 @@ import firrtl.annotations._ import firrtl.graph._ import firrtl.analyses.InstanceKeyGraph import firrtl.Mappers._ -import firrtl.Utils.{throwInternalError, kind} +import firrtl.Utils.{kind, throwInternalError} import firrtl.MemoizedHash._ import firrtl.options.{Dependency, RegisteredTransform, ShellOption} @@ -29,29 +28,34 @@ import collection.mutable * circumstances of their instantiation in their parent module, they will still not be removed. To * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. */ -class DeadCodeElimination extends Transform +class DeadCodeElimination + extends Transform with ResolvedAnnotationPaths with RegisteredTransform with DependencyAPIMigration { override def prerequisites = firrtl.stage.Forms.LowForm ++ - Seq( Dependency(firrtl.passes.RemoveValidIf), - Dependency[firrtl.transforms.ConstantPropagation], - Dependency(firrtl.passes.memlib.VerilogMemDelays), - Dependency(firrtl.passes.SplitExpressions), - Dependency[firrtl.transforms.CombineCats], - Dependency(passes.CommonSubexpressionElimination) ) + Seq( + Dependency(firrtl.passes.RemoveValidIf), + Dependency[firrtl.transforms.ConstantPropagation], + Dependency(firrtl.passes.memlib.VerilogMemDelays), + Dependency(firrtl.passes.SplitExpressions), + Dependency[firrtl.transforms.CombineCats], + Dependency(passes.CommonSubexpressionElimination) + ) override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename], - Dependency(passes.VerilogPrep), - Dependency[firrtl.AddDescriptionNodes] ) + Seq( + Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename], + Dependency(passes.VerilogPrep), + Dependency[firrtl.AddDescriptionNodes] + ) override def invalidates(a: Transform) = false @@ -59,7 +63,9 @@ class DeadCodeElimination extends Transform new ShellOption[Unit]( longOption = "no-dce", toAnnotationSeq = (_: Unit) => Seq(NoDCEAnnotation), - helpText = "Disable dead code elimination" ) ) + helpText = "Disable dead code elimination" + ) + ) /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */ private type LogicNode = MemoizedHash[WrappedExpression] @@ -72,6 +78,7 @@ class DeadCodeElimination extends Transform val loweredName = LowerTypes.loweredName(component.name.split('.')) apply(component.module.name, WRef(loweredName)) } + /** External Modules are representated as a single node driven by all inputs and driving all * outputs */ @@ -87,7 +94,7 @@ class DeadCodeElimination extends Transform def rec(e: Expression): Expression = { e match { case ref @ (_: WRef | _: WSubField) => refs += ref - case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.map(rec) case ignore @ (_: Literal) => // Do nothing case unexpected => throwInternalError() } @@ -98,9 +105,7 @@ class DeadCodeElimination extends Transform } // Gets all dependencies and constructs LogicNodes from them - private def getDepsImpl(mname: String, - instMap: collection.Map[String, String]) - (expr: Expression): Seq[LogicNode] = + private def getDepsImpl(mname: String, instMap: collection.Map[String, String])(expr: Expression): Seq[LogicNode] = extractRefs(expr).map { e => if (kind(e) == InstanceKind) { val (inst, tail) = Utils.splitRef(e) @@ -110,11 +115,12 @@ class DeadCodeElimination extends Transform } } - /** Construct the dependency graph within this module */ - private def setupDepGraph(depGraph: MutableDiGraph[LogicNode], - instMap: collection.Map[String, String]) - (mod: Module): Unit = { + private def setupDepGraph( + depGraph: MutableDiGraph[LogicNode], + instMap: collection.Map[String, String] + )(mod: Module + ): Unit = { def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) def onStmt(stmt: Statement): Unit = stmt match { @@ -150,7 +156,7 @@ class DeadCodeElimination extends Transform val node = getDeps(loc) match { case Seq(elt) => elt } getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref)) // Simulation constructs are treated as top-level outputs - case Stop(_,_, clk, en) => + case Stop(_, _, clk, en) => Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Print(_, _, args, clk, en) => (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) @@ -172,12 +178,14 @@ class DeadCodeElimination extends Transform } // TODO Make immutable? - private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]], - doTouchExtMods: Set[String], - c: Circuit): MutableDiGraph[LogicNode] = { + private def createDependencyGraph( + instMaps: collection.Map[String, collection.Map[String, String]], + doTouchExtMods: Set[String], + c: Circuit + ): MutableDiGraph[LogicNode] = { val depGraph = new MutableDiGraph[LogicNode] c.modules.foreach { - case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) + case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) case ext: ExtModule => // Connect all inputs to all outputs val node = LogicNode(ext) @@ -205,23 +213,25 @@ class DeadCodeElimination extends Transform depGraph } - private def deleteDeadCode(instMap: collection.Map[String, String], - deadNodes: collection.Set[LogicNode], - moduleMap: collection.Map[String, DefModule], - renames: RenameMap, - topName: String, - doTouchExtMods: Set[String]) - (mod: DefModule): Option[DefModule] = { + private def deleteDeadCode( + instMap: collection.Map[String, String], + deadNodes: collection.Set[LogicNode], + moduleMap: collection.Map[String, DefModule], + renames: RenameMap, + topName: String, + doTouchExtMods: Set[String] + )(mod: DefModule + ): Option[DefModule] = { // For log-level debug def deleteMsg(decl: IsDeclaration): String = { val tpe = decl match { - case _: DefNode => "node" + case _: DefNode => "node" case _: DefRegister => "reg" - case _: DefWire => "wire" - case _: Port => "port" - case _: DefMemory => "mem" + case _: DefWire => "wire" + case _: Port => "port" + case _: DefMemory => "mem" case (_: DefInstance | _: WDefInstance) => "inst" - case _: Module => "module" + case _: Module => "module" case _: ExtModule => "extmodule" } val ref = decl match { @@ -237,7 +247,7 @@ class DeadCodeElimination extends Transform def deleteIfNotEnabled(stmt: Statement, en: Expression): Statement = en match { case UIntLiteral(v, _) if v == BigInt(0) => EmptyStmt - case _ => stmt + case _ => stmt } def onStmt(stmt: Statement): Statement = { @@ -256,12 +266,11 @@ class DeadCodeElimination extends Transform logger.debug(deleteMsg(decl)) renames.delete(decl.name) EmptyStmt - } - else decl - case print: Print => deleteIfNotEnabled(print, print.en) - case stop: Stop => deleteIfNotEnabled(stop, stop.en) + } else decl + case print: Print => deleteIfNotEnabled(print, print.en) + case stop: Stop => deleteIfNotEnabled(stop, stop.en) case formal: Verification => deleteIfNotEnabled(formal, formal.en) - case con: Connect => + case con: Connect => val node = getDeps(con.loc) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else con case Attach(info, exprs) => // If any exprs are dead then all are @@ -270,7 +279,7 @@ class DeadCodeElimination extends Transform case IsInvalid(info, expr) => val node = getDeps(expr) match { case Seq(elt) => elt } if (deadNodes.contains(node)) EmptyStmt else IsInvalid(info, expr) - case block: Block => block map onStmt + case block: Block => block.map(onStmt) case other => other } stmtx match { // Check if module empty @@ -300,8 +309,7 @@ class DeadCodeElimination extends Transform if (portsx.isEmpty && doTouchExtMods.contains(ext.name)) { logger.debug(deleteMsg(mod)) None - } - else { + } else { if (ext.ports != portsx) throwInternalError() // Sanity check Some(ext.copy(ports = portsx)) } @@ -309,14 +317,13 @@ class DeadCodeElimination extends Transform } - def run(state: CircuitState, - dontTouches: Seq[LogicNode], - doTouchExtMods: Set[String]): CircuitState = { + def run(state: CircuitState, dontTouches: Seq[LogicNode], doTouchExtMods: Set[String]): CircuitState = { val c = state.circuit val moduleMap = c.modules.map(m => m.name -> m).toMap val iGraph = InstanceKeyGraph(c) - val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) => - k.module -> v.map(i => i.name -> i.module).toMap + val moduleDeps = iGraph.graph.getEdgeMap.map({ + case (k, v) => + k.module -> v.map(i => i.name -> i.module).toMap }) val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) @@ -347,11 +354,12 @@ class DeadCodeElimination extends Transform // themselves. We iterate over the modules in a topological order from leaves to the top. The // current status of the modulesxMap is used to either delete instances or update their types val modulesxMap = mutable.HashMap.empty[String, DefModule] - topoSortedModules.foreach { case mod => - deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match { - case Some(m) => modulesxMap += m.name -> m - case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) - } + topoSortedModules.foreach { + case mod => + deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main, doTouchExtMods)(mod) match { + case Some(m) => modulesxMap += m.name -> m + case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) + } } // Preserve original module order |
