aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala96
1 files changed, 69 insertions, 27 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index b272f134..a2399b5a 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -7,7 +7,7 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency
-import firrtl.InfoExpr.orElse
+import firrtl.InfoExpr.{orElse, unwrap}
import scala.collection.mutable
@@ -58,38 +58,80 @@ object FlattenRegUpdate {
* @return [[firrtl.ir.Module Module]] with register updates flattened
*/
def flattenReg(mod: Module): Module = {
- // We want to flatten Mux trees for reg updates into if-trees for
- // improved QoR for conditional updates. However, unbounded recursion
- // would take exponential time, so don't redundantly flatten the same
- // Mux more than a bounded number of times, preserving linear runtime.
- // The threshold is empirical but ample.
- val flattenThreshold = 4
- val numTimesFlattened = mutable.HashMap[Mux, Int]()
- def canFlatten(m: Mux): Boolean = {
- val n = numTimesFlattened.getOrElse(m, 0)
- numTimesFlattened(m) = n + 1
- n < flattenThreshold
- }
+ // We want to flatten Mux trees for reg updates into if-trees for improved QoR for conditional
+ // updates. Sometimes the fan-in for a register has a mux structure with repeated
+ // sub-expressions that are themselves complex mux structures. These repeated structures can
+ // cause explosions in the size and complexity of the Verilog. In addition, user code that
+ // follows such structure often will have conditions in the sub-trees that are mutually
+ // exclusive with the conditions in the muxes closer to the register input. For example:
+ //
+ // when a : ; when 1
+ // r <= foo
+ // when b : ; when 2
+ // when a :
+ // r <= bar ; when 3
+ //
+ // After expand whens, when 1 is a common sub-expression that will show up twice in the mux
+ // structure from when 2:
+ //
+ // _GEN_0 = mux(a, foo, r)
+ // _GEN_1 = mux(a, bar, _GEN_0)
+ // r <= mux(b, _GEN_1, _GEN_0)
+ //
+ // Inlining _GEN_0 into _GEN_1 would result in unreachable lines in the Verilog. While we could
+ // do some optimizations here, this is *not* really a problem, it's just that Verilog metrics
+ // are based on the assumption of human-written code and as such it results in unreachable
+ // lines. Simply not inlining avoids this issue and leaves the optimizations up to synthesis
+ // tools which do a great job here.
+ val maxDepth = 4
val regUpdates = mutable.ArrayBuffer.empty[Connect]
val netlist = buildNetlist(mod)
- def constructRegUpdate(e: Expression): (Info, Expression) = {
- import InfoExpr.unwrap
- // Only walk netlist for nodes and wires, NOT registers or other state
- val (info, expr) = kind(e) match {
- case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e))
- case _ => unwrap(e)
+ // First traversal marks expression that would be inlined multiple times as endpoints
+ // Note that we could traverse more than maxDepth times - this corresponds to an expression that
+ // is already a very deeply nested mux
+ def determineEndpoints(expr: Expression): collection.Set[WrappedExpression] = {
+ val seen = mutable.HashSet.empty[WrappedExpression]
+ val endpoint = mutable.HashSet.empty[WrappedExpression]
+ def rec(depth: Int)(e: Expression): Unit = {
+ val (_, ex) = kind(e) match {
+ case NodeKind | WireKind if depth < maxDepth && !seen(e) =>
+ seen += e
+ unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
+ }
+ ex match {
+ case Mux(_, tval, fval, _) =>
+ rec(depth + 1)(tval)
+ rec(depth + 1)(fval)
+ case _ =>
+ // Mark e not ex because original reference is the endpoint, not op or whatever
+ endpoint += ex
+ }
}
- expr match {
- case mux: Mux if canFlatten(mux) =>
- val (tinfo, tvalx) = constructRegUpdate(mux.tval)
- val (finfo, fvalx) = constructRegUpdate(mux.fval)
- val infox = combineInfos(info, tinfo, finfo)
- (infox, mux.copy(tval = tvalx, fval = fvalx))
- // Return the original expression to end flattening
- case _ => unwrap(e)
+ rec(0)(expr)
+ endpoint
+ }
+
+ def constructRegUpdate(start: Expression): (Info, Expression) = {
+ val endpoints = determineEndpoints(start)
+ def rec(e: Expression): (Info, Expression) = {
+ val (info, expr) = kind(e) match {
+ case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
+ }
+ expr match {
+ case Mux(cond, tval, fval, tpe) =>
+ val (tinfo, tvalx) = rec(tval)
+ val (finfo, fvalx) = rec(fval)
+ val infox = combineInfos(info, tinfo, finfo)
+ (infox, Mux(cond, tvalx, fvalx, tpe))
+ // Return the original expression to end flattening
+ case _ => unwrap(e)
+ }
}
+ rec(start)
}
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {