aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2017-05-11 18:22:08 -0700
committerGitHub2017-05-11 18:22:08 -0700
commitcf226360a7681354609779743895d015c3415451 (patch)
treeef3764642a2fb9083f12a6dea6188225ff1e3786 /src
parentfba12e01fda28a72b3c00116b52f8aee8bce0677 (diff)
Improved Global Dead Code Elimination (#549)
Performs DCE by constructing a global dependency graph starting with top-level outputs, external module ports, and simulation constructs as circuit sinks. External modules can optionally be eligible for DCE via the OptimizableExtModuleAnnotation. Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all eligible for removal. Components marked with a DontTouchAnnotation will be treated as a circuit sink and thus anything that drives such a marked component will NOT be removed. This transform preserves deduplication. All instances of a given DefModule are treated as the same individual module. Thus, while certain instances may have dead code due to the circumstances of their instantiation in their parent module, they will still not be removed. To remove such modules, use the NoDedupAnnotation to prevent deduplication.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala2
-rw-r--r--src/main/scala/firrtl/WIR.scala1
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala296
-rw-r--r--src/main/scala/firrtl/transforms/OptimizationAnnotations.scala48
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala81
-rw-r--r--src/test/scala/firrtlTests/AttachSpec.scala11
-rw-r--r--src/test/scala/firrtlTests/DCETests.scala366
-rw-r--r--src/test/scala/firrtlTests/FirrtlSpec.scala22
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala11
9 files changed, 813 insertions, 25 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 84f237a3..b28dd1b2 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -105,7 +105,7 @@ class LowFirrtlOptimization extends CoreTransform {
passes.ConstProp,
passes.SplitExpressions,
passes.CommonSubexpressionElimination,
- passes.DeadCodeElimination)
+ new firrtl.transforms.DeadCodeElimination)
}
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index d2e1457c..47e0f321 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -45,6 +45,7 @@ case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender)
}
object WSubField {
def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER)
+ def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UNKNOWNGENDER)
}
case class WSubIndex(expr: Expression, value: Int, tpe: Type, gender: Gender) extends Expression {
def serialize: String = s"${expr.serialize}[$value]"
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
new file mode 100644
index 00000000..5199276c
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -0,0 +1,296 @@
+
+package firrtl.transforms
+
+import firrtl._
+import firrtl.ir._
+import firrtl.passes._
+import firrtl.annotations._
+import firrtl.graph._
+import firrtl.analyses.InstanceGraph
+import firrtl.Mappers._
+import firrtl.WrappedExpression._
+import firrtl.Utils.{throwInternalError, toWrappedExpression, kind}
+import firrtl.MemoizedHash._
+import wiring.WiringUtils.getChildrenMap
+
+import collection.mutable
+import java.io.{File, FileWriter}
+
+/** Dead Code Elimination (DCE)
+ *
+ * Performs DCE by constructing a global dependency graph starting with top-level outputs, external
+ * module ports, and simulation constructs as circuit sinks. External modules can optionally be
+ * eligible for DCE via the [[OptimizableExtModuleAnnotation]].
+ *
+ * Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all
+ * eligible for removal. Components marked with a [[DontTouchAnnotation]] will be treated as a
+ * circuit sink and thus anything that drives such a marked component will NOT be removed.
+ *
+ * This transform preserves deduplication. All instances of a given [[DefModule]] are treated as
+ * the same individual module. Thus, while certain instances may have dead code due to the
+ * circumstances of their instantiation in their parent module, they will still not be removed. To
+ * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication.
+ */
+class DeadCodeElimination extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */
+ private type LogicNode = MemoizedHash[WrappedExpression]
+ private object LogicNode {
+ def apply(moduleName: String, expr: Expression): LogicNode =
+ WrappedExpression(Utils.mergeRef(WRef(moduleName), expr))
+ def apply(moduleName: String, name: String): LogicNode = apply(moduleName, WRef(name))
+ def apply(component: ComponentName): LogicNode = {
+ // Currently only leaf nodes are supported TODO implement
+ val loweredName = LowerTypes.loweredName(component.name.split('.'))
+ apply(component.module.name, WRef(loweredName))
+ }
+ /** External Modules are representated as a single node driven by all inputs and driving all
+ * outputs
+ */
+ def apply(ext: ExtModule): LogicNode = LogicNode(ext.name, ext.name)
+ }
+
+ /** Expression used to represent outputs in the circuit (# is illegal in names) */
+ private val circuitSink = LogicNode("#Top", "#Sink")
+
+ /** Extract all References and SubFields from a possibly nested Expression */
+ def extractRefs(expr: Expression): Seq[Expression] = {
+ val refs = mutable.ArrayBuffer.empty[Expression]
+ def rec(e: Expression): Expression = {
+ e match {
+ case ref @ (_: WRef | _: WSubField) => refs += ref
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case ignore @ (_: Literal) => // Do nothing
+ case unexpected => throwInternalError
+ }
+ e
+ }
+ rec(expr)
+ refs
+ }
+
+ // Gets all dependencies and constructs LogicNodes from them
+ private def getDepsImpl(mname: String,
+ instMap: collection.Map[String, String])
+ (expr: Expression): Seq[LogicNode] =
+ extractRefs(expr).map { e =>
+ if (kind(e) == InstanceKind) {
+ val (inst, tail) = Utils.splitRef(e)
+ LogicNode(instMap(inst.name), tail)
+ } else {
+ LogicNode(mname, e)
+ }
+ }
+
+
+ /** Construct the dependency graph within this module */
+ private def setupDepGraph(depGraph: MutableDiGraph[LogicNode],
+ instMap: collection.Map[String, String])
+ (mod: Module): Unit = {
+ def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
+
+ def onStmt(stmt: Statement): Unit = stmt match {
+ case DefRegister(_, name, _, clock, reset, init) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref))
+ case DefNode(_, name, value) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ getDeps(value).foreach(ref => depGraph.addEdge(node, ref))
+ case DefWire(_, name, _) =>
+ depGraph.addVertex(LogicNode(mod.name, name))
+ case mem: DefMemory =>
+ // Treat DefMems as a node with outputs depending on the node and node depending on inputs
+ // From perpsective of the module or instance, MALE expressions are inputs, FEMALE are outputs
+ val memRef = WRef(mem.name, MemPortUtils.memType(mem), ExpKind, FEMALE)
+ val exprs = Utils.create_exps(memRef).groupBy(Utils.gender(_))
+ val sources = exprs.getOrElse(MALE, List.empty).flatMap(getDeps(_))
+ val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_))
+ val memNode = getDeps(memRef) match { case Seq(node) => node }
+ depGraph.addVertex(memNode)
+ sinks.foreach(sink => depGraph.addEdge(sink, memNode))
+ sources.foreach(source => depGraph.addEdge(memNode, source))
+ case Attach(_, exprs) => // Add edge between each expression
+ exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach {
+ case Seq(a, b) =>
+ depGraph.addEdge(a, b)
+ depGraph.addEdge(b, a)
+ }
+ case Connect(_, loc, expr) =>
+ // This match enforces the low Firrtl requirement of expanded connections
+ val node = getDeps(loc) match { case Seq(elt) => elt }
+ getDeps(expr).foreach(ref => depGraph.addEdge(node, ref))
+ // Simulation constructs are treated as top-level outputs
+ case Stop(_,_, clk, en) =>
+ Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ case Print(_, _, args, clk, en) =>
+ (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ case Block(stmts) => stmts.foreach(onStmt(_))
+ case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing
+ case other => throw new Exception(s"Unexpected Statement $other")
+ }
+
+ // Add all ports as vertices
+ mod.ports.foreach {
+ case Port(_, name, _, _: GroundType) => depGraph.addVertex(LogicNode(mod.name, name))
+ case other => throwInternalError
+ }
+ onStmt(mod.body)
+ }
+
+ // TODO Make immutable?
+ private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]],
+ doTouchExtMods: Set[String],
+ c: Circuit): MutableDiGraph[LogicNode] = {
+ val depGraph = new MutableDiGraph[LogicNode]
+ c.modules.foreach {
+ case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
+ case ext: ExtModule =>
+ // Connect all inputs to all outputs
+ val node = LogicNode(ext)
+ ext.ports.foreach {
+ case Port(_, pname, _, AnalogType(_)) =>
+ depGraph.addEdge(LogicNode(ext.name, pname), node)
+ depGraph.addEdge(node, LogicNode(ext.name, pname))
+ case Port(_, pname, Output, _) =>
+ val portNode = LogicNode(ext.name, pname)
+ depGraph.addEdge(portNode, node)
+ // Don't touch external modules *unless* they are specifically marked as doTouch
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode)
+ case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname))
+ }
+ }
+ // Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface)
+ val topModule = c.modules.find(_.name == c.main).get
+ val topOutputs = topModule.ports.foreach { port =>
+ depGraph.addEdge(circuitSink, LogicNode(c.main, port.name))
+ }
+
+ depGraph
+ }
+
+ private def deleteDeadCode(instMap: collection.Map[String, String],
+ deadNodes: Set[LogicNode],
+ moduleMap: collection.Map[String, DefModule],
+ renames: RenameMap)
+ (mod: DefModule): Option[DefModule] = {
+ def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
+
+ var emptyBody = true
+ renames.setModule(mod.name)
+
+ def onStmt(stmt: Statement): Statement = {
+ val stmtx = stmt match {
+ case inst: WDefInstance =>
+ moduleMap.get(inst.module) match {
+ case Some(instMod) => inst.copy(tpe = Utils.module_type(instMod))
+ case None =>
+ renames.delete(inst.name)
+ EmptyStmt
+ }
+ case decl: IsDeclaration =>
+ val node = LogicNode(mod.name, decl.name)
+ if (deadNodes.contains(node)) {
+ renames.delete(decl.name)
+ EmptyStmt
+ }
+ else decl
+ case con: Connect =>
+ val node = getDeps(con.loc) match { case Seq(elt) => elt }
+ if (deadNodes.contains(node)) EmptyStmt else con
+ case Attach(info, exprs) => // If any exprs are dead then all are
+ val dead = exprs.flatMap(getDeps(_)).forall(deadNodes.contains(_))
+ if (dead) EmptyStmt else Attach(info, exprs)
+ case block: Block => block map onStmt
+ case other => other
+ }
+ stmtx match { // Check if module empty
+ case EmptyStmt | _: Block =>
+ case other => emptyBody = false
+ }
+ stmtx
+ }
+
+ val (deadPorts, portsx) = mod.ports.partition(p => deadNodes.contains(LogicNode(mod.name, p.name)))
+ deadPorts.foreach(p => renames.delete(p.name))
+
+ mod match {
+ case Module(info, name, _, body) =>
+ val bodyx = onStmt(body)
+ if (emptyBody && portsx.isEmpty) None else Some(Module(info, name, portsx, bodyx))
+ case ext: ExtModule =>
+ if (portsx.isEmpty) None
+ else {
+ if (ext.ports != portsx) throwInternalError // Sanity check
+ Some(ext.copy(ports = portsx))
+ }
+ }
+
+ }
+
+ def run(state: CircuitState,
+ dontTouches: Seq[LogicNode],
+ doTouchExtMods: Set[String]): CircuitState = {
+ val c = state.circuit
+ val moduleMap = c.modules.map(m => m.name -> m).toMap
+ val iGraph = new InstanceGraph(c)
+ val moduleDeps = iGraph.graph.edges.map { case (k,v) =>
+ k.module -> v.map(i => i.name -> i.module).toMap
+ }
+ val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
+
+ val depGraph = {
+ val dGraph = createDependencyGraph(moduleDeps, doTouchExtMods, c)
+ for (dontTouch <- dontTouches) {
+ dGraph.getVertices.find(_ == dontTouch) match {
+ case Some(node) => dGraph.addEdge(circuitSink, node)
+ case None =>
+ val (root, tail) = Utils.splitRef(dontTouch.e1)
+ DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize)
+ }
+ }
+ DiGraph(dGraph)
+ }
+
+ val liveNodes = depGraph.reachableFrom(circuitSink) + circuitSink
+ val deadNodes = depGraph.getVertices -- liveNodes
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+
+ // As we delete deadCode, we will delete ports from Modules and somtimes complete modules
+ // themselves. We iterate over the modules in a topological order from leaves to the top. The
+ // current status of the modulesxMap is used to either delete instances or update their types
+ val modulesxMap = mutable.HashMap.empty[String, DefModule]
+ topoSortedModules.foreach { case mod =>
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match {
+ case Some(m) => modulesxMap += m.name -> m
+ case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
+ }
+ }
+
+ // Preserve original module order
+ val newCircuit = c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name)))
+
+ state.copy(circuit = newCircuit, renames = Some(renames))
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val (dontTouches: Seq[LogicNode], doTouchExtMods: Seq[String]) =
+ state.annotations match {
+ case Some(aMap) =>
+ // TODO Do with single walk over annotations
+ val dontTouches = aMap.annotations.collect {
+ case DontTouchAnnotation(component) => LogicNode(component)
+ }
+ val optExtMods = aMap.annotations.collect {
+ case OptimizableExtModuleAnnotation(ModuleName(name, _)) => name
+ }
+ (dontTouches, optExtMods)
+ case None => (Seq.empty, Seq.empty)
+ }
+ run(state, dontTouches, doTouchExtMods.toSet)
+ }
+}
diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
new file mode 100644
index 00000000..23723a60
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
@@ -0,0 +1,48 @@
+
+package firrtl
+package transforms
+
+import firrtl.annotations._
+import firrtl.passes.PassException
+
+/** A component that should be preserved
+ *
+ * DCE treats the component as a top-level sink of the circuit
+ */
+object DontTouchAnnotation {
+ private val marker = "DONTtouch!"
+ def apply(target: ComponentName): Annotation = Annotation(target, classOf[Transform], marker)
+
+ def unapply(a: Annotation): Option[ComponentName] = a match {
+ case Annotation(component: ComponentName, _, value) if value == marker => Some(component)
+ case _ => None
+ }
+
+ class DontTouchNotFoundException(module: String, component: String) extends PassException(
+ s"Component marked DONT Touch ($module.$component) not found!\n" +
+ "Perhaps it is an aggregate type? Currently only leaf components are supported.\n" +
+ "Otherwise it was probably accidentally deleted. Please check that your custom passes are not" +
+ "responsible and then file an issue on Github."
+ )
+
+ def errorNotFound(module: String, component: String) =
+ throw new DontTouchNotFoundException(module, component)
+}
+
+/** An [[firrtl.ir.ExtModule]] that can be optimized
+ *
+ * Firrtl does not know the semantics of an external module. This annotation provides some
+ * "greybox" information that the external module does not have any side effects. In particular,
+ * this means that the external module can be Dead Code Eliminated.
+ *
+ * @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted
+ */
+object OptimizableExtModuleAnnotation {
+ private val marker = "optimizableExtModule!"
+ def apply(target: ModuleName): Annotation = Annotation(target, classOf[Transform], marker)
+
+ def unapply(a: Annotation): Option[ModuleName] = a match {
+ case Annotation(component: ModuleName, _, value) if value == marker => Some(component)
+ case _ => None
+ }
+}
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index 81c394c1..44964eba 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -7,8 +7,10 @@ import java.io.{File, FileWriter, Writer}
import firrtl.annotations.AnnotationYamlProtocol._
import firrtl.annotations._
import firrtl._
+import firrtl.transforms.OptimizableExtModuleAnnotation
import firrtl.passes.InlineAnnotation
import firrtl.passes.memlib.PinAnnotation
+import firrtl.transforms.DontTouchAnnotation
import net.jcazevedo.moultingyaml._
import org.scalatest.Matchers
import logger._
@@ -43,8 +45,17 @@ trait AnnotationSpec extends LowTransformSpec {
class AnnotationTests extends AnnotationSpec with Matchers {
def getAMap(a: Annotation): Option[AnnotationMap] = Some(AnnotationMap(Seq(a)))
def getAMap(as: Seq[Annotation]): Option[AnnotationMap] = Some(AnnotationMap(as))
- def anno(s: String, value: String ="this is a value"): Annotation =
- Annotation(ComponentName(s, ModuleName("Top", CircuitName("Top"))), classOf[Transform], value)
+ def anno(s: String, value: String ="this is a value", mod: String = "Top"): Annotation =
+ Annotation(ComponentName(s, ModuleName(mod, CircuitName("Top"))), classOf[Transform], value)
+ def manno(mod: String): Annotation =
+ Annotation(ModuleName(mod, CircuitName("Top")), classOf[Transform], "some value")
+ // TODO unify with FirrtlMatchers, problems with multiple definitions of parse
+ def dontTouch(path: String): Annotation = {
+ val parts = path.split('.')
+ require(parts.size >= 2, "Must specify both module and component!")
+ val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top")))
+ DontTouchAnnotation(name)
+ }
"Loose and Sticky annotation on a node" should "pass through" in {
val input: String =
@@ -145,7 +156,6 @@ class AnnotationTests extends AnnotationSpec with Matchers {
val deleted = result.deletedAnnotations
exception.str should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted")
}
-
"Renaming" should "propagate in Lowering of memories" in {
val compiler = new VerilogCompiler
// Uncomment to help debugging failing tests
@@ -165,7 +175,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| m.r.en <= UInt(1)
| m.r.addr <= in
|""".stripMargin
- val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"))
+ val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"),
+ dontTouch("Top.m"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("m_a", "mem"))
@@ -179,7 +190,6 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should not contain (anno("m"))
resultAnno should not contain (anno("r"))
}
-
"Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in {
val compiler = new VerilogCompiler
Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug))
@@ -191,7 +201,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| cmem m: {a: UInt<4>, b: UInt<4>[2]}[8]
| read mport r = m[in], clk
|""".stripMargin
- val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"))
+ val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"), dontTouch("Top.m"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("m_a", "mem"))
@@ -220,7 +230,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| x.a <= zero
| x.b <= zero
|""".stripMargin
- val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]"))
+ val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"),
+ anno("y[2]"), dontTouch("Top.x"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("x_a"))
@@ -260,7 +271,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
anno("w.a"), anno("w.b[0]"), anno("w.b[1]"),
anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
- anno("write.a"), anno("write.b[0]"), anno("write.b[1]")
+ anno("write.a"), anno("write.b[0]"), anno("write.b[1]"),
+ dontTouch("Top.r")
)
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
@@ -314,7 +326,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"))
+ val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_a"))
@@ -349,7 +361,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"))
+ val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"),
+ dontTouch("Top.r"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_b_0"))
@@ -364,8 +377,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("r_b_1"))
}
-
- "Renaming" should "track dce" in {
+ "Renaming" should "track constprop + dce" in {
val compiler = new VerilogCompiler
val input =
"""circuit Top :
@@ -403,4 +415,49 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("out_b_0"))
resultAnno should contain (anno("out_b_1"))
}
+
+ ignore should "track deleted modules AND instances in dce" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Dead :
+ | input foo : UInt<8>
+ | output bar : UInt<8>
+ | bar <= foo
+ | extmodule DeadExt :
+ | input foo : UInt<8>
+ | output bar : UInt<8>
+ | module Top :
+ | input foo : UInt<8>
+ | output bar : UInt<8>
+ | inst d of Dead
+ | d.foo <= foo
+ | inst d2 of DeadExt
+ | d2.foo <= foo
+ | bar <= foo
+ |""".stripMargin
+ val annos = Seq(
+ OptimizableExtModuleAnnotation(ModuleName("DeadExt", CircuitName("Top"))),
+ manno("Dead"), manno("DeadExt"), manno("Top"),
+ anno("d"), anno("d2"),
+ anno("foo", mod = "Top"), anno("bar", mod = "Top"),
+ anno("foo", mod = "Dead"), anno("bar", mod = "Dead"),
+ anno("foo", mod = "DeadExt"), anno("bar", mod = "DeadExt")
+ )
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+
+ resultAnno should contain (manno("Top"))
+ resultAnno should contain (anno("foo", mod = "Top"))
+ resultAnno should contain (anno("bar", mod = "Top"))
+
+ resultAnno should not contain (manno("Dead"))
+ resultAnno should not contain (manno("DeadExt"))
+ resultAnno should not contain (anno("d"))
+ resultAnno should not contain (anno("d2"))
+ resultAnno should not contain (anno("foo", mod = "Dead"))
+ resultAnno should not contain (anno("bar", mod = "Dead"))
+ resultAnno should not contain (anno("foo", mod = "DeadExt"))
+ resultAnno should not contain (anno("bar", mod = "DeadExt"))
+ }
}
diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala
index c29a7e43..93a36f70 100644
--- a/src/test/scala/firrtlTests/AttachSpec.scala
+++ b/src/test/scala/firrtlTests/AttachSpec.scala
@@ -62,7 +62,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec {
| module A:
| input an: Analog<3>
| module B:
- | input an: Analog<3> """.stripMargin
+ | input an: Analog<3>""".stripMargin
val check =
"""module Attaching(
|);
@@ -70,16 +70,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec {
| A a (
| .an(_GEN_0)
| );
- | A b (
+ | B b (
| .an(_GEN_0)
| );
|endmodule
|module A(
| inout [2:0] an
|);
+ |module B(
+ | inout [2:0] an
+ |);
|endmodule
|""".stripMargin.split("\n") map normalized
- executeTest(input, check, compiler)
+ executeTest(input, check, compiler, Seq(dontTouch("A.an"), dontDedup("A")))
}
it should "attach a wire source" in {
@@ -101,7 +104,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec {
| );
|endmodule
|""".stripMargin.split("\n") map normalized
- executeTest(input, check, compiler)
+ executeTest(input, check, compiler, Seq(dontTouch("Attaching.x")))
}
it should "attach multiple sources" in {
diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala
new file mode 100644
index 00000000..deb73b3b
--- /dev/null
+++ b/src/test/scala/firrtlTests/DCETests.scala
@@ -0,0 +1,366 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl.ir.Circuit
+import firrtl._
+import firrtl.passes._
+import firrtl.transforms._
+import firrtl.annotations._
+import firrtl.passes.memlib.SimpleTransform
+
+class DCETests extends FirrtlFlatSpec {
+ // Not using executeTest because it is for positive testing, we need to check that stuff got
+ // deleted
+ private val customTransforms = Seq(
+ new LowFirrtlOptimization,
+ new SimpleTransform(RemoveEmpty, LowForm)
+ )
+ private def exec(input: String, check: String, annos: Seq[Annotation] = List.empty): Unit = {
+ val state = CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annos)))
+ val finalState = (new LowFirrtlCompiler).compileAndEmit(state, customTransforms)
+ val res = finalState.getEmittedCircuit.value
+ // Convert to sets for comparison
+ val resSet = Set(parse(res).serialize.split("\n"):_*)
+ val checkSet = Set(parse(check).serialize.split("\n"):_*)
+ resSet should be (checkSet)
+ }
+
+ "Unread wire" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | z <= x
+ | a <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Unread wire marked dont touch" should "NOT be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | z <= x
+ | a <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | z <= x
+ | a <= x""".stripMargin
+ exec(input, check, Seq(dontTouch("Top.a")))
+ }
+ "Unread register" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | reg a : UInt<1>, clk
+ | a <= x
+ | node y = asUInt(clk)
+ | z <= or(x, y)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node y = asUInt(clk)
+ | z <= or(x, y)""".stripMargin
+ exec(input, check)
+ }
+ "Unread node" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Unused ports" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Sub :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | inst sub of Sub
+ | sub.x <= x
+ | z <= sub.z""".stripMargin
+ val check =
+ """circuit Top :
+ | module Sub :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | inst sub of Sub
+ | sub.x <= x
+ | z <= sub.z""".stripMargin
+ exec(input, check)
+ }
+ "Chain of unread nodes" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | node b = or(a, a)
+ | node c = add(b, x)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Chain of unread wires and their connections" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | a <= x
+ | wire b : UInt<1>
+ | b <= a
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Read register" should "not be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | reg r : UInt<1>, clk
+ | r <= x
+ | z <= r""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | reg r : UInt<1>, clk with : (reset => (UInt<1>("h0"), r))
+ | r <= x
+ | z <= r""".stripMargin
+ exec(input, check)
+ }
+ "Logic that feeds into simulation constructs" should "not be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | stop(clk, a, 0)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | z <= x
+ | stop(clk, a, 0)""".stripMargin
+ exec(input, check)
+ }
+ "Globally dead module" should "should be deleted" in {
+ val input =
+ """circuit Top :
+ | module Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Globally dead extmodule" should "NOT be deleted by default" in {
+ val input =
+ """circuit Top :
+ | extmodule Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | extmodule Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Globally dead extmodule marked optimizable" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | extmodule Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ val doTouchAnno = OptimizableExtModuleAnnotation(ModuleName("Dead", CircuitName("Top")))
+ exec(input, check, Seq(doTouchAnno))
+ }
+ "Analog ports of extmodules" should "count as both inputs and outputs" in {
+ val input =
+ """circuit Top :
+ | extmodule BB1 :
+ | output bus : Analog<1>
+ | extmodule BB2 :
+ | output bus : Analog<1>
+ | output out : UInt<1>
+ | module Top :
+ | output out : UInt<1>
+ | inst bb1 of BB1
+ | inst bb2 of BB2
+ | attach (bb1.bus, bb2.bus)
+ | out <= bb2.out
+ """.stripMargin
+ exec(input, input)
+ }
+ // bar.z is not used and thus is dead code, but foo.z is used so this code isn't eliminated
+ "Module deduplication" should "should be preserved despite unused output of ONE instance" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | output z : UInt<1>
+ | y <= not(x)
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | inst bar of Child
+ | foo.x <= x
+ | bar.x <= x
+ | node t0 = or(foo.y, foo.z)
+ | z <= or(t0, bar.y)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | output z : UInt<1>
+ | y <= not(x)
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | inst bar of Child
+ | foo.x <= x
+ | bar.x <= x
+ | node t0 = or(foo.y, foo.z)
+ | z <= or(t0, bar.y)""".stripMargin
+ exec(input, check)
+ }
+ // This currently does NOT work
+ behavior of "Single dead instances"
+ ignore should "should be deleted" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | inst bar of Child
+ | foo.x <= x
+ | bar.x <= x
+ | z <= foo.z""".stripMargin
+ val check =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | skip
+ | foo.x <= x
+ | skip
+ | z <= foo.z""".stripMargin
+ exec(input, check)
+ }
+}
diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala
index f77b47f3..a45af8c7 100644
--- a/src/test/scala/firrtlTests/FirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/FirrtlSpec.scala
@@ -12,7 +12,8 @@ import scala.io.Source
import firrtl._
import firrtl.Parser.IgnoreInfo
-import firrtl.annotations
+import firrtl.annotations._
+import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation}
import firrtl.util.BackendCompilationUtilities
trait FirrtlRunners extends BackendCompilationUtilities {
@@ -82,6 +83,16 @@ trait FirrtlRunners extends BackendCompilationUtilities {
}
trait FirrtlMatchers extends Matchers {
+ def dontTouch(path: String): Annotation = {
+ val parts = path.split('.')
+ require(parts.size >= 2, "Must specify both module and component!")
+ val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top")))
+ DontTouchAnnotation(name)
+ }
+ def dontDedup(mod: String): Annotation = {
+ require(mod.split('.').size == 1, "Can only specify a Module, not a component or instance")
+ NoDedupAnnotation(ModuleName(mod, CircuitName("Top")))
+ }
// Replace all whitespace with a single space and remove leading and
// trailing whitespace
// Note this is intended for single-line strings, no newlines
@@ -94,8 +105,13 @@ trait FirrtlMatchers extends Matchers {
* compiler will be run on input then emitted result will each be split into
* lines and normalized.
*/
- def executeTest(input: String, expected: Seq[String], compiler: Compiler) = {
- val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm))
+ def executeTest(
+ input: String,
+ expected: Seq[String],
+ compiler: Compiler,
+ annotations: Seq[Annotation] = Seq.empty) = {
+ val annoMap = AnnotationMap(annotations)
+ val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annoMap)))
val lines = finalState.getEmittedCircuit.value split "\n" map normalized
for (e <- expected) {
lines should contain (e)
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 8367f152..25f845bc 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -5,6 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms._
import firrtl.passes.memlib._
import annotations._
@@ -21,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty)
+ def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty)
}
)
@@ -199,7 +200,7 @@ circuit CustomMemory :
smem mem_1 : UInt<16>[7]
read mport _T_17 = mem_0[io.rAddr], clock
read mport _T_19 = mem_1[io.rAddr], clock
- io.dO <= _T_17
+ io.dO <= and(_T_17, _T_19)
when io.wEn :
write mport _T_18 = mem_0[io.wAddr], clock
write mport _T_20 = mem_1[io.wAddr], clock
@@ -218,7 +219,7 @@ circuit CustomMemory :
case e: ExtModule => true
case _ => false
}
- require(numExtMods == 2)
+ numExtMods should be (2)
(new java.io.File(confLoc)).delete()
}
@@ -237,7 +238,7 @@ circuit CustomMemory :
read mport _T_17 = mem_0[io.rAddr], clock
read mport _T_19 = mem_1[io.rAddr], clock
read mport _T_21 = mem_2[io.rAddr], clock
- io.dO <= _T_17
+ io.dO <= and(_T_17, and(_T_19, _T_21))
when io.wEn :
write mport _T_18 = mem_0[io.wAddr], clock
write mport _T_20 = mem_1[io.wAddr], clock
@@ -258,7 +259,7 @@ circuit CustomMemory :
case e: ExtModule => true
case _ => false
}
- require(numExtMods == 2)
+ numExtMods should be (2)
(new java.io.File(confLoc)).delete()
}