diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/transforms/FlattenRegUpdate.scala | 96 |
1 files changed, 69 insertions, 27 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index b272f134..a2399b5a 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -7,7 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.options.Dependency -import firrtl.InfoExpr.orElse +import firrtl.InfoExpr.{orElse, unwrap} import scala.collection.mutable @@ -58,38 +58,80 @@ object FlattenRegUpdate { * @return [[firrtl.ir.Module Module]] with register updates flattened */ def flattenReg(mod: Module): Module = { - // We want to flatten Mux trees for reg updates into if-trees for - // improved QoR for conditional updates. However, unbounded recursion - // would take exponential time, so don't redundantly flatten the same - // Mux more than a bounded number of times, preserving linear runtime. - // The threshold is empirical but ample. - val flattenThreshold = 4 - val numTimesFlattened = mutable.HashMap[Mux, Int]() - def canFlatten(m: Mux): Boolean = { - val n = numTimesFlattened.getOrElse(m, 0) - numTimesFlattened(m) = n + 1 - n < flattenThreshold - } + // We want to flatten Mux trees for reg updates into if-trees for improved QoR for conditional + // updates. Sometimes the fan-in for a register has a mux structure with repeated + // sub-expressions that are themselves complex mux structures. These repeated structures can + // cause explosions in the size and complexity of the Verilog. In addition, user code that + // follows such structure often will have conditions in the sub-trees that are mutually + // exclusive with the conditions in the muxes closer to the register input. For example: + // + // when a : ; when 1 + // r <= foo + // when b : ; when 2 + // when a : + // r <= bar ; when 3 + // + // After expand whens, when 1 is a common sub-expression that will show up twice in the mux + // structure from when 2: + // + // _GEN_0 = mux(a, foo, r) + // _GEN_1 = mux(a, bar, _GEN_0) + // r <= mux(b, _GEN_1, _GEN_0) + // + // Inlining _GEN_0 into _GEN_1 would result in unreachable lines in the Verilog. While we could + // do some optimizations here, this is *not* really a problem, it's just that Verilog metrics + // are based on the assumption of human-written code and as such it results in unreachable + // lines. Simply not inlining avoids this issue and leaves the optimizations up to synthesis + // tools which do a great job here. + val maxDepth = 4 val regUpdates = mutable.ArrayBuffer.empty[Connect] val netlist = buildNetlist(mod) - def constructRegUpdate(e: Expression): (Info, Expression) = { - import InfoExpr.unwrap - // Only walk netlist for nodes and wires, NOT registers or other state - val (info, expr) = kind(e) match { - case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e)) - case _ => unwrap(e) + // First traversal marks expression that would be inlined multiple times as endpoints + // Note that we could traverse more than maxDepth times - this corresponds to an expression that + // is already a very deeply nested mux + def determineEndpoints(expr: Expression): collection.Set[WrappedExpression] = { + val seen = mutable.HashSet.empty[WrappedExpression] + val endpoint = mutable.HashSet.empty[WrappedExpression] + def rec(depth: Int)(e: Expression): Unit = { + val (_, ex) = kind(e) match { + case NodeKind | WireKind if depth < maxDepth && !seen(e) => + seen += e + unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) + } + ex match { + case Mux(_, tval, fval, _) => + rec(depth + 1)(tval) + rec(depth + 1)(fval) + case _ => + // Mark e not ex because original reference is the endpoint, not op or whatever + endpoint += ex + } } - expr match { - case mux: Mux if canFlatten(mux) => - val (tinfo, tvalx) = constructRegUpdate(mux.tval) - val (finfo, fvalx) = constructRegUpdate(mux.fval) - val infox = combineInfos(info, tinfo, finfo) - (infox, mux.copy(tval = tvalx, fval = fvalx)) - // Return the original expression to end flattening - case _ => unwrap(e) + rec(0)(expr) + endpoint + } + + def constructRegUpdate(start: Expression): (Info, Expression) = { + val endpoints = determineEndpoints(start) + def rec(e: Expression): (Info, Expression) = { + val (info, expr) = kind(e) match { + case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) + } + expr match { + case Mux(cond, tval, fval, tpe) => + val (tinfo, tvalx) = rec(tval) + val (finfo, fvalx) = rec(fval) + val infox = combineInfos(info, tinfo, finfo) + (infox, Mux(cond, tvalx, fvalx, tpe)) + // Return the original expression to end flattening + case _ => unwrap(e) + } } + rec(start) } def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { |
