diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/DeadCodeElimination.scala | 296 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/OptimizationAnnotations.scala | 48 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AnnotationTests.scala | 81 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AttachSpec.scala | 11 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/DCETests.scala | 366 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/FirrtlSpec.scala | 22 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 11 |
9 files changed, 813 insertions, 25 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 84f237a3..b28dd1b2 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -105,7 +105,7 @@ class LowFirrtlOptimization extends CoreTransform { passes.ConstProp, passes.SplitExpressions, passes.CommonSubexpressionElimination, - passes.DeadCodeElimination) + new firrtl.transforms.DeadCodeElimination) } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index d2e1457c..47e0f321 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -45,6 +45,7 @@ case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) } object WSubField { def apply(expr: Expression, n: String): WSubField = new WSubField(expr, n, field_type(expr.tpe, n), UNKNOWNGENDER) + def apply(expr: Expression, name: String, tpe: Type): WSubField = new WSubField(expr, name, tpe, UNKNOWNGENDER) } case class WSubIndex(expr: Expression, value: Int, tpe: Type, gender: Gender) extends Expression { def serialize: String = s"${expr.serialize}[$value]" diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala new file mode 100644 index 00000000..5199276c --- /dev/null +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -0,0 +1,296 @@ + +package firrtl.transforms + +import firrtl._ +import firrtl.ir._ +import firrtl.passes._ +import firrtl.annotations._ +import firrtl.graph._ +import firrtl.analyses.InstanceGraph +import firrtl.Mappers._ +import firrtl.WrappedExpression._ +import firrtl.Utils.{throwInternalError, toWrappedExpression, kind} +import firrtl.MemoizedHash._ +import wiring.WiringUtils.getChildrenMap + +import collection.mutable +import java.io.{File, FileWriter} + +/** Dead Code Elimination (DCE) + * + * Performs DCE by constructing a global dependency graph starting with top-level outputs, external + * module ports, and simulation constructs as circuit sinks. External modules can optionally be + * eligible for DCE via the [[OptimizableExtModuleAnnotation]]. + * + * Dead code is eliminated across module boundaries. Wires, ports, registers, and memories are all + * eligible for removal. Components marked with a [[DontTouchAnnotation]] will be treated as a + * circuit sink and thus anything that drives such a marked component will NOT be removed. + * + * This transform preserves deduplication. All instances of a given [[DefModule]] are treated as + * the same individual module. Thus, while certain instances may have dead code due to the + * circumstances of their instantiation in their parent module, they will still not be removed. To + * remove such modules, use the [[NoDedupAnnotation]] to prevent deduplication. + */ +class DeadCodeElimination extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + /** Based on LogicNode ins CheckCombLoops, currently kind of faking it */ + private type LogicNode = MemoizedHash[WrappedExpression] + private object LogicNode { + def apply(moduleName: String, expr: Expression): LogicNode = + WrappedExpression(Utils.mergeRef(WRef(moduleName), expr)) + def apply(moduleName: String, name: String): LogicNode = apply(moduleName, WRef(name)) + def apply(component: ComponentName): LogicNode = { + // Currently only leaf nodes are supported TODO implement + val loweredName = LowerTypes.loweredName(component.name.split('.')) + apply(component.module.name, WRef(loweredName)) + } + /** External Modules are representated as a single node driven by all inputs and driving all + * outputs + */ + def apply(ext: ExtModule): LogicNode = LogicNode(ext.name, ext.name) + } + + /** Expression used to represent outputs in the circuit (# is illegal in names) */ + private val circuitSink = LogicNode("#Top", "#Sink") + + /** Extract all References and SubFields from a possibly nested Expression */ + def extractRefs(expr: Expression): Seq[Expression] = { + val refs = mutable.ArrayBuffer.empty[Expression] + def rec(e: Expression): Expression = { + e match { + case ref @ (_: WRef | _: WSubField) => refs += ref + case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested map rec + case ignore @ (_: Literal) => // Do nothing + case unexpected => throwInternalError + } + e + } + rec(expr) + refs + } + + // Gets all dependencies and constructs LogicNodes from them + private def getDepsImpl(mname: String, + instMap: collection.Map[String, String]) + (expr: Expression): Seq[LogicNode] = + extractRefs(expr).map { e => + if (kind(e) == InstanceKind) { + val (inst, tail) = Utils.splitRef(e) + LogicNode(instMap(inst.name), tail) + } else { + LogicNode(mname, e) + } + } + + + /** Construct the dependency graph within this module */ + private def setupDepGraph(depGraph: MutableDiGraph[LogicNode], + instMap: collection.Map[String, String]) + (mod: Module): Unit = { + def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) + + def onStmt(stmt: Statement): Unit = stmt match { + case DefRegister(_, name, _, clock, reset, init) => + val node = LogicNode(mod.name, name) + depGraph.addVertex(node) + Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref)) + case DefNode(_, name, value) => + val node = LogicNode(mod.name, name) + depGraph.addVertex(node) + getDeps(value).foreach(ref => depGraph.addEdge(node, ref)) + case DefWire(_, name, _) => + depGraph.addVertex(LogicNode(mod.name, name)) + case mem: DefMemory => + // Treat DefMems as a node with outputs depending on the node and node depending on inputs + // From perpsective of the module or instance, MALE expressions are inputs, FEMALE are outputs + val memRef = WRef(mem.name, MemPortUtils.memType(mem), ExpKind, FEMALE) + val exprs = Utils.create_exps(memRef).groupBy(Utils.gender(_)) + val sources = exprs.getOrElse(MALE, List.empty).flatMap(getDeps(_)) + val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_)) + val memNode = getDeps(memRef) match { case Seq(node) => node } + depGraph.addVertex(memNode) + sinks.foreach(sink => depGraph.addEdge(sink, memNode)) + sources.foreach(source => depGraph.addEdge(memNode, source)) + case Attach(_, exprs) => // Add edge between each expression + exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach { + case Seq(a, b) => + depGraph.addEdge(a, b) + depGraph.addEdge(b, a) + } + case Connect(_, loc, expr) => + // This match enforces the low Firrtl requirement of expanded connections + val node = getDeps(loc) match { case Seq(elt) => elt } + getDeps(expr).foreach(ref => depGraph.addEdge(node, ref)) + // Simulation constructs are treated as top-level outputs + case Stop(_,_, clk, en) => + Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) + case Print(_, _, args, clk, en) => + (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) + case Block(stmts) => stmts.foreach(onStmt(_)) + case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing + case other => throw new Exception(s"Unexpected Statement $other") + } + + // Add all ports as vertices + mod.ports.foreach { + case Port(_, name, _, _: GroundType) => depGraph.addVertex(LogicNode(mod.name, name)) + case other => throwInternalError + } + onStmt(mod.body) + } + + // TODO Make immutable? + private def createDependencyGraph(instMaps: collection.Map[String, collection.Map[String, String]], + doTouchExtMods: Set[String], + c: Circuit): MutableDiGraph[LogicNode] = { + val depGraph = new MutableDiGraph[LogicNode] + c.modules.foreach { + case mod: Module => setupDepGraph(depGraph, instMaps(mod.name))(mod) + case ext: ExtModule => + // Connect all inputs to all outputs + val node = LogicNode(ext) + ext.ports.foreach { + case Port(_, pname, _, AnalogType(_)) => + depGraph.addEdge(LogicNode(ext.name, pname), node) + depGraph.addEdge(node, LogicNode(ext.name, pname)) + case Port(_, pname, Output, _) => + val portNode = LogicNode(ext.name, pname) + depGraph.addEdge(portNode, node) + // Don't touch external modules *unless* they are specifically marked as doTouch + if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode) + case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname)) + } + } + // Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface) + val topModule = c.modules.find(_.name == c.main).get + val topOutputs = topModule.ports.foreach { port => + depGraph.addEdge(circuitSink, LogicNode(c.main, port.name)) + } + + depGraph + } + + private def deleteDeadCode(instMap: collection.Map[String, String], + deadNodes: Set[LogicNode], + moduleMap: collection.Map[String, DefModule], + renames: RenameMap) + (mod: DefModule): Option[DefModule] = { + def getDeps(expr: Expression): Seq[LogicNode] = getDepsImpl(mod.name, instMap)(expr) + + var emptyBody = true + renames.setModule(mod.name) + + def onStmt(stmt: Statement): Statement = { + val stmtx = stmt match { + case inst: WDefInstance => + moduleMap.get(inst.module) match { + case Some(instMod) => inst.copy(tpe = Utils.module_type(instMod)) + case None => + renames.delete(inst.name) + EmptyStmt + } + case decl: IsDeclaration => + val node = LogicNode(mod.name, decl.name) + if (deadNodes.contains(node)) { + renames.delete(decl.name) + EmptyStmt + } + else decl + case con: Connect => + val node = getDeps(con.loc) match { case Seq(elt) => elt } + if (deadNodes.contains(node)) EmptyStmt else con + case Attach(info, exprs) => // If any exprs are dead then all are + val dead = exprs.flatMap(getDeps(_)).forall(deadNodes.contains(_)) + if (dead) EmptyStmt else Attach(info, exprs) + case block: Block => block map onStmt + case other => other + } + stmtx match { // Check if module empty + case EmptyStmt | _: Block => + case other => emptyBody = false + } + stmtx + } + + val (deadPorts, portsx) = mod.ports.partition(p => deadNodes.contains(LogicNode(mod.name, p.name))) + deadPorts.foreach(p => renames.delete(p.name)) + + mod match { + case Module(info, name, _, body) => + val bodyx = onStmt(body) + if (emptyBody && portsx.isEmpty) None else Some(Module(info, name, portsx, bodyx)) + case ext: ExtModule => + if (portsx.isEmpty) None + else { + if (ext.ports != portsx) throwInternalError // Sanity check + Some(ext.copy(ports = portsx)) + } + } + + } + + def run(state: CircuitState, + dontTouches: Seq[LogicNode], + doTouchExtMods: Set[String]): CircuitState = { + val c = state.circuit + val moduleMap = c.modules.map(m => m.name -> m).toMap + val iGraph = new InstanceGraph(c) + val moduleDeps = iGraph.graph.edges.map { case (k,v) => + k.module -> v.map(i => i.name -> i.module).toMap + } + val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) + + val depGraph = { + val dGraph = createDependencyGraph(moduleDeps, doTouchExtMods, c) + for (dontTouch <- dontTouches) { + dGraph.getVertices.find(_ == dontTouch) match { + case Some(node) => dGraph.addEdge(circuitSink, node) + case None => + val (root, tail) = Utils.splitRef(dontTouch.e1) + DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize) + } + } + DiGraph(dGraph) + } + + val liveNodes = depGraph.reachableFrom(circuitSink) + circuitSink + val deadNodes = depGraph.getVertices -- liveNodes + val renames = RenameMap() + renames.setCircuit(c.main) + + // As we delete deadCode, we will delete ports from Modules and somtimes complete modules + // themselves. We iterate over the modules in a topological order from leaves to the top. The + // current status of the modulesxMap is used to either delete instances or update their types + val modulesxMap = mutable.HashMap.empty[String, DefModule] + topoSortedModules.foreach { case mod => + deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match { + case Some(m) => modulesxMap += m.name -> m + case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) + } + } + + // Preserve original module order + val newCircuit = c.copy(modules = c.modules.flatMap(m => modulesxMap.get(m.name))) + + state.copy(circuit = newCircuit, renames = Some(renames)) + } + + def execute(state: CircuitState): CircuitState = { + val (dontTouches: Seq[LogicNode], doTouchExtMods: Seq[String]) = + state.annotations match { + case Some(aMap) => + // TODO Do with single walk over annotations + val dontTouches = aMap.annotations.collect { + case DontTouchAnnotation(component) => LogicNode(component) + } + val optExtMods = aMap.annotations.collect { + case OptimizableExtModuleAnnotation(ModuleName(name, _)) => name + } + (dontTouches, optExtMods) + case None => (Seq.empty, Seq.empty) + } + run(state, dontTouches, doTouchExtMods.toSet) + } +} diff --git a/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala new file mode 100644 index 00000000..23723a60 --- /dev/null +++ b/src/main/scala/firrtl/transforms/OptimizationAnnotations.scala @@ -0,0 +1,48 @@ + +package firrtl +package transforms + +import firrtl.annotations._ +import firrtl.passes.PassException + +/** A component that should be preserved + * + * DCE treats the component as a top-level sink of the circuit + */ +object DontTouchAnnotation { + private val marker = "DONTtouch!" + def apply(target: ComponentName): Annotation = Annotation(target, classOf[Transform], marker) + + def unapply(a: Annotation): Option[ComponentName] = a match { + case Annotation(component: ComponentName, _, value) if value == marker => Some(component) + case _ => None + } + + class DontTouchNotFoundException(module: String, component: String) extends PassException( + s"Component marked DONT Touch ($module.$component) not found!\n" + + "Perhaps it is an aggregate type? Currently only leaf components are supported.\n" + + "Otherwise it was probably accidentally deleted. Please check that your custom passes are not" + + "responsible and then file an issue on Github." + ) + + def errorNotFound(module: String, component: String) = + throw new DontTouchNotFoundException(module, component) +} + +/** An [[firrtl.ir.ExtModule]] that can be optimized + * + * Firrtl does not know the semantics of an external module. This annotation provides some + * "greybox" information that the external module does not have any side effects. In particular, + * this means that the external module can be Dead Code Eliminated. + * + * @note Unlike [[DontTouchAnnotation]], we don't care if the annotation is deleted + */ +object OptimizableExtModuleAnnotation { + private val marker = "optimizableExtModule!" + def apply(target: ModuleName): Annotation = Annotation(target, classOf[Transform], marker) + + def unapply(a: Annotation): Option[ModuleName] = a match { + case Annotation(component: ModuleName, _, value) if value == marker => Some(component) + case _ => None + } +} diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 81c394c1..44964eba 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -7,8 +7,10 @@ import java.io.{File, FileWriter, Writer} import firrtl.annotations.AnnotationYamlProtocol._ import firrtl.annotations._ import firrtl._ +import firrtl.transforms.OptimizableExtModuleAnnotation import firrtl.passes.InlineAnnotation import firrtl.passes.memlib.PinAnnotation +import firrtl.transforms.DontTouchAnnotation import net.jcazevedo.moultingyaml._ import org.scalatest.Matchers import logger._ @@ -43,8 +45,17 @@ trait AnnotationSpec extends LowTransformSpec { class AnnotationTests extends AnnotationSpec with Matchers { def getAMap(a: Annotation): Option[AnnotationMap] = Some(AnnotationMap(Seq(a))) def getAMap(as: Seq[Annotation]): Option[AnnotationMap] = Some(AnnotationMap(as)) - def anno(s: String, value: String ="this is a value"): Annotation = - Annotation(ComponentName(s, ModuleName("Top", CircuitName("Top"))), classOf[Transform], value) + def anno(s: String, value: String ="this is a value", mod: String = "Top"): Annotation = + Annotation(ComponentName(s, ModuleName(mod, CircuitName("Top"))), classOf[Transform], value) + def manno(mod: String): Annotation = + Annotation(ModuleName(mod, CircuitName("Top")), classOf[Transform], "some value") + // TODO unify with FirrtlMatchers, problems with multiple definitions of parse + def dontTouch(path: String): Annotation = { + val parts = path.split('.') + require(parts.size >= 2, "Must specify both module and component!") + val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top"))) + DontTouchAnnotation(name) + } "Loose and Sticky annotation on a node" should "pass through" in { val input: String = @@ -145,7 +156,6 @@ class AnnotationTests extends AnnotationSpec with Matchers { val deleted = result.deletedAnnotations exception.str should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") } - "Renaming" should "propagate in Lowering of memories" in { val compiler = new VerilogCompiler // Uncomment to help debugging failing tests @@ -165,7 +175,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | m.r.en <= UInt(1) | m.r.addr <= in |""".stripMargin - val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem")) + val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"), + dontTouch("Top.m")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("m_a", "mem")) @@ -179,7 +190,6 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should not contain (anno("m")) resultAnno should not contain (anno("r")) } - "Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in { val compiler = new VerilogCompiler Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug)) @@ -191,7 +201,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | cmem m: {a: UInt<4>, b: UInt<4>[2]}[8] | read mport r = m[in], clk |""".stripMargin - val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem")) + val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"), dontTouch("Top.m")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("m_a", "mem")) @@ -220,7 +230,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | x.a <= zero | x.b <= zero |""".stripMargin - val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]")) + val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), + anno("y[2]"), dontTouch("Top.x")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("x_a")) @@ -260,7 +271,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), - anno("write.a"), anno("write.b[0]"), anno("write.b[1]") + anno("write.a"), anno("write.b[0]"), anno("write.b[1]"), + dontTouch("Top.r") ) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations @@ -314,7 +326,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r")) + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"), dontTouch("Top.r")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_a")) @@ -349,7 +361,8 @@ class AnnotationTests extends AnnotationSpec with Matchers { | out <= n | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk |""".stripMargin - val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b")) + val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"), + dontTouch("Top.r")) val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) val resultAnno = result.annotations.get.annotations resultAnno should contain (anno("in_b_0")) @@ -364,8 +377,7 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("r_b_1")) } - - "Renaming" should "track dce" in { + "Renaming" should "track constprop + dce" in { val compiler = new VerilogCompiler val input = """circuit Top : @@ -403,4 +415,49 @@ class AnnotationTests extends AnnotationSpec with Matchers { resultAnno should contain (anno("out_b_0")) resultAnno should contain (anno("out_b_1")) } + + ignore should "track deleted modules AND instances in dce" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Dead : + | input foo : UInt<8> + | output bar : UInt<8> + | bar <= foo + | extmodule DeadExt : + | input foo : UInt<8> + | output bar : UInt<8> + | module Top : + | input foo : UInt<8> + | output bar : UInt<8> + | inst d of Dead + | d.foo <= foo + | inst d2 of DeadExt + | d2.foo <= foo + | bar <= foo + |""".stripMargin + val annos = Seq( + OptimizableExtModuleAnnotation(ModuleName("DeadExt", CircuitName("Top"))), + manno("Dead"), manno("DeadExt"), manno("Top"), + anno("d"), anno("d2"), + anno("foo", mod = "Top"), anno("bar", mod = "Top"), + anno("foo", mod = "Dead"), anno("bar", mod = "Dead"), + anno("foo", mod = "DeadExt"), anno("bar", mod = "DeadExt") + ) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + + resultAnno should contain (manno("Top")) + resultAnno should contain (anno("foo", mod = "Top")) + resultAnno should contain (anno("bar", mod = "Top")) + + resultAnno should not contain (manno("Dead")) + resultAnno should not contain (manno("DeadExt")) + resultAnno should not contain (anno("d")) + resultAnno should not contain (anno("d2")) + resultAnno should not contain (anno("foo", mod = "Dead")) + resultAnno should not contain (anno("bar", mod = "Dead")) + resultAnno should not contain (anno("foo", mod = "DeadExt")) + resultAnno should not contain (anno("bar", mod = "DeadExt")) + } } diff --git a/src/test/scala/firrtlTests/AttachSpec.scala b/src/test/scala/firrtlTests/AttachSpec.scala index c29a7e43..93a36f70 100644 --- a/src/test/scala/firrtlTests/AttachSpec.scala +++ b/src/test/scala/firrtlTests/AttachSpec.scala @@ -62,7 +62,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | module A: | input an: Analog<3> | module B: - | input an: Analog<3> """.stripMargin + | input an: Analog<3>""".stripMargin val check = """module Attaching( |); @@ -70,16 +70,19 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | A a ( | .an(_GEN_0) | ); - | A b ( + | B b ( | .an(_GEN_0) | ); |endmodule |module A( | inout [2:0] an |); + |module B( + | inout [2:0] an + |); |endmodule |""".stripMargin.split("\n") map normalized - executeTest(input, check, compiler) + executeTest(input, check, compiler, Seq(dontTouch("A.an"), dontDedup("A"))) } it should "attach a wire source" in { @@ -101,7 +104,7 @@ class InoutVerilogSpec extends FirrtlFlatSpec { | ); |endmodule |""".stripMargin.split("\n") map normalized - executeTest(input, check, compiler) + executeTest(input, check, compiler, Seq(dontTouch("Attaching.x"))) } it should "attach multiple sources" in { diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala new file mode 100644 index 00000000..deb73b3b --- /dev/null +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -0,0 +1,366 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl.ir.Circuit +import firrtl._ +import firrtl.passes._ +import firrtl.transforms._ +import firrtl.annotations._ +import firrtl.passes.memlib.SimpleTransform + +class DCETests extends FirrtlFlatSpec { + // Not using executeTest because it is for positive testing, we need to check that stuff got + // deleted + private val customTransforms = Seq( + new LowFirrtlOptimization, + new SimpleTransform(RemoveEmpty, LowForm) + ) + private def exec(input: String, check: String, annos: Seq[Annotation] = List.empty): Unit = { + val state = CircuitState(parse(input), ChirrtlForm, Some(AnnotationMap(annos))) + val finalState = (new LowFirrtlCompiler).compileAndEmit(state, customTransforms) + val res = finalState.getEmittedCircuit.value + // Convert to sets for comparison + val resSet = Set(parse(res).serialize.split("\n"):_*) + val checkSet = Set(parse(check).serialize.split("\n"):_*) + resSet should be (checkSet) + } + + "Unread wire" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | z <= x + | a <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Unread wire marked dont touch" should "NOT be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | z <= x + | a <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | z <= x + | a <= x""".stripMargin + exec(input, check, Seq(dontTouch("Top.a"))) + } + "Unread register" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | reg a : UInt<1>, clk + | a <= x + | node y = asUInt(clk) + | z <= or(x, y)""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | node y = asUInt(clk) + | z <= or(x, y)""".stripMargin + exec(input, check) + } + "Unread node" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Unused ports" should "be deleted" in { + val input = + """circuit Top : + | module Sub : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | inst sub of Sub + | sub.x <= x + | z <= sub.z""".stripMargin + val check = + """circuit Top : + | module Sub : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | input y : UInt<1> + | output z : UInt<1> + | inst sub of Sub + | sub.x <= x + | z <= sub.z""".stripMargin + exec(input, check) + } + "Chain of unread nodes" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | node b = or(a, a) + | node c = add(b, x) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Chain of unread wires and their connections" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | wire a : UInt<1> + | a <= x + | wire b : UInt<1> + | b <= a + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Read register" should "not be deleted" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | reg r : UInt<1>, clk + | r <= x + | z <= r""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | reg r : UInt<1>, clk with : (reset => (UInt<1>("h0"), r)) + | r <= x + | z <= r""".stripMargin + exec(input, check) + } + "Logic that feeds into simulation constructs" should "not be deleted" in { + val input = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | stop(clk, a, 0) + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input clk : Clock + | input x : UInt<1> + | output z : UInt<1> + | node a = not(x) + | z <= x + | stop(clk, a, 0)""".stripMargin + exec(input, check) + } + "Globally dead module" should "should be deleted" in { + val input = + """circuit Top : + | module Dead : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + exec(input, check) + } + "Globally dead extmodule" should "NOT be deleted by default" in { + val input = + """circuit Top : + | extmodule Dead : + | input x : UInt<1> + | output z : UInt<1> + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + val check = + """circuit Top : + | extmodule Dead : + | input x : UInt<1> + | output z : UInt<1> + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + exec(input, check) + } + "Globally dead extmodule marked optimizable" should "be deleted" in { + val input = + """circuit Top : + | extmodule Dead : + | input x : UInt<1> + | output z : UInt<1> + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst dead of Dead + | dead.x <= x + | z <= x""".stripMargin + val check = + """circuit Top : + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | z <= x""".stripMargin + val doTouchAnno = OptimizableExtModuleAnnotation(ModuleName("Dead", CircuitName("Top"))) + exec(input, check, Seq(doTouchAnno)) + } + "Analog ports of extmodules" should "count as both inputs and outputs" in { + val input = + """circuit Top : + | extmodule BB1 : + | output bus : Analog<1> + | extmodule BB2 : + | output bus : Analog<1> + | output out : UInt<1> + | module Top : + | output out : UInt<1> + | inst bb1 of BB1 + | inst bb2 of BB2 + | attach (bb1.bus, bb2.bus) + | out <= bb2.out + """.stripMargin + exec(input, input) + } + // bar.z is not used and thus is dead code, but foo.z is used so this code isn't eliminated + "Module deduplication" should "should be preserved despite unused output of ONE instance" in { + val input = + """circuit Top : + | module Child : + | input x : UInt<1> + | output y : UInt<1> + | output z : UInt<1> + | y <= not(x) + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | inst bar of Child + | foo.x <= x + | bar.x <= x + | node t0 = or(foo.y, foo.z) + | z <= or(t0, bar.y)""".stripMargin + val check = + """circuit Top : + | module Child : + | input x : UInt<1> + | output y : UInt<1> + | output z : UInt<1> + | y <= not(x) + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | inst bar of Child + | foo.x <= x + | bar.x <= x + | node t0 = or(foo.y, foo.z) + | z <= or(t0, bar.y)""".stripMargin + exec(input, check) + } + // This currently does NOT work + behavior of "Single dead instances" + ignore should "should be deleted" in { + val input = + """circuit Top : + | module Child : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | inst bar of Child + | foo.x <= x + | bar.x <= x + | z <= foo.z""".stripMargin + val check = + """circuit Top : + | module Child : + | input x : UInt<1> + | output z : UInt<1> + | z <= x + | module Top : + | input x : UInt<1> + | output z : UInt<1> + | inst foo of Child + | skip + | foo.x <= x + | skip + | z <= foo.z""".stripMargin + exec(input, check) + } +} diff --git a/src/test/scala/firrtlTests/FirrtlSpec.scala b/src/test/scala/firrtlTests/FirrtlSpec.scala index f77b47f3..a45af8c7 100644 --- a/src/test/scala/firrtlTests/FirrtlSpec.scala +++ b/src/test/scala/firrtlTests/FirrtlSpec.scala @@ -12,7 +12,8 @@ import scala.io.Source import firrtl._ import firrtl.Parser.IgnoreInfo -import firrtl.annotations +import firrtl.annotations._ +import firrtl.transforms.{DontTouchAnnotation, NoDedupAnnotation} import firrtl.util.BackendCompilationUtilities trait FirrtlRunners extends BackendCompilationUtilities { @@ -82,6 +83,16 @@ trait FirrtlRunners extends BackendCompilationUtilities { } trait FirrtlMatchers extends Matchers { + def dontTouch(path: String): Annotation = { + val parts = path.split('.') + require(parts.size >= 2, "Must specify both module and component!") + val name = ComponentName(parts.tail.mkString("."), ModuleName(parts.head, CircuitName("Top"))) + DontTouchAnnotation(name) + } + def dontDedup(mod: String): Annotation = { + require(mod.split('.').size == 1, "Can only specify a Module, not a component or instance") + NoDedupAnnotation(ModuleName(mod, CircuitName("Top"))) + } // Replace all whitespace with a single space and remove leading and // trailing whitespace // Note this is intended for single-line strings, no newlines @@ -94,8 +105,13 @@ trait FirrtlMatchers extends Matchers { * compiler will be run on input then emitted result will each be split into * lines and normalized. */ - def executeTest(input: String, expected: Seq[String], compiler: Compiler) = { - val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm)) + def executeTest( + input: String, + expected: Seq[String], + compiler: Compiler, + annotations: Seq[Annotation] = Seq.empty) = { + val annoMap = AnnotationMap(annotations) + val finalState = compiler.compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annoMap))) val lines = finalState.getEmittedCircuit.value split "\n" map normalized for (e <- expected) { lines should contain (e) diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 8367f152..25f845bc 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -5,6 +5,7 @@ package firrtlTests import firrtl._ import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms._ import firrtl.passes.memlib._ import annotations._ @@ -21,7 +22,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new SeqTransform { def inputForm = LowForm def outputForm = LowForm - def transforms = Seq(ConstProp, CommonSubexpressionElimination, DeadCodeElimination, RemoveEmpty) + def transforms = Seq(ConstProp, CommonSubexpressionElimination, new DeadCodeElimination, RemoveEmpty) } ) @@ -199,7 +200,7 @@ circuit CustomMemory : smem mem_1 : UInt<16>[7] read mport _T_17 = mem_0[io.rAddr], clock read mport _T_19 = mem_1[io.rAddr], clock - io.dO <= _T_17 + io.dO <= and(_T_17, _T_19) when io.wEn : write mport _T_18 = mem_0[io.wAddr], clock write mport _T_20 = mem_1[io.wAddr], clock @@ -218,7 +219,7 @@ circuit CustomMemory : case e: ExtModule => true case _ => false } - require(numExtMods == 2) + numExtMods should be (2) (new java.io.File(confLoc)).delete() } @@ -237,7 +238,7 @@ circuit CustomMemory : read mport _T_17 = mem_0[io.rAddr], clock read mport _T_19 = mem_1[io.rAddr], clock read mport _T_21 = mem_2[io.rAddr], clock - io.dO <= _T_17 + io.dO <= and(_T_17, and(_T_19, _T_21)) when io.wEn : write mport _T_18 = mem_0[io.wAddr], clock write mport _T_20 = mem_1[io.wAddr], clock @@ -258,7 +259,7 @@ circuit CustomMemory : case e: ExtModule => true case _ => false } - require(numExtMods == 2) + numExtMods should be (2) (new java.io.File(confLoc)).delete() } |
