diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms/FlattenRegUpdate.scala')
| -rw-r--r-- | src/main/scala/firrtl/transforms/FlattenRegUpdate.scala | 117 |
1 files changed, 117 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala new file mode 100644 index 00000000..07cb9cb5 --- /dev/null +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -0,0 +1,117 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.Utils._ + +import scala.collection.mutable + +object FlattenRegUpdate { + + /** Mapping from references to the [[Expression]]s that drive them */ + type Netlist = mutable.HashMap[WrappedExpression, Expression] + + /** Build a [[Netlist]] from a Module's connections and Nodes + * + * This assumes [[LowForm]] + * + * @param mod [[Module]] from which to build a [[Netlist]] + * @return [[Netlist]] of the module's connections and nodes + */ + def buildNetlist(mod: Module): Netlist = { + val netlist = new Netlist() + def onStmt(stmt: Statement): Statement = { + stmt.map(onStmt) match { + case Connect(_, lhs, rhs) => + netlist(lhs) = rhs + case DefNode(_, nname, rhs) => + netlist(WRef(nname)) = rhs + case _: IsInvalid => throwInternalError(Some("Unexpected IsInvalid, should have been removed by now")) + case _ => // Do nothing + } + stmt + } + mod.map(onStmt) + netlist + } + + /** Flatten Register Updates + * + * Constructs nested mux trees (up to a certain arbitrary threshold) for register updates. This + * can result in dead code that this function does NOT remove. + * + * @param mod [[Module]] to transform + * @return [[Module]] with register updates flattened + */ + def flattenReg(mod: Module): Module = { + // We want to flatten Mux trees for reg updates into if-trees for + // improved QoR for conditional updates. However, unbounded recursion + // would take exponential time, so don't redundantly flatten the same + // Mux more than a bounded number of times, preserving linear runtime. + // The threshold is empirical but ample. + val flattenThreshold = 4 + val numTimesFlattened = mutable.HashMap[Mux, Int]() + def canFlatten(m: Mux): Boolean = { + val n = numTimesFlattened.getOrElse(m, 0) + numTimesFlattened(m) = n + 1 + n < flattenThreshold + } + + val regUpdates = mutable.ArrayBuffer.empty[Connect] + val netlist = buildNetlist(mod) + + def constructRegUpdate(e: Expression): Expression = { + // Only walk netlist for nodes and wires, NOT registers or other state + val expr = kind(e) match { + case NodeKind | WireKind => netlist.getOrElse(e, e) + case _ => e + } + expr match { + case mux: Mux if canFlatten(mux) => + val tvalx = constructRegUpdate(mux.tval) + val fvalx = constructRegUpdate(mux.fval) + mux.copy(tval = tvalx, fval = fvalx) + // Return the original expression to end flattening + case _ => e + } + } + + def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { + case reg @ DefRegister(_, rname, _,_, resetCond, _) => + assert(resetCond == Utils.zero, "Register reset should have already been made explicit!") + val ref = WRef(reg) + val update = Connect(NoInfo, ref, constructRegUpdate(netlist.getOrElse(ref, ref))) + regUpdates += update + reg + // Remove connections to Registers so we preserve LowFirrtl single-connection semantics + case Connect(_, lhs, _) if kind(lhs) == RegKind => EmptyStmt + case other => other + } + + val bodyx = onStmt(mod.body) + mod.copy(body = Block(bodyx +: regUpdates)) + } + +} + +/** Flatten register update + * + * This transform flattens register updates into a single expression on the rhs of connection to + * the register + */ +// TODO Preserve source locators +class FlattenRegUpdate extends Transform { + def inputForm = MidForm + def outputForm = MidForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map { + case mod: Module => FlattenRegUpdate.flattenReg(mod) + case ext: ExtModule => ext + } + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} |
