aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2018-03-23 18:16:21 -0700
committerGitHub2018-03-23 18:16:21 -0700
commitf806b26ec377882f5adae43f101aa53e92b13f5c (patch)
tree46b94ee2a3d9fabd4ff36bddb15052c2d2eba321 /src
parentebb6847e9d01b424424ae11a0067448a4094e46d (diff)
Make Register Update Flattening a Transform and Delete Dangling Nodes (#692)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala79
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala8
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala117
-rw-r--r--src/test/scala/firrtlTests/DCETests.scala23
4 files changed, 172 insertions, 55 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index a94ed37f..2c874392 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -12,6 +12,7 @@ import scala.io.Source
import firrtl.ir._
import firrtl.passes._
+import firrtl.transforms.{DeadCodeElimination, FlattenRegUpdate}
import firrtl.annotations._
import firrtl.Mappers._
import firrtl.PrimOps._
@@ -363,59 +364,31 @@ class VerilogEmitter extends SeqTransform with Emitter {
assigns += Seq("assign ", e, " = ", rand_string(e.tpe), ";")
assigns += Seq("`endif // RANDOMIZE_INVALID_ASSIGN")
}
- def update_and_reset(r: Expression, clk: Expression, reset: Expression, init: Expression) = {
- // We want to flatten Mux trees for reg updates into if-trees for
- // improved QoR for conditional updates. However, unbounded recursion
- // would take exponential time, so don't redundantly flatten the same
- // Mux more than a bounded number of times, preserving linear runtime.
- // The threshold is empirical but ample.
- val flattenThreshold = 4
- val numTimesFlattened = collection.mutable.HashMap[Mux, Int]()
- def canFlatten(m: Mux) = {
- val n = numTimesFlattened.getOrElse(m, 0)
- numTimesFlattened(m) = n + 1
- n < flattenThreshold
- }
-
- def addUpdate(e: Expression, tabs: String): Seq[Seq[Any]] = {
- if (weq(e, r)) Nil // Don't bother emitting connection of register to itself
- else {
- // Only walk netlist for nodes and wires, NOT registers or other state
- val expr = kind(e) match {
- case NodeKind | WireKind => netlist.getOrElse(e, e)
- case _ => e
- }
- expr match {
- case m: Mux if canFlatten(m) =>
- if(m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
- val ifStatement = Seq(tabs, "if (", m.cond, ") begin")
- val trueCase = addUpdate(m.tval, tabs + tab)
- val elseStatement = Seq(tabs, "end else begin")
- val ifNotStatement = Seq(tabs, "if (!(", m.cond, ")) begin")
- val falseCase = addUpdate(m.fval, tabs + tab)
- val endStatement = Seq(tabs, "end")
-
- ((trueCase.nonEmpty, falseCase.nonEmpty): @ unchecked) match {
- case (true, true) =>
- ifStatement +: trueCase ++: elseStatement +: falseCase :+ endStatement
- case (true, false) =>
- ifStatement +: trueCase :+ endStatement
- case (false, true) =>
- ifNotStatement +: falseCase :+ endStatement
- }
- case _ => Seq(Seq(tabs, r, " <= ", e, ";"))
- }
+ def regUpdate(r: Expression, clk: Expression) = {
+ def addUpdate(expr: Expression, tabs: String): Seq[Seq[Any]] = {
+ if (weq(expr, r)) Nil // Don't bother emitting connection of register to itself
+ else expr match {
+ case m: Mux =>
+ if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly")
+ def ifStatement = Seq(tabs, "if (", m.cond, ") begin")
+ val trueCase = addUpdate(m.tval, tabs + tab)
+ val elseStatement = Seq(tabs, "end else begin")
+ def ifNotStatement = Seq(tabs, "if (!(", m.cond, ")) begin")
+ val falseCase = addUpdate(m.fval, tabs + tab)
+ val endStatement = Seq(tabs, "end")
+
+ ((trueCase.nonEmpty, falseCase.nonEmpty): @ unchecked) match {
+ case (true, true) =>
+ ifStatement +: trueCase ++: elseStatement +: falseCase :+ endStatement
+ case (true, false) =>
+ ifStatement +: trueCase :+ endStatement
+ case (false, true) =>
+ ifNotStatement +: falseCase :+ endStatement
+ }
+ case e => Seq(Seq(tabs, r, " <= ", e, ";"))
}
}
-
- at_clock.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= {
- val tv = init
- val fv = netlist(r)
- if (weq(tv, r))
- addUpdate(fv, "")
- else
- addUpdate(Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), "")
- }
+ at_clock.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(netlist(r), "")
}
def update(e: Expression, value: Expression, clk: Expression, en: Expression, info: Info) = {
@@ -519,7 +492,7 @@ class VerilogEmitter extends SeqTransform with Emitter {
case sx: DefRegister =>
declare("reg", sx.name, sx.tpe, sx.info)
val e = wref(sx.name, sx.tpe)
- update_and_reset(e, sx.clock, sx.reset, sx.init)
+ regUpdate(e, sx.clock)
initialize(e)
sx
case sx: DefNode =>
@@ -686,6 +659,8 @@ class VerilogEmitter extends SeqTransform with Emitter {
/** Preamble for every emitted Verilog file */
def transforms = Seq(
+ new FlattenRegUpdate,
+ new DeadCodeElimination,
passes.VerilogModulusCleanup,
passes.VerilogWrap,
passes.VerilogRename,
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index 8b6b5c85..ecfa7393 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -178,7 +178,8 @@ class DeadCodeElimination extends Transform {
private def deleteDeadCode(instMap: collection.Map[String, String],
deadNodes: collection.Set[LogicNode],
moduleMap: collection.Map[String, DefModule],
- renames: RenameMap)
+ renames: RenameMap,
+ topName: String)
(mod: DefModule): Option[DefModule] = {
// For log-level debug
def deleteMsg(decl: IsDeclaration): String = {
@@ -249,7 +250,8 @@ class DeadCodeElimination extends Transform {
mod match {
case Module(info, name, _, body) =>
val bodyx = onStmt(body)
- if (emptyBody && portsx.isEmpty) {
+ // We don't delete the top module, even if it's empty
+ if (emptyBody && portsx.isEmpty && name != topName) {
logger.debug(deleteMsg(mod))
None
} else {
@@ -307,7 +309,7 @@ class DeadCodeElimination extends Transform {
// current status of the modulesxMap is used to either delete instances or update their types
val modulesxMap = mutable.HashMap.empty[String, DefModule]
topoSortedModules.foreach { case mod =>
- deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match {
+ deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main)(mod) match {
case Some(m) => modulesxMap += m.name -> m
case None => renames.delete(ModuleName(mod.name, CircuitName(c.main)))
}
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
new file mode 100644
index 00000000..07cb9cb5
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -0,0 +1,117 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.Utils._
+
+import scala.collection.mutable
+
+object FlattenRegUpdate {
+
+ /** Mapping from references to the [[Expression]]s that drive them */
+ type Netlist = mutable.HashMap[WrappedExpression, Expression]
+
+ /** Build a [[Netlist]] from a Module's connections and Nodes
+ *
+ * This assumes [[LowForm]]
+ *
+ * @param mod [[Module]] from which to build a [[Netlist]]
+ * @return [[Netlist]] of the module's connections and nodes
+ */
+ def buildNetlist(mod: Module): Netlist = {
+ val netlist = new Netlist()
+ def onStmt(stmt: Statement): Statement = {
+ stmt.map(onStmt) match {
+ case Connect(_, lhs, rhs) =>
+ netlist(lhs) = rhs
+ case DefNode(_, nname, rhs) =>
+ netlist(WRef(nname)) = rhs
+ case _: IsInvalid => throwInternalError(Some("Unexpected IsInvalid, should have been removed by now"))
+ case _ => // Do nothing
+ }
+ stmt
+ }
+ mod.map(onStmt)
+ netlist
+ }
+
+ /** Flatten Register Updates
+ *
+ * Constructs nested mux trees (up to a certain arbitrary threshold) for register updates. This
+ * can result in dead code that this function does NOT remove.
+ *
+ * @param mod [[Module]] to transform
+ * @return [[Module]] with register updates flattened
+ */
+ def flattenReg(mod: Module): Module = {
+ // We want to flatten Mux trees for reg updates into if-trees for
+ // improved QoR for conditional updates. However, unbounded recursion
+ // would take exponential time, so don't redundantly flatten the same
+ // Mux more than a bounded number of times, preserving linear runtime.
+ // The threshold is empirical but ample.
+ val flattenThreshold = 4
+ val numTimesFlattened = mutable.HashMap[Mux, Int]()
+ def canFlatten(m: Mux): Boolean = {
+ val n = numTimesFlattened.getOrElse(m, 0)
+ numTimesFlattened(m) = n + 1
+ n < flattenThreshold
+ }
+
+ val regUpdates = mutable.ArrayBuffer.empty[Connect]
+ val netlist = buildNetlist(mod)
+
+ def constructRegUpdate(e: Expression): Expression = {
+ // Only walk netlist for nodes and wires, NOT registers or other state
+ val expr = kind(e) match {
+ case NodeKind | WireKind => netlist.getOrElse(e, e)
+ case _ => e
+ }
+ expr match {
+ case mux: Mux if canFlatten(mux) =>
+ val tvalx = constructRegUpdate(mux.tval)
+ val fvalx = constructRegUpdate(mux.fval)
+ mux.copy(tval = tvalx, fval = fvalx)
+ // Return the original expression to end flattening
+ case _ => e
+ }
+ }
+
+ def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
+ case reg @ DefRegister(_, rname, _,_, resetCond, _) =>
+ assert(resetCond == Utils.zero, "Register reset should have already been made explicit!")
+ val ref = WRef(reg)
+ val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref)))
+ regUpdates += update
+ reg
+ // Remove connections to Registers so we preserve LowFirrtl single-connection semantics
+ case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt
+ case other => other
+ }
+
+ val bodyx = onStmt(mod.body)
+ mod.copy(body = Block(bodyx +: regUpdates))
+ }
+
+}
+
+/** Flatten register update
+ *
+ * This transform flattens register updates into a single expression on the rhs of connection to
+ * the register
+ */
+// TODO Preserve source locators
+class FlattenRegUpdate extends Transform {
+ def inputForm = MidForm
+ def outputForm = MidForm
+
+ def execute(state: CircuitState): CircuitState = {
+ val modulesx = state.circuit.modules.map {
+ case mod: Module => FlattenRegUpdate.flattenReg(mod)
+ case ext: ExtModule => ext
+ }
+ state.copy(circuit = state.circuit.copy(modules = modulesx))
+ }
+}
diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala
index 97c1c146..b8345093 100644
--- a/src/test/scala/firrtlTests/DCETests.scala
+++ b/src/test/scala/firrtlTests/DCETests.scala
@@ -391,6 +391,29 @@ class DCETests extends FirrtlFlatSpec {
| z <= foo.z""".stripMargin
exec(input, check)
}
+
+ "Emitted Verilog" should "not contain dead \"register update\" code" in {
+ val input = parse(
+ """circuit test :
+ | module test :
+ | input clock : Clock
+ | input a : UInt<1>
+ | input x : UInt<8>
+ | output z : UInt<8>
+ | reg r : UInt, clock
+ | when a :
+ | r <= x
+ | z <= r""".stripMargin
+ )
+
+ val state = CircuitState(input, ChirrtlForm)
+ val result = (new VerilogCompiler).compileAndEmit(state, List.empty)
+ val verilog = result.getEmittedCircuit.value
+ // Check that mux is removed!
+ verilog shouldNot include regex ("""a \? x : r;""")
+ // Check for register update
+ verilog should include regex ("""(?m)if \(a\) begin\n\s*r <= x;\s*end""")
+ }
}
class DCECommandLineSpec extends FirrtlFlatSpec {