aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms/DeadCodeElimination.scala')
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala296
1 files changed, 296 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
new file mode 100644
index 00000000..5199276c
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -0,0 +1,296 @@
+
+package firrtl.transforms
+
+import firrtl._
+import firrtl.ir._
+import firrtl.passes._
+import firrtl.annotations._
+import firrtl.graph._
+import firrtl.analyses.InstanceGraph
+import firrtl.Mappers._
+import firrtl.WrappedExpression._
+import firrtl.Utils.{throwInternalError, toWrappedExpression, kind}
+import firrtl.MemoizedHash._
+import wiring.WiringUtils.getChildrenMap
+
+import collection.mutable
+import java.io.{File, FileWriter}
+
+/** Dead Code Elimination (DCE)
+ *
+ * Performs DCE by constructing a global dependency graph starting with top-level outputs, external
+ * module ports, and simulation constructs as circuit sinks. External modules can optionally be
+ * eligible for DCE via the [[OptimizableExtModuleAnnotation]].
+ *
+ * Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all
+ * eligible for removal. Components marked with a [[DontTouchAnnotation]] will be treated as a
+ * circuit sink and thus anything that drives such a marked component will NOT be removed.
+ *
+ * This transform preserves deduplication. All instances of a given [[DefModule]] are treated as
+ * the same individual module. Thus, while certain instances may have dead code due to the
+ * circumstances of their instantiation in their parent module, they will still not be removed. To
+ * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication.
+ */
+class DeadCodeElimination extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */
+ private type LogicNode = MemoizedHash[WrappedExpression]
+ private object LogicNode {
+ def apply(moduleName: String, expr: Expression): LogicNode =
+ WrappedExpression(Utils.mergeRef(WRef(moduleName), expr))
+ def apply(moduleName: String, name: String): LogicNode = apply(moduleName, WRef(name))
+ def apply(component: ComponentName): LogicNode = {
+ // Currently only leaf nodes are supported TODO implement
+ val loweredName = LowerTypes.loweredName(component.name.split('.'))
+ apply(component.module.name, WRef(loweredName))
+ }
+ /** External Modules are representated as a single node driven by all inputs and driving all
+ * outputs
+ */
+ def apply(ext: ExtModule): LogicNode = LogicNode(ext.name, ext.name)
+ }
+
+ /** Expression used to represent outputs in the circuit (# is illegal in names) */
+ private val circuitSink = LogicNode("#Top", "#Sink")
+
+ /** Extract all References and SubFields from a possibly nested Expression */
+ def extractRefs(expr: Expression): Seq[Expression] = {
+ val refs = mutable.ArrayBuffer.empty[Expression]
+ def rec(e: Expression): Expression = {
+ e match {
+ case ref @ (_: WRef | _: WSubField) => refs += ref
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case ignore @ (_: Literal) => // Do nothing
+ case unexpected => throwInternalError
+ }
+ e
+ }
+ rec(expr)
+ refs
+ }
+
+ // Gets all dependencies and constructs LogicNodes from them
+ private def getDepsImpl(mname: String,
+ instMap: collection.Map[String, String])
+ (expr: Expression): Seq[LogicNode] =
+ extractRefs(expr).map { e =>
+ if (kind(e) == InstanceKind) {
+ val (inst, tail) = Utils.splitRef(e)
+ LogicNode(instMap(inst.name), tail)
+ } else {
+ LogicNode(mname, e)
+ }
+ }
+
+
+ /** Construct the dependency graph within this module */
+ private def setupDepGraph(depGraph: MutableDiGraph[LogicNode],
+ instMap: collection.Map[String, String])
+ (mod: Module): Unit = {
+ def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
+
+ def onStmt(stmt: Statement): Unit = stmt match {
+ case DefRegister(_, name, _, clock, reset, init) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref))
+ case DefNode(_, name, value) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ getDeps(value).foreach(ref => depGraph.addEdge(node, ref))
+ case DefWire(_, name, _) =>
+ depGraph.addVertex(LogicNode(mod.name, name))
+ case mem: DefMemory =>
+ // Treat DefMems as a node with outputs depending on the node and node depending on inputs
+ // From perpsective of the module or instance, MALE expressions are inputs, FEMALE are outputs
+ val memRef = WRef(mem.name, MemPortUtils.memType(mem), ExpKind, FEMALE)
+ val exprs = Utils.create_exps(memRef).groupBy(Utils.gender(_))
+ val sources = exprs.getOrElse(MALE, List.empty).flatMap(getDeps(_))
+ val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_))
+ val memNode = getDeps(memRef) match { case Seq(node) => node }
+ depGraph.addVertex(memNode)
+ sinks.foreach(sink => depGraph.addEdge(sink, memNode))
+ sources.foreach(source => depGraph.addEdge(memNode, source))
+ case Attach(_, exprs) => // Add edge between each expression
+ exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach {
+ case Seq(a, b) =>
+ depGraph.addEdge(a, b)
+ depGraph.addEdge(b, a)
+ }
+ case Connect(_, loc, expr) =>
+ // This match enforces the low Firrtl requirement of expanded connections
+ val node = getDeps(loc) match { case Seq(elt) => elt }
+ getDeps(expr).foreach(ref => depGraph.addEdge(node, ref))
+ // Simulation constructs are treated as top-level outputs
+ case Stop(_,_, clk, en) =>
+ Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ case Print(_, _, args, clk, en) =>
+ (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ case Block(stmts) => stmts.foreach(onStmt(_))
+ case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing
+ case other => throw new Exception(s"Unexpected Statement $other")
+ }
+
+ // Add all ports as vertices
+ mod.ports.foreach {
+ case Port(_, name, _, _: GroundType) => depGraph.addVertex(LogicNode(mod.name, name))
+ case other => throwInternalError
+ }
+ onStmt(mod.body)
+ }
+
+ // TODO Make immutable?
+ private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]],
+ doTouchExtMods: Set[String],
+ c: Circuit): MutableDiGraph[LogicNode] = {
+ val depGraph = new MutableDiGraph[LogicNode]
+ c.modules.foreach {
+ case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
+ case ext: ExtModule =>
+ // Connect all inputs to all outputs
+ val node = LogicNode(ext)
+ ext.ports.foreach {
+ case Port(_, pname, _, AnalogType(_)) =>
+ depGraph.addEdge(LogicNode(ext.name, pname), node)
+ depGraph.addEdge(node, LogicNode(ext.name, pname))
+ case Port(_, pname, Output, _) =>
+ val portNode = LogicNode(ext.name, pname)
+ depGraph.addEdge(portNode, node)
+ // Don't touch external modules *unless* they are specifically marked as doTouch
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode)
+ case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname))
+ }
+ }
+ // Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface)
+ val topModule = c.modules.find(_.name == c.main).get
+ val topOutputs = topModule.ports.foreach { port =>
+ depGraph.addEdge(circuitSink, LogicNode(c.main, port.name))
+ }
+
+ depGraph
+ }
+
+ private def deleteDeadCode(instMap: collection.Map[String, String],
+ deadNodes: Set[LogicNode],
+ moduleMap: collection.Map[String, DefModule],
+ renames: RenameMap)
+ (mod: DefModule): Option[DefModule] = {
+ def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
+
+ var emptyBody = true
+ renames.setModule(mod.name)
+
+ def onStmt(stmt: Statement): Statement = {
+ val stmtx = stmt match {
+ case inst: WDefInstance =>
+ moduleMap.get(inst.module) match {
+ case Some(instMod) => inst.copy(tpe = Utils.module_type(instMod))
+ case None =>
+ renames.delete(inst.name)
+ EmptyStmt
+ }
+ case decl: IsDeclaration =>
+ val node = LogicNode(mod.name, decl.name)
+ if (deadNodes.contains(node)) {
+ renames.delete(decl.name)
+ EmptyStmt
+ }
+ else decl
+ case con: Connect =>
+ val node = getDeps(con.loc) match { case Seq(elt) => elt }
+ if (deadNodes.contains(node)) EmptyStmt else con
+ case Attach(info, exprs) => // If any exprs are dead then all are
+ val dead = exprs.flatMap(getDeps(_)).forall(deadNodes.contains(_))
+ if (dead) EmptyStmt else Attach(info, exprs)
+ case block: Block => block map onStmt
+ case other => other
+ }
+ stmtx match { // Check if module empty
+ case EmptyStmt | _: Block =>
+ case other => emptyBody = false
+ }
+ stmtx
+ }
+
+ val (deadPorts, portsx) = mod.ports.partition(p => deadNodes.contains(LogicNode(mod.name, p.name)))
+ deadPorts.foreach(p => renames.delete(p.name))
+
+ mod match {
+ case Module(info, name, _, body) =>
+ val bodyx = onStmt(body)
+ if (emptyBody && portsx.isEmpty) None else Some(Module(info, name, portsx, bodyx))
+ case ext: ExtModule =>
+ if (portsx.isEmpty) None
+ else {
+ if (ext.ports != portsx) throwInternalError // Sanity check
+ Some(ext.copy(ports = portsx))
+ }
+ }
+
+ }
+
+ def run(state: CircuitState,
+ dontTouches: Seq[LogicNode],
+ doTouchExtMods: Set[String]): CircuitState = {
+ val c = state.circuit
+ val moduleMap = c.modules.map(m => m.name -> m).toMap
+ val iGraph = new InstanceGraph(c)
+ val moduleDeps = iGraph.graph.edges.map { case (k,v) =>
+ k.module -> v.map(i => i.name -> i.module).toMap
+ }
+ val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
+
+ val depGraph = {
+ val dGraph = createDependencyGraph(moduleDeps, doTouchExtMods, c)
+ for (dontTouch <- dontTouches) {
+ dGraph.getVertices.find(_ == dontTouch) match {
+ case Some(node) => dGraph.addEdge(circuitSink, node)
+ case None =>
+ val (root, tail) = Utils.splitRef(dontTouch.e1)
+ DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize)
+ }
+ }
+ DiGraph(dGraph)
+ }
+
+ val liveNodes = depGraph.reachableFrom(circuitSink) + circuitSink
+ val deadNodes = depGraph.getVertices -- liveNodes
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+
+ // As we delete deadCode, we will delete ports from Modules and somtimes complete modules
+ // themselves. We iterate over the modules in a topological order from leaves to the top. The
+ // current status of the modulesxMap is used to either delete instances or update their types
+ val modulesxMap = mutable.HashMap.empty[String, DefModule]
+ topoSortedModules.foreach { case mod =>
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match {
+ case Some(m) => modulesxMap += m.name -> m
+ case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
+ }
+ }
+
+ // Preserve original module order
+ val newCircuit = c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name)))
+
+ state.copy(circuit = newCircuit, renames = Some(renames))
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val (dontTouches: Seq[LogicNode], doTouchExtMods: Seq[String]) =
+ state.annotations match {
+ case Some(aMap) =>
+ // TODO Do with single walk over annotations
+ val dontTouches = aMap.annotations.collect {
+ case DontTouchAnnotation(component) => LogicNode(component)
+ }
+ val optExtMods = aMap.annotations.collect {
+ case OptimizableExtModuleAnnotation(ModuleName(name, _)) => name
+ }
+ (dontTouches, optExtMods)
+ case None => (Seq.empty, Seq.empty)
+ }
+ run(state, dontTouches, doTouchExtMods.toSet)
+ }
+}