aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2020-07-31 15:33:05 -0700
committerGitHub2020-07-31 22:33:05 +0000
commit17279da1f9f07bbd690f248c454656a231af18ae (patch)
treeb508b8008b950daae11c0117c4d2e78e4bab4de3 /src
parent31132333d5c0cbef52035cf76b677edd9b208b5e (diff)
Avoid repeated inlining in FlattenRegUpdate (#1727)
* Avoid repeated inlining in FlattenRegUpdate When-else structure can lead to the same complex mux structure being the default on several branches in register update logic. When these are inlined, it can lead to artifical unreachable branches that show up as coverage holes in coverage of the emitted Verilog. This commit changes the inlining logic to prevent inlining any reference expression that shows up multiple times because this is a common indicator of the problematic case. * Add tests for improved register update logic emission * Improve FlattenRegUpdate comment and add more tests * [skip formal checks] ICache equivalence check verified locally
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala96
-rw-r--r--src/test/scala/firrtlTests/RegisterUpdateSpec.scala100
2 files changed, 169 insertions, 27 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index b272f134..a2399b5a 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -7,7 +7,7 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency
-import firrtl.InfoExpr.orElse
+import firrtl.InfoExpr.{orElse, unwrap}
import scala.collection.mutable
@@ -58,38 +58,80 @@ object FlattenRegUpdate {
* @return [[firrtl.ir.Module Module]] with register updates flattened
*/
def flattenReg(mod: Module): Module = {
- // We want to flatten Mux trees for reg updates into if-trees for
- // improved QoR for conditional updates. However, unbounded recursion
- // would take exponential time, so don't redundantly flatten the same
- // Mux more than a bounded number of times, preserving linear runtime.
- // The threshold is empirical but ample.
- val flattenThreshold = 4
- val numTimesFlattened = mutable.HashMap[Mux, Int]()
- def canFlatten(m: Mux): Boolean = {
- val n = numTimesFlattened.getOrElse(m, 0)
- numTimesFlattened(m) = n + 1
- n < flattenThreshold
- }
+ // We want to flatten Mux trees for reg updates into if-trees for improved QoR for conditional
+ // updates. Sometimes the fan-in for a register has a mux structure with repeated
+ // sub-expressions that are themselves complex mux structures. These repeated structures can
+ // cause explosions in the size and complexity of the Verilog. In addition, user code that
+ // follows such structure often will have conditions in the sub-trees that are mutually
+ // exclusive with the conditions in the muxes closer to the register input. For example:
+ //
+ // when a : ; when 1
+ // r <= foo
+ // when b : ; when 2
+ // when a :
+ // r <= bar ; when 3
+ //
+ // After expand whens, when 1 is a common sub-expression that will show up twice in the mux
+ // structure from when 2:
+ //
+ // _GEN_0 = mux(a, foo, r)
+ // _GEN_1 = mux(a, bar, _GEN_0)
+ // r <= mux(b, _GEN_1, _GEN_0)
+ //
+ // Inlining _GEN_0 into _GEN_1 would result in unreachable lines in the Verilog. While we could
+ // do some optimizations here, this is *not* really a problem, it's just that Verilog metrics
+ // are based on the assumption of human-written code and as such it results in unreachable
+ // lines. Simply not inlining avoids this issue and leaves the optimizations up to synthesis
+ // tools which do a great job here.
+ val maxDepth = 4
val regUpdates = mutable.ArrayBuffer.empty[Connect]
val netlist = buildNetlist(mod)
- def constructRegUpdate(e: Expression): (Info, Expression) = {
- import InfoExpr.unwrap
- // Only walk netlist for nodes and wires, NOT registers or other state
- val (info, expr) = kind(e) match {
- case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e))
- case _ => unwrap(e)
+ // First traversal marks expression that would be inlined multiple times as endpoints
+ // Note that we could traverse more than maxDepth times - this corresponds to an expression that
+ // is already a very deeply nested mux
+ def determineEndpoints(expr: Expression): collection.Set[WrappedExpression] = {
+ val seen = mutable.HashSet.empty[WrappedExpression]
+ val endpoint = mutable.HashSet.empty[WrappedExpression]
+ def rec(depth: Int)(e: Expression): Unit = {
+ val (_, ex) = kind(e) match {
+ case NodeKind | WireKind if depth < maxDepth && !seen(e) =>
+ seen += e
+ unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
+ }
+ ex match {
+ case Mux(_, tval, fval, _) =>
+ rec(depth + 1)(tval)
+ rec(depth + 1)(fval)
+ case _ =>
+ // Mark e not ex because original reference is the endpoint, not op or whatever
+ endpoint += ex
+ }
}
- expr match {
- case mux: Mux if canFlatten(mux) =>
- val (tinfo, tvalx) = constructRegUpdate(mux.tval)
- val (finfo, fvalx) = constructRegUpdate(mux.fval)
- val infox = combineInfos(info, tinfo, finfo)
- (infox, mux.copy(tval = tvalx, fval = fvalx))
- // Return the original expression to end flattening
- case _ => unwrap(e)
+ rec(0)(expr)
+ endpoint
+ }
+
+ def constructRegUpdate(start: Expression): (Info, Expression) = {
+ val endpoints = determineEndpoints(start)
+ def rec(e: Expression): (Info, Expression) = {
+ val (info, expr) = kind(e) match {
+ case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
+ }
+ expr match {
+ case Mux(cond, tval, fval, tpe) =>
+ val (tinfo, tvalx) = rec(tval)
+ val (finfo, fvalx) = rec(fval)
+ val infox = combineInfos(info, tinfo, finfo)
+ (infox, Mux(cond, tvalx, fvalx, tpe))
+ // Return the original expression to end flattening
+ case _ => unwrap(e)
+ }
}
+ rec(start)
}
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
diff --git a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala
new file mode 100644
index 00000000..dfef5955
--- /dev/null
+++ b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala
@@ -0,0 +1,100 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.transforms.FlattenRegUpdate
+import firrtl.annotations.NoTargetAnnotation
+import firrtl.stage.transforms.Compiler
+import firrtl.options.Dependency
+import firrtl.testutils._
+import FirrtlCheckers._
+import scala.util.matching.Regex
+
+object RegisterUpdateSpec {
+ case class CaptureStateAnno(value: CircuitState) extends NoTargetAnnotation
+ // Capture the CircuitState between FlattenRegUpdate and VerilogEmitter
+ //Emit captured state as FIRRTL for use in testing
+ class CaptureCircuitState extends Transform with DependencyAPIMigration {
+ override def prerequisites = Dependency[FlattenRegUpdate] :: Nil
+ override def optionalPrerequisiteOf = Dependency[VerilogEmitter] :: Nil
+ override def invalidates(a: Transform): Boolean = false
+ def execute(state: CircuitState): CircuitState = {
+ val emittedAnno = EmittedFirrtlCircuitAnnotation(
+ EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir"))
+ val capturedState = state.copy(annotations = emittedAnno +: state.annotations)
+ state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations)
+ }
+ }
+}
+
+class RegisterUpdateSpec extends FirrtlFlatSpec {
+ import RegisterUpdateSpec._
+ def compile(input: String): CircuitState = {
+ val compiler = new Compiler(Seq(Dependency[CaptureCircuitState], Dependency[VerilogEmitter]))
+ compiler.execute(CircuitState(parse(input), EmitCircuitAnnotation(classOf[VerilogEmitter]) :: Nil))
+ }
+ def compileBody(body: String) = {
+ val str = """
+ |circuit Test :
+ | module Test :
+ |""".stripMargin + body.split("\n").mkString(" ", "\n ", "")
+ compile(str)
+ }
+
+ "Register update logic" should "not duplicate common subtrees" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>}
+ |reg r : UInt<8>, clock
+ |when io.a :
+ | r <= io.in
+ |when io.b :
+ | when io.c :
+ | r <= UInt(2)
+ |io.out <= r""".stripMargin
+ )
+ // Checking intermediate state between FlattenRegUpdate and Verilog emission
+ val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get
+ fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""")
+ // Checking the Verilog
+ val verilog = result.getEmittedCircuit.value
+ result shouldNot containLine ("r <= io_in;")
+ verilog shouldNot include ("if (io_a) begin")
+ result should containLine ("r <= _GEN_0;")
+ }
+
+ it should "not let duplicate subtrees on one register affect another" in {
+
+ val result = compileBody(s"""
+ |input clock : Clock
+ |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>}
+
+ |reg r : UInt<8>, clock
+ |reg r2 : UInt<8>, clock
+ |when io.a :
+ | r <= io.in
+ | r2 <= io.in
+ |when io.b :
+ | r2 <= UInt(3)
+ | when io.c :
+ | r <= UInt(2)
+ |io.out <= and(r, r2)""".stripMargin
+ )
+ // Checking intermediate state between FlattenRegUpdate and Verilog emission
+ val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get
+ fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""")
+ fstate should containLine ("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""")
+ // Checking the Verilog
+ val verilog = result.getEmittedCircuit.value
+ result shouldNot containLine ("r <= io_in;")
+ result should containLine ("r <= _GEN_0;")
+ result should containLine ("r2 <= io_in;")
+ verilog should include ("if (io_a) begin") // For r2
+ // 1 time for r2, old versions would have 3 occurences
+ Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be (1)
+ }
+
+}
+