diff options
| author | Jack Koenig | 2018-03-23 18:16:21 -0700 |
|---|---|---|
| committer | GitHub | 2018-03-23 18:16:21 -0700 |
| commit | f806b26ec377882f5adae43f101aa53e92b13f5c (patch) | |
| tree | 46b94ee2a3d9fabd4ff36bddb15052c2d2eba321 /src | |
| parent | ebb6847e9d01b424424ae11a0067448a4094e46d (diff) | |
Make Register Update Flattening a Transform and Delete Dangling Nodes (#692)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 79 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/DeadCodeElimination.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/FlattenRegUpdate.scala | 117 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/DCETests.scala | 23 |
4 files changed, 172 insertions, 55 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index a94ed37f..2c874392 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -12,6 +12,7 @@ import scala.io.Source import firrtl.ir._ import firrtl.passes._ +import firrtl.transforms.{DeadCodeElimination, FlattenRegUpdate} import firrtl.annotations._ import firrtl.Mappers._ import firrtl.PrimOps._ @@ -363,59 +364,31 @@ class VerilogEmitter extends SeqTransform with Emitter { assigns += Seq("assign ", e, " = ", rand_string(e.tpe), ";") assigns += Seq("`endif // RANDOMIZE_INVALID_ASSIGN") } - def update_and_reset(r: Expression, clk: Expression, reset: Expression, init: Expression) = { - // We want to flatten Mux trees for reg updates into if-trees for - // improved QoR for conditional updates. However, unbounded recursion - // would take exponential time, so don't redundantly flatten the same - // Mux more than a bounded number of times, preserving linear runtime. - // The threshold is empirical but ample. - val flattenThreshold = 4 - val numTimesFlattened = collection.mutable.HashMap[Mux, Int]() - def canFlatten(m: Mux) = { - val n = numTimesFlattened.getOrElse(m, 0) - numTimesFlattened(m) = n + 1 - n < flattenThreshold - } - - def addUpdate(e: Expression, tabs: String): Seq[Seq[Any]] = { - if (weq(e, r)) Nil // Don't bother emitting connection of register to itself - else { - // Only walk netlist for nodes and wires, NOT registers or other state - val expr = kind(e) match { - case NodeKind | WireKind => netlist.getOrElse(e, e) - case _ => e - } - expr match { - case m: Mux if canFlatten(m) => - if(m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") - val ifStatement = Seq(tabs, "if (", m.cond, ") begin") - val trueCase = addUpdate(m.tval, tabs + tab) - val elseStatement = Seq(tabs, "end else begin") - val ifNotStatement = Seq(tabs, "if (!(", m.cond, ")) begin") - val falseCase = addUpdate(m.fval, tabs + tab) - val endStatement = Seq(tabs, "end") - - ((trueCase.nonEmpty, falseCase.nonEmpty): @ unchecked) match { - case (true, true) => - ifStatement +: trueCase ++: elseStatement +: falseCase :+ endStatement - case (true, false) => - ifStatement +: trueCase :+ endStatement - case (false, true) => - ifNotStatement +: falseCase :+ endStatement - } - case _ => Seq(Seq(tabs, r, " <= ", e, ";")) - } + def regUpdate(r: Expression, clk: Expression) = { + def addUpdate(expr: Expression, tabs: String): Seq[Seq[Any]] = { + if (weq(expr, r)) Nil // Don't bother emitting connection of register to itself + else expr match { + case m: Mux => + if (m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") + def ifStatement = Seq(tabs, "if (", m.cond, ") begin") + val trueCase = addUpdate(m.tval, tabs + tab) + val elseStatement = Seq(tabs, "end else begin") + def ifNotStatement = Seq(tabs, "if (!(", m.cond, ")) begin") + val falseCase = addUpdate(m.fval, tabs + tab) + val endStatement = Seq(tabs, "end") + + ((trueCase.nonEmpty, falseCase.nonEmpty): @ unchecked) match { + case (true, true) => + ifStatement +: trueCase ++: elseStatement +: falseCase :+ endStatement + case (true, false) => + ifStatement +: trueCase :+ endStatement + case (false, true) => + ifNotStatement +: falseCase :+ endStatement + } + case e => Seq(Seq(tabs, r, " <= ", e, ";")) } } - - at_clock.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= { - val tv = init - val fv = netlist(r) - if (weq(tv, r)) - addUpdate(fv, "") - else - addUpdate(Mux(reset, tv, fv, mux_type_and_widths(tv, fv)), "") - } + at_clock.getOrElseUpdate(clk, ArrayBuffer[Seq[Any]]()) ++= addUpdate(netlist(r), "") } def update(e: Expression, value: Expression, clk: Expression, en: Expression, info: Info) = { @@ -519,7 +492,7 @@ class VerilogEmitter extends SeqTransform with Emitter { case sx: DefRegister => declare("reg", sx.name, sx.tpe, sx.info) val e = wref(sx.name, sx.tpe) - update_and_reset(e, sx.clock, sx.reset, sx.init) + regUpdate(e, sx.clock) initialize(e) sx case sx: DefNode => @@ -686,6 +659,8 @@ class VerilogEmitter extends SeqTransform with Emitter { /** Preamble for every emitted Verilog file */ def transforms = Seq( + new FlattenRegUpdate, + new DeadCodeElimination, passes.VerilogModulusCleanup, passes.VerilogWrap, passes.VerilogRename, diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index 8b6b5c85..ecfa7393 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -178,7 +178,8 @@ class DeadCodeElimination extends Transform { private def deleteDeadCode(instMap: collection.Map[String, String], deadNodes: collection.Set[LogicNode], moduleMap: collection.Map[String, DefModule], - renames: RenameMap) + renames: RenameMap, + topName: String) (mod: DefModule): Option[DefModule] = { // For log-level debug def deleteMsg(decl: IsDeclaration): String = { @@ -249,7 +250,8 @@ class DeadCodeElimination extends Transform { mod match { case Module(info, name, _, body) => val bodyx = onStmt(body) - if (emptyBody && portsx.isEmpty) { + // We don't delete the top module, even if it's empty + if (emptyBody && portsx.isEmpty && name != topName) { logger.debug(deleteMsg(mod)) None } else { @@ -307,7 +309,7 @@ class DeadCodeElimination extends Transform { // current status of the modulesxMap is used to either delete instances or update their types val modulesxMap = mutable.HashMap.empty[String, DefModule] topoSortedModules.foreach { case mod => - deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames)(mod) match { + deleteDeadCode(moduleDeps(mod.name), deadNodes, modulesxMap, renames, c.main)(mod) match { case Some(m) => modulesxMap += m.name -> m case None => renames.delete(ModuleName(mod.name, CircuitName(c.main))) } diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala new file mode 100644 index 00000000..07cb9cb5 --- /dev/null +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -0,0 +1,117 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.Utils._ + +import scala.collection.mutable + +object FlattenRegUpdate { + + /** Mapping from references to the [[Expression]]s that drive them */ + type Netlist = mutable.HashMap[WrappedExpression, Expression] + + /** Build a [[Netlist]] from a Module's connections and Nodes + * + * This assumes [[LowForm]] + * + * @param mod [[Module]] from which to build a [[Netlist]] + * @return [[Netlist]] of the module's connections and nodes + */ + def buildNetlist(mod: Module): Netlist = { + val netlist = new Netlist() + def onStmt(stmt: Statement): Statement = { + stmt.map(onStmt) match { + case Connect(_, lhs, rhs) => + netlist(lhs) = rhs + case DefNode(_, nname, rhs) => + netlist(WRef(nname)) = rhs + case _: IsInvalid => throwInternalError(Some("Unexpected IsInvalid, should have been removed by now")) + case _ => // Do nothing + } + stmt + } + mod.map(onStmt) + netlist + } + + /** Flatten Register Updates + * + * Constructs nested mux trees (up to a certain arbitrary threshold) for register updates. This + * can result in dead code that this function does NOT remove. + * + * @param mod [[Module]] to transform + * @return [[Module]] with register updates flattened + */ + def flattenReg(mod: Module): Module = { + // We want to flatten Mux trees for reg updates into if-trees for + // improved QoR for conditional updates. However, unbounded recursion + // would take exponential time, so don't redundantly flatten the same + // Mux more than a bounded number of times, preserving linear runtime. + // The threshold is empirical but ample. + val flattenThreshold = 4 + val numTimesFlattened = mutable.HashMap[Mux, Int]() + def canFlatten(m: Mux): Boolean = { + val n = numTimesFlattened.getOrElse(m, 0) + numTimesFlattened(m) = n + 1 + n < flattenThreshold + } + + val regUpdates = mutable.ArrayBuffer.empty[Connect] + val netlist = buildNetlist(mod) + + def constructRegUpdate(e: Expression): Expression = { + // Only walk netlist for nodes and wires, NOT registers or other state + val expr = kind(e) match { + case NodeKind | WireKind => netlist.getOrElse(e, e) + case _ => e + } + expr match { + case mux: Mux if canFlatten(mux) => + val tvalx = constructRegUpdate(mux.tval) + val fvalx = constructRegUpdate(mux.fval) + mux.copy(tval = tvalx, fval = fvalx) + // Return the original expression to end flattening + case _ => e + } + } + + def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { + case reg @ DefRegister(_, rname, _,_, resetCond, _) => + assert(resetCond == Utils.zero, "Register reset should have already been made explicit!") + val ref = WRef(reg) + val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref))) + regUpdates += update + reg + // Remove connections to Registers so we preserve LowFirrtl single-connection semantics + case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt + case other => other + } + + val bodyx = onStmt(mod.body) + mod.copy(body = Block(bodyx +: regUpdates)) + } + +} + +/** Flatten register update + * + * This transform flattens register updates into a single expression on the rhs of connection to + * the register + */ +// TODO Preserve source locators +class FlattenRegUpdate extends Transform { + def inputForm = MidForm + def outputForm = MidForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map { + case mod: Module => FlattenRegUpdate.flattenReg(mod) + case ext: ExtModule => ext + } + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} diff --git a/src/test/scala/firrtlTests/DCETests.scala b/src/test/scala/firrtlTests/DCETests.scala index 97c1c146..b8345093 100644 --- a/src/test/scala/firrtlTests/DCETests.scala +++ b/src/test/scala/firrtlTests/DCETests.scala @@ -391,6 +391,29 @@ class DCETests extends FirrtlFlatSpec { | z <= foo.z""".stripMargin exec(input, check) } + + "Emitted Verilog" should "not contain dead \"register update\" code" in { + val input = parse( + """circuit test : + | module test : + | input clock : Clock + | input a : UInt<1> + | input x : UInt<8> + | output z : UInt<8> + | reg r : UInt, clock + | when a : + | r <= x + | z <= r""".stripMargin + ) + + val state = CircuitState(input, ChirrtlForm) + val result = (new VerilogCompiler).compileAndEmit(state, List.empty) + val verilog = result.getEmittedCircuit.value + // Check that mux is removed! + verilog shouldNot include regex ("""a \? x : r;""") + // Check for register update + verilog should include regex ("""(?m)if \(a\) begin\n\s*r <= x;\s*end""") + } } class DCECommandLineSpec extends FirrtlFlatSpec { |
