aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala2
-rw-r--r--src/main/scala/firrtl/WIR.scala1
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala296
-rw-r--r--src/main/scala/firrtl/transforms/OptimizationAnnotations.scala48
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala81
-rw-r--r--src/test/scala/firrtlTests/AttachSpec.scala11
-rw-r--r--src/test/scala/firrtlTests/DCETests.scala366
-rw-r--r--src/test/scala/firrtlTests/FirrtlSpec.scala22
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala11
9 files changed, 813 insertions, 25 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 84f237a3..b28dd1b2 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -105,7 +105,7 @@ class LowFirrtlOptimization extends CoreTransform {
passes.ConstProp,
passes.SplitExpressions,
passes.CommonSubexpressionElimination,
- passes.DeadCodeElimination)
+ new firrtl.transforms.DeadCodeElimination)
}
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index d2e1457c..47e0f321 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -45,6 +45,7 @@ case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender)
}
object WSubField {
def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER)
+ def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UNKNOWNGENDER)
}
case class WSubIndex(expr: Expression, value: Int, tpe: Type, gender: Gender) extends Expression {
def serialize: String = s"${expr.serialize}[$value]"
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
new file mode 100644
index 00000000..5199276c
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -0,0 +1,296 @@
+
+package firrtl.transforms
+
+import firrtl._
+import firrtl.ir._
+import firrtl.passes._
+import firrtl.annotations._
+import firrtl.graph._
+import firrtl.analyses.InstanceGraph
+import firrtl.Mappers._
+import firrtl.WrappedExpression._
+import firrtl.Utils.{throwInternalError, toWrappedExpression, kind}
+import firrtl.MemoizedHash._
+import wiring.WiringUtils.getChildrenMap
+
+import collection.mutable
+import java.io.{File, FileWriter}
+
+/** Dead Code Elimination (DCE)
+ *
+ * Performs DCE by constructing a global dependency graph starting with top-level outputs, external
+ * module ports, and simulation constructs as circuit sinks. External modules can optionally be
+ * eligible for DCE via the [[OptimizableExtModuleAnnotation]].
+ *
+ * Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all
+ * eligible for removal. Components marked with a [[DontTouchAnnotation]] will be treated as a
+ * circuit sink and thus anything that drives such a marked component will NOT be removed.
+ *
+ * This transform preserves deduplication. All instances of a given [[DefModule]] are treated as
+ * the same individual module. Thus, while certain instances may have dead code due to the
+ * circumstances of their instantiation in their parent module, they will still not be removed. To
+ * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication.
+ */
+class DeadCodeElimination extends Transform {
+ def inputForm = LowForm
+ def outputForm = LowForm
+
+ /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */
+ private type LogicNode = MemoizedHash[WrappedExpression]
+ private object LogicNode {
+ def apply(moduleName: String, expr: Expression): LogicNode =
+ WrappedExpression(Utils.mergeRef(WRef(moduleName), expr))
+ def apply(moduleName: String, name: String): LogicNode = apply(moduleName, WRef(name))
+ def apply(component: ComponentName): LogicNode = {
+ // Currently only leaf nodes are supported TODO implement
+ val loweredName = LowerTypes.loweredName(component.name.split('.'))
+ apply(component.module.name, WRef(loweredName))
+ }
+ /** External Modules are representated as a single node driven by all inputs and driving all
+ * outputs
+ */
+ def apply(ext: ExtModule): LogicNode = LogicNode(ext.name, ext.name)
+ }
+
+ /** Expression used to represent outputs in the circuit (# is illegal in names) */
+ private val circuitSink = LogicNode("#Top", "#Sink")
+
+ /** Extract all References and SubFields from a possibly nested Expression */
+ def extractRefs(expr: Expression): Seq[Expression] = {
+ val refs = mutable.ArrayBuffer.empty[Expression]
+ def rec(e: Expression): Expression = {
+ e match {
+ case ref @ (_: WRef | _: WSubField) => refs += ref
+ case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec
+ case ignore @ (_: Literal) => // Do nothing
+ case unexpected => throwInternalError
+ }
+ e
+ }
+ rec(expr)
+ refs
+ }
+
+ // Gets all dependencies and constructs LogicNodes from them
+ private def getDepsImpl(mname: String,
+ instMap: collection.Map[String, String])
+ (expr: Expression): Seq[LogicNode] =
+ extractRefs(expr).map { e =>
+ if (kind(e) == InstanceKind) {
+ val (inst, tail) = Utils.splitRef(e)
+ LogicNode(instMap(inst.name), tail)
+ } else {
+ LogicNode(mname, e)
+ }
+ }
+
+
+ /** Construct the dependency graph within this module */
+ private def setupDepGraph(depGraph: MutableDiGraph[LogicNode],
+ instMap: collection.Map[String, String])
+ (mod: Module): Unit = {
+ def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
+
+ def onStmt(stmt: Statement): Unit = stmt match {
+ case DefRegister(_, name, _, clock, reset, init) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref))
+ case DefNode(_, name, value) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ getDeps(value).foreach(ref => depGraph.addEdge(node, ref))
+ case DefWire(_, name, _) =>
+ depGraph.addVertex(LogicNode(mod.name, name))
+ case mem: DefMemory =>
+ // Treat DefMems as a node with outputs depending on the node and node depending on inputs
+ // From perpsective of the module or instance, MALE expressions are inputs, FEMALE are outputs
+ val memRef = WRef(mem.name, MemPortUtils.memType(mem), ExpKind, FEMALE)
+ val exprs = Utils.create_exps(memRef).groupBy(Utils.gender(_))
+ val sources = exprs.getOrElse(MALE, List.empty).flatMap(getDeps(_))
+ val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_))
+ val memNode = getDeps(memRef) match { case Seq(node) => node }
+ depGraph.addVertex(memNode)
+ sinks.foreach(sink => depGraph.addEdge(sink, memNode))
+ sources.foreach(source => depGraph.addEdge(memNode, source))
+ case Attach(_, exprs) => // Add edge between each expression
+ exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach {
+ case Seq(a, b) =>
+ depGraph.addEdge(a, b)
+ depGraph.addEdge(b, a)
+ }
+ case Connect(_, loc, expr) =>
+ // This match enforces the low Firrtl requirement of expanded connections
+ val node = getDeps(loc) match { case Seq(elt) => elt }
+ getDeps(expr).foreach(ref => depGraph.addEdge(node, ref))
+ // Simulation constructs are treated as top-level outputs
+ case Stop(_,_, clk, en) =>
+ Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ case Print(_, _, args, clk, en) =>
+ (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ case Block(stmts) => stmts.foreach(onStmt(_))
+ case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing
+ case other => throw new Exception(s"Unexpected Statement $other")
+ }
+
+ // Add all ports as vertices
+ mod.ports.foreach {
+ case Port(_, name, _, _: GroundType) => depGraph.addVertex(LogicNode(mod.name, name))
+ case other => throwInternalError
+ }
+ onStmt(mod.body)
+ }
+
+ // TODO Make immutable?
+ private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]],
+ doTouchExtMods: Set[String],
+ c: Circuit): MutableDiGraph[LogicNode] = {
+ val depGraph = new MutableDiGraph[LogicNode]
+ c.modules.foreach {
+ case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod)
+ case ext: ExtModule =>
+ // Connect all inputs to all outputs
+ val node = LogicNode(ext)
+ ext.ports.foreach {
+ case Port(_, pname, _, AnalogType(_)) =>
+ depGraph.addEdge(LogicNode(ext.name, pname), node)
+ depGraph.addEdge(node, LogicNode(ext.name, pname))
+ case Port(_, pname, Output, _) =>
+ val portNode = LogicNode(ext.name, pname)
+ depGraph.addEdge(portNode, node)
+ // Don't touch external modules *unless* they are specifically marked as doTouch
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode)
+ case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname))
+ }
+ }
+ // Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface)
+ val topModule = c.modules.find(_.name == c.main).get
+ val topOutputs = topModule.ports.foreach { port =>
+ depGraph.addEdge(circuitSink, LogicNode(c.main, port.name))
+ }
+
+ depGraph
+ }
+
+ private def deleteDeadCode(instMap: collection.Map[String, String],
+ deadNodes: Set[LogicNode],
+ moduleMap: collection.Map[String, DefModule],
+ renames: RenameMap)
+ (mod: DefModule): Option[DefModule] = {
+ def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr)
+
+ var emptyBody = true
+ renames.setModule(mod.name)
+
+ def onStmt(stmt: Statement): Statement = {
+ val stmtx = stmt match {
+ case inst: WDefInstance =>
+ moduleMap.get(inst.module) match {
+ case Some(instMod) => inst.copy(tpe = Utils.module_type(instMod))
+ case None =>
+ renames.delete(inst.name)
+ EmptyStmt
+ }
+ case decl: IsDeclaration =>
+ val node = LogicNode(mod.name, decl.name)
+ if (deadNodes.contains(node)) {
+ renames.delete(decl.name)
+ EmptyStmt
+ }
+ else decl
+ case con: Connect =>
+ val node = getDeps(con.loc) match { case Seq(elt) => elt }
+ if (deadNodes.contains(node)) EmptyStmt else con
+ case Attach(info, exprs) => // If any exprs are dead then all are
+ val dead = exprs.flatMap(getDeps(_)).forall(deadNodes.contains(_))
+ if (dead) EmptyStmt else Attach(info, exprs)
+ case block: Block => block map onStmt
+ case other => other
+ }
+ stmtx match { // Check if module empty
+ case EmptyStmt | _: Block =>
+ case other => emptyBody = false
+ }
+ stmtx
+ }
+
+ val (deadPorts, portsx) = mod.ports.partition(p => deadNodes.contains(LogicNode(mod.name, p.name)))
+ deadPorts.foreach(p => renames.delete(p.name))
+
+ mod match {
+ case Module(info, name, _, body) =>
+ val bodyx = onStmt(body)
+ if (emptyBody && portsx.isEmpty) None else Some(Module(info, name, portsx, bodyx))
+ case ext: ExtModule =>
+ if (portsx.isEmpty) None
+ else {
+ if (ext.ports != portsx) throwInternalError // Sanity check
+ Some(ext.copy(ports = portsx))
+ }
+ }
+
+ }
+
+ def run(state: CircuitState,
+ dontTouches: Seq[LogicNode],
+ doTouchExtMods: Set[String]): CircuitState = {
+ val c = state.circuit
+ val moduleMap = c.modules.map(m => m.name -> m).toMap
+ val iGraph = new InstanceGraph(c)
+ val moduleDeps = iGraph.graph.edges.map { case (k,v) =>
+ k.module -> v.map(i => i.name -> i.module).toMap
+ }
+ val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
+
+ val depGraph = {
+ val dGraph = createDependencyGraph(moduleDeps, doTouchExtMods, c)
+ for (dontTouch <- dontTouches) {
+ dGraph.getVertices.find(_ == dontTouch) match {
+ case Some(node) => dGraph.addEdge(circuitSink, node)
+ case None =>
+ val (root, tail) = Utils.splitRef(dontTouch.e1)
+ DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize)
+ }
+ }
+ DiGraph(dGraph)
+ }
+
+ val liveNodes = depGraph.reachableFrom(circuitSink) + circuitSink
+ val deadNodes = depGraph.getVertices -- liveNodes
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+
+ // As we delete deadCode, we will delete ports from Modules and somtimes complete modules
+ // themselves. We iterate over the modules in a topological order from leaves to the top. The
+ // current status of the modulesxMap is used to either delete instances or update their types
+ val modulesxMap = mutable.HashMap.empty[String, DefModule]
+ topoSortedModules.foreach { case mod =>
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match {
+ case Some(m) => modulesxMap += m.name -> m
+ case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
+ }
+ }
+
+ // Preserve original module order
+ val newCircuit = c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name)))
+
+ state.copy(circuit = newCircuit, renames = Some(renames))
+ }
+
+ def execute(state: CircuitState): CircuitState = {
+ val (dontTouches: Seq[LogicNode], doTouchExtMods: Seq[String]) =
+ state.annotations match {
+ case Some(aMap) =>
+ // TODO Do with single walk over annotations
+ val dontTouches = aMap.annotations.collect {
+ case DontTouchAnnotation(component) => LogicNode(component)
+ }
+ val optExtMods = aMap.annotations.collect {
+ case OptimizableExtModuleAnnotation(ModuleName(name, _)) => name
+ }
+ (dontTouches, optExtMods)
+ case None => (Seq.empty, Seq.empty)
+ }
+ run(state, dontTouches, doTouchExtMods.toSet)
+ }
+}
diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
new file mode 100644
index 00000000..23723a60
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala
@@ -0,0 +1,48 @@
+
+package firrtl
+package transforms
+
+import firrtl.annotations._
+import firrtl.passes.PassException
+
+/** A component that should be preserved
+ *
+ * DCE treats the component as a top-level sink of the circuit
+ */
+object DontTouchAnnotation {
+ private val marker = "DONTtouch!"
+ def apply(target: ComponentName): Annotation = Annotation(target, classOf[Transform], marker)
+
+ def unapply(a: Annotation): Option[ComponentName] = a match {
+ case Annotation(component: ComponentName, _, value) if value == marker => Some(component)
+ case _ => None
+ }
+
+ class DontTouchNotFoundException(module: String, component: String) extends PassException(
+ s"Component marked DONT Touch ($module.$component) not found!\n" +
+ "Perhaps it is an aggregate type? Currently only leaf components are supported.\n" +
+ "Otherwise it was probably accidentally deleted. Please check that your custom passes are not" +
+ "responsible and then file an issue on Github."
+ )
+
+ def errorNotFound(module: String, component: String) =
+ throw new DontTouchNotFoundException(module, component)
+}
+
+/** An [[firrtl.ir.ExtModule]] that can be optimized
+ *
+ * Firrtl does not know the semantics of an external module. This annotation provides some
+ * "greybox" information that the external module does not have any side effects. In particular,
+ * this means that the external module can be Dead Code Eliminated.
+ *
+ * @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted
+ */
+object OptimizableExtModuleAnnotation {
+ private val marker = "optimizableExtModule!"
+ def apply(target: ModuleName): Annotation = Annotation(target, classOf[Transform], marker)
+
+ def unapply(a: Annotation): Option[ModuleName] = a match {
+ case Annotation(component: ModuleName, _, value) if value == marker => Some(component)
+ case _ => None
+ }
+}
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index 81c394c1..44964eba 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -7,8 +7,10 @@ import java.io.{File, FileWriter, Writer}
import firrtl.annotations.AnnotationYamlProtocol._
import firrtl.annotations._
import firrtl._
+import firrtl.transforms.OptimizableExtModuleAnnotation
import firrtl.passes.InlineAnnotation
import firrtl.passes.memlib.PinAnnotation
+import firrtl.transforms.DontTouchAnnotation
import net.jcazevedo.moultingyaml._
import org.scalatest.Matchers
import logger._
@@ -43,8 +45,17 @@ trait AnnotationSpec extends LowTransformSpec {
class AnnotationTests extends AnnotationSpec with Matchers {
def getAMap(a: Annotation): Option[AnnotationMap] = Some(AnnotationMap(Seq(a)))
def getAMap(as: Seq[Annotation]): Option[AnnotationMap] = Some(AnnotationMap(as))
- def anno(s: String, value: String ="this is a value"): Annotation =
- Annotation(ComponentName(s, ModuleName("Top", CircuitName("Top"))), classOf[Transform], value)
+ def anno(s: String, value: String ="this is a value", mod: String = "Top"): Annotation =
+ Annotation(ComponentName(s, ModuleName(mod, CircuitName("Top"))), classOf[Transform], value)
+ def manno(mod: String): Annotation =
+ Annotation(ModuleName(mod, CircuitName("Top")), classOf[Transform], "some value")
+ // TODO unify with FirrtlMatchers, problems with multiple definitions of parse
+ def dontTouch(path: String): Annotation = {
+ val parts = path.split('.')
+ require(parts.size >= 2, "Must specify both module and component!")
+ val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top")))
+ DontTouchAnnotation(name)
+ }
"Loose and Sticky annotation on a node" should "pass through" in {
val input: String =
@@ -145,7 +156,6 @@ class AnnotationTests extends AnnotationSpec with Matchers {
val deleted = result.deletedAnnotations
exception.str should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted")
}
-
"Renaming" should "propagate in Lowering of memories" in {
val compiler = new VerilogCompiler
// Uncomment to help debugging failing tests
@@ -165,7 +175,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| m.r.en <= UInt(1)
| m.r.addr <= in
|""".stripMargin
- val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"))
+ val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"),
+ dontTouch("Top.m"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("m_a", "mem"))
@@ -179,7 +190,6 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should not contain (anno("m"))
resultAnno should not contain (anno("r"))
}
-
"Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in {
val compiler = new VerilogCompiler
Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug))
@@ -191,7 +201,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| cmem m: {a: UInt<4>, b: UInt<4>[2]}[8]
| read mport r = m[in], clk
|""".stripMargin
- val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"))
+ val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"), dontTouch("Top.m"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("m_a", "mem"))
@@ -220,7 +230,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| x.a <= zero
| x.b <= zero
|""".stripMargin
- val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]"))
+ val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"),
+ anno("y[2]"), dontTouch("Top.x"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("x_a"))
@@ -260,7 +271,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
anno("w.a"), anno("w.b[0]"), anno("w.b[1]"),
anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
- anno("write.a"), anno("write.b[0]"), anno("write.b[1]")
+ anno("write.a"), anno("write.b[0]"), anno("write.b[1]"),
+ dontTouch("Top.r")
)
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
@@ -314,7 +326,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"))
+ val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_a"))
@@ -349,7 +361,8 @@ class AnnotationTests extends AnnotationSpec with Matchers {
| out <= n
| reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
|""".stripMargin
- val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"))
+ val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"),
+ dontTouch("Top.r"))
val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
val resultAnno = result.annotations.get.annotations
resultAnno should contain (anno("in_b_0"))
@@ -364,8 +377,7 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("r_b_1"))
}
-
- "Renaming" should "track dce" in {
+ "Renaming" should "track constprop + dce" in {
val compiler = new VerilogCompiler
val input =
"""circuit Top :
@@ -403,4 +415,49 @@ class AnnotationTests extends AnnotationSpec with Matchers {
resultAnno should contain (anno("out_b_0"))
resultAnno should contain (anno("out_b_1"))
}
+
+ ignore should "track deleted modules AND instances in dce" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Dead :
+ | input foo : UInt<8>
+ | output bar : UInt<8>
+ | bar <= foo
+ | extmodule DeadExt :
+ | input foo : UInt<8>
+ | output bar : UInt<8>
+ | module Top :
+ | input foo : UInt<8>
+ | output bar : UInt<8>
+ | inst d of Dead
+ | d.foo <= foo
+ | inst d2 of DeadExt
+ | d2.foo <= foo
+ | bar <= foo
+ |""".stripMargin
+ val annos = Seq(
+ OptimizableExtModuleAnnotation(ModuleName("DeadExt", CircuitName("Top"))),
+ manno("Dead"), manno("DeadExt"), manno("Top"),
+ anno("d"), anno("d2"),
+ anno("foo", mod = "Top"), anno("bar", mod = "Top"),
+ anno("foo", mod = "Dead"), anno("bar", mod = "Dead"),
+ anno("foo", mod = "DeadExt"), anno("bar", mod = "DeadExt")
+ )
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+
+ resultAnno should contain (manno("Top"))
+ resultAnno should contain (anno("foo", mod = "Top"))
+ resultAnno should contain (anno("bar", mod = "Top"))
+
+ resultAnno should not contain (manno("Dead"))
+ resultAnno should not contain (manno("DeadExt"))
+ resultAnno should not contain (anno("d"))
+ resultAnno should not contain (anno("d2"))
+ resultAnno should not contain (anno("foo", mod = "Dead"))
+ resultAnno should not contain (anno("bar", mod = "Dead"))
+ resultAnno should not contain (anno("foo", mod = "DeadExt"))
+ resultAnno should not contain (anno("bar", mod = "DeadExt"))
+ }
}
diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala
index c29a7e43..93a36f70 100644
--- a/src/test/scala/firrtlTests/AttachSpec.scala
+++ b/src/test/scala/firrtlTests/AttachSpec.scala
@@ -62,7 +62,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec {
| module A:
| input an: Analog<3>
| module B:
- | input an: Analog<3> """.stripMargin
+ | input an: Analog<3>""".stripMargin
val check =
"""module Attaching(
|);
@@ -70,16 +70,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec {
| A a (
| .an(_GEN_0)
| );
- | A b (
+ | B b (
| .an(_GEN_0)
| );
|endmodule
|module A(
| inout [2:0] an
|);
+ |module B(
+ | inout [2:0] an
+ |);
|endmodule
|""".stripMargin.split("\n") map normalized
- executeTest(input, check, compiler)
+ executeTest(input, check, compiler, Seq(dontTouch("A.an"), dontDedup("A")))
}
it should "attach a wire source" in {
@@ -101,7 +104,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec {
| );
|endmodule
|""".stripMargin.split("\n") map normalized
- executeTest(input, check, compiler)
+ executeTest(input, check, compiler, Seq(dontTouch("Attaching.x")))
}
it should "attach multiple sources" in {
diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala
new file mode 100644
index 00000000..deb73b3b
--- /dev/null
+++ b/src/test/scala/firrtlTests/DCETests.scala
@@ -0,0 +1,366 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl.ir.Circuit
+import firrtl._
+import firrtl.passes._
+import firrtl.transforms._
+import firrtl.annotations._
+import firrtl.passes.memlib.SimpleTransform
+
+class DCETests extends FirrtlFlatSpec {
+ // Not using executeTest because it is for positive testing, we need to check that stuff got
+ // deleted
+ private val customTransforms = Seq(
+ new LowFirrtlOptimization,
+ new SimpleTransform(RemoveEmpty, LowForm)
+ )
+ private def exec(input: String, check: String, annos: Seq[Annotation] = List.empty): Unit = {
+ val state = CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annos)))
+ val finalState = (new LowFirrtlCompiler).compileAndEmit(state, customTransforms)
+ val res = finalState.getEmittedCircuit.value
+ // Convert to sets for comparison
+ val resSet = Set(parse(res).serialize.split("\n"):_*)
+ val checkSet = Set(parse(check).serialize.split("\n"):_*)
+ resSet should be (checkSet)
+ }
+
+ "Unread wire" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | z <= x
+ | a <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Unread wire marked dont touch" should "NOT be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | z <= x
+ | a <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | z <= x
+ | a <= x""".stripMargin
+ exec(input, check, Seq(dontTouch("Top.a")))
+ }
+ "Unread register" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | reg a : UInt<1>, clk
+ | a <= x
+ | node y = asUInt(clk)
+ | z <= or(x, y)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node y = asUInt(clk)
+ | z <= or(x, y)""".stripMargin
+ exec(input, check)
+ }
+ "Unread node" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Unused ports" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Sub :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | inst sub of Sub
+ | sub.x <= x
+ | z <= sub.z""".stripMargin
+ val check =
+ """circuit Top :
+ | module Sub :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | input y : UInt<1>
+ | output z : UInt<1>
+ | inst sub of Sub
+ | sub.x <= x
+ | z <= sub.z""".stripMargin
+ exec(input, check)
+ }
+ "Chain of unread nodes" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | node b = or(a, a)
+ | node c = add(b, x)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Chain of unread wires and their connections" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | wire a : UInt<1>
+ | a <= x
+ | wire b : UInt<1>
+ | b <= a
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Read register" should "not be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | reg r : UInt<1>, clk
+ | r <= x
+ | z <= r""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | reg r : UInt<1>, clk with : (reset => (UInt<1>("h0"), r))
+ | r <= x
+ | z <= r""".stripMargin
+ exec(input, check)
+ }
+ "Logic that feeds into simulation constructs" should "not be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | stop(clk, a, 0)
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk : Clock
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | node a = not(x)
+ | z <= x
+ | stop(clk, a, 0)""".stripMargin
+ exec(input, check)
+ }
+ "Globally dead module" should "should be deleted" in {
+ val input =
+ """circuit Top :
+ | module Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Globally dead extmodule" should "NOT be deleted by default" in {
+ val input =
+ """circuit Top :
+ | extmodule Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | extmodule Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ exec(input, check)
+ }
+ "Globally dead extmodule marked optimizable" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | extmodule Dead :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst dead of Dead
+ | dead.x <= x
+ | z <= x""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x""".stripMargin
+ val doTouchAnno = OptimizableExtModuleAnnotation(ModuleName("Dead", CircuitName("Top")))
+ exec(input, check, Seq(doTouchAnno))
+ }
+ "Analog ports of extmodules" should "count as both inputs and outputs" in {
+ val input =
+ """circuit Top :
+ | extmodule BB1 :
+ | output bus : Analog<1>
+ | extmodule BB2 :
+ | output bus : Analog<1>
+ | output out : UInt<1>
+ | module Top :
+ | output out : UInt<1>
+ | inst bb1 of BB1
+ | inst bb2 of BB2
+ | attach (bb1.bus, bb2.bus)
+ | out <= bb2.out
+ """.stripMargin
+ exec(input, input)
+ }
+ // bar.z is not used and thus is dead code, but foo.z is used so this code isn't eliminated
+ "Module deduplication" should "should be preserved despite unused output of ONE instance" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | output z : UInt<1>
+ | y <= not(x)
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | inst bar of Child
+ | foo.x <= x
+ | bar.x <= x
+ | node t0 = or(foo.y, foo.z)
+ | z <= or(t0, bar.y)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output y : UInt<1>
+ | output z : UInt<1>
+ | y <= not(x)
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | inst bar of Child
+ | foo.x <= x
+ | bar.x <= x
+ | node t0 = or(foo.y, foo.z)
+ | z <= or(t0, bar.y)""".stripMargin
+ exec(input, check)
+ }
+ // This currently does NOT work
+ behavior of "Single dead instances"
+ ignore should "should be deleted" in {
+ val input =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | inst bar of Child
+ | foo.x <= x
+ | bar.x <= x
+ | z <= foo.z""".stripMargin
+ val check =
+ """circuit Top :
+ | module Child :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | z <= x
+ | module Top :
+ | input x : UInt<1>
+ | output z : UInt<1>
+ | inst foo of Child
+ | skip
+ | foo.x <= x
+ | skip
+ | z <= foo.z""".stripMargin
+ exec(input, check)
+ }
+}
diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala
index f77b47f3..a45af8c7 100644
--- a/src/test/scala/firrtlTests/FirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/FirrtlSpec.scala
@@ -12,7 +12,8 @@ import scala.io.Source
import firrtl._
import firrtl.Parser.IgnoreInfo
-import firrtl.annotations
+import firrtl.annotations._
+import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation}
import firrtl.util.BackendCompilationUtilities
trait FirrtlRunners extends BackendCompilationUtilities {
@@ -82,6 +83,16 @@ trait FirrtlRunners extends BackendCompilationUtilities {
}
trait FirrtlMatchers extends Matchers {
+ def dontTouch(path: String): Annotation = {
+ val parts = path.split('.')
+ require(parts.size >= 2, "Must specify both module and component!")
+ val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top")))
+ DontTouchAnnotation(name)
+ }
+ def dontDedup(mod: String): Annotation = {
+ require(mod.split('.').size == 1, "Can only specify a Module, not a component or instance")
+ NoDedupAnnotation(ModuleName(mod, CircuitName("Top")))
+ }
// Replace all whitespace with a single space and remove leading and
// trailing whitespace
// Note this is intended for single-line strings, no newlines
@@ -94,8 +105,13 @@ trait FirrtlMatchers extends Matchers {
* compiler will be run on input then emitted result will each be split into
* lines and normalized.
*/
- def executeTest(input: String, expected: Seq[String], compiler: Compiler) = {
- val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm))
+ def executeTest(
+ input: String,
+ expected: Seq[String],
+ compiler: Compiler,
+ annotations: Seq[Annotation] = Seq.empty) = {
+ val annoMap = AnnotationMap(annotations)
+ val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annoMap)))
val lines = finalState.getEmittedCircuit.value split "\n" map normalized
for (e <- expected) {
lines should contain (e)
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 8367f152..25f845bc 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -5,6 +5,7 @@ package firrtlTests
import firrtl._
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms._
import firrtl.passes.memlib._
import annotations._
@@ -21,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new SeqTransform {
def inputForm = LowForm
def outputForm = LowForm
- def transforms = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty)
+ def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty)
}
)
@@ -199,7 +200,7 @@ circuit CustomMemory :
smem mem_1 : UInt<16>[7]
read mport _T_17 = mem_0[io.rAddr], clock
read mport _T_19 = mem_1[io.rAddr], clock
- io.dO <= _T_17
+ io.dO <= and(_T_17, _T_19)
when io.wEn :
write mport _T_18 = mem_0[io.wAddr], clock
write mport _T_20 = mem_1[io.wAddr], clock
@@ -218,7 +219,7 @@ circuit CustomMemory :
case e: ExtModule => true
case _ => false
}
- require(numExtMods == 2)
+ numExtMods should be (2)
(new java.io.File(confLoc)).delete()
}
@@ -237,7 +238,7 @@ circuit CustomMemory :
read mport _T_17 = mem_0[io.rAddr], clock
read mport _T_19 = mem_1[io.rAddr], clock
read mport _T_21 = mem_2[io.rAddr], clock
- io.dO <= _T_17
+ io.dO <= and(_T_17, and(_T_19, _T_21))
when io.wEn :
write mport _T_18 = mem_0[io.wAddr], clock
write mport _T_20 = mem_1[io.wAddr], clock
@@ -258,7 +259,7 @@ circuit CustomMemory :
case e: ExtModule => true
case _ => false
}
- require(numExtMods == 2)
+ numExtMods should be (2)
(new java.io.File(confLoc)).delete()
}