aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/FlattenRegUpdate.scala96
-rw-r--r--src/test/scala/firrtlTests/RegisterUpdateSpec.scala100
2 files changed, 169 insertions, 27 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
index b272f134..a2399b5a 100644
--- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
+++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala
@@ -7,7 +7,7 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.Utils._
import firrtl.options.Dependency
-import firrtl.InfoExpr.orElse
+import firrtl.InfoExpr.{orElse, unwrap}
import scala.collection.mutable
@@ -58,38 +58,80 @@ object FlattenRegUpdate {
* @return [[firrtl.ir.Module Module]] with register updates flattened
*/
def flattenReg(mod: Module): Module = {
- // We want to flatten Mux trees for reg updates into if-trees for
- // improved QoR for conditional updates. However, unbounded recursion
- // would take exponential time, so don't redundantly flatten the same
- // Mux more than a bounded number of times, preserving linear runtime.
- // The threshold is empirical but ample.
- val flattenThreshold = 4
- val numTimesFlattened = mutable.HashMap[Mux, Int]()
- def canFlatten(m: Mux): Boolean = {
- val n = numTimesFlattened.getOrElse(m, 0)
- numTimesFlattened(m) = n + 1
- n < flattenThreshold
- }
+ // We want to flatten Mux trees for reg updates into if-trees for improved QoR for conditional
+ // updates. Sometimes the fan-in for a register has a mux structure with repeated
+ // sub-expressions that are themselves complex mux structures. These repeated structures can
+ // cause explosions in the size and complexity of the Verilog. In addition, user code that
+ // follows such structure often will have conditions in the sub-trees that are mutually
+ // exclusive with the conditions in the muxes closer to the register input. For example:
+ //
+ // when a : ; when 1
+ // r <= foo
+ // when b : ; when 2
+ // when a :
+ // r <= bar ; when 3
+ //
+ // After expand whens, when 1 is a common sub-expression that will show up twice in the mux
+ // structure from when 2:
+ //
+ // _GEN_0 = mux(a, foo, r)
+ // _GEN_1 = mux(a, bar, _GEN_0)
+ // r <= mux(b, _GEN_1, _GEN_0)
+ //
+ // Inlining _GEN_0 into _GEN_1 would result in unreachable lines in the Verilog. While we could
+ // do some optimizations here, this is *not* really a problem, it's just that Verilog metrics
+ // are based on the assumption of human-written code and as such it results in unreachable
+ // lines. Simply not inlining avoids this issue and leaves the optimizations up to synthesis
+ // tools which do a great job here.
+ val maxDepth = 4
val regUpdates = mutable.ArrayBuffer.empty[Connect]
val netlist = buildNetlist(mod)
- def constructRegUpdate(e: Expression): (Info, Expression) = {
- import InfoExpr.unwrap
- // Only walk netlist for nodes and wires, NOT registers or other state
- val (info, expr) = kind(e) match {
- case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e))
- case _ => unwrap(e)
+ // First traversal marks expression that would be inlined multiple times as endpoints
+ // Note that we could traverse more than maxDepth times - this corresponds to an expression that
+ // is already a very deeply nested mux
+ def determineEndpoints(expr: Expression): collection.Set[WrappedExpression] = {
+ val seen = mutable.HashSet.empty[WrappedExpression]
+ val endpoint = mutable.HashSet.empty[WrappedExpression]
+ def rec(depth: Int)(e: Expression): Unit = {
+ val (_, ex) = kind(e) match {
+ case NodeKind | WireKind if depth < maxDepth && !seen(e) =>
+ seen += e
+ unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
+ }
+ ex match {
+ case Mux(_, tval, fval, _) =>
+ rec(depth + 1)(tval)
+ rec(depth + 1)(fval)
+ case _ =>
+ // Mark e not ex because original reference is the endpoint, not op or whatever
+ endpoint += ex
+ }
}
- expr match {
- case mux: Mux if canFlatten(mux) =>
- val (tinfo, tvalx) = constructRegUpdate(mux.tval)
- val (finfo, fvalx) = constructRegUpdate(mux.fval)
- val infox = combineInfos(info, tinfo, finfo)
- (infox, mux.copy(tval = tvalx, fval = fvalx))
- // Return the original expression to end flattening
- case _ => unwrap(e)
+ rec(0)(expr)
+ endpoint
+ }
+
+ def constructRegUpdate(start: Expression): (Info, Expression) = {
+ val endpoints = determineEndpoints(start)
+ def rec(e: Expression): (Info, Expression) = {
+ val (info, expr) = kind(e) match {
+ case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e))
+ case _ => unwrap(e)
+ }
+ expr match {
+ case Mux(cond, tval, fval, tpe) =>
+ val (tinfo, tvalx) = rec(tval)
+ val (finfo, fvalx) = rec(fval)
+ val infox = combineInfos(info, tinfo, finfo)
+ (infox, Mux(cond, tvalx, fvalx, tpe))
+ // Return the original expression to end flattening
+ case _ => unwrap(e)
+ }
}
+ rec(start)
}
def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match {
diff --git a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala
new file mode 100644
index 00000000..dfef5955
--- /dev/null
+++ b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala
@@ -0,0 +1,100 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.transforms.FlattenRegUpdate
+import firrtl.annotations.NoTargetAnnotation
+import firrtl.stage.transforms.Compiler
+import firrtl.options.Dependency
+import firrtl.testutils._
+import FirrtlCheckers._
+import scala.util.matching.Regex
+
+object RegisterUpdateSpec {
+ case class CaptureStateAnno(value: CircuitState) extends NoTargetAnnotation
+ // Capture the CircuitState between FlattenRegUpdate and VerilogEmitter
+ //Emit captured state as FIRRTL for use in testing
+ class CaptureCircuitState extends Transform with DependencyAPIMigration {
+ override def prerequisites = Dependency[FlattenRegUpdate] :: Nil
+ override def optionalPrerequisiteOf = Dependency[VerilogEmitter] :: Nil
+ override def invalidates(a: Transform): Boolean = false
+ def execute(state: CircuitState): CircuitState = {
+ val emittedAnno = EmittedFirrtlCircuitAnnotation(
+ EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir"))
+ val capturedState = state.copy(annotations = emittedAnno +: state.annotations)
+ state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations)
+ }
+ }
+}
+
+class RegisterUpdateSpec extends FirrtlFlatSpec {
+ import RegisterUpdateSpec._
+ def compile(input: String): CircuitState = {
+ val compiler = new Compiler(Seq(Dependency[CaptureCircuitState], Dependency[VerilogEmitter]))
+ compiler.execute(CircuitState(parse(input), EmitCircuitAnnotation(classOf[VerilogEmitter]) :: Nil))
+ }
+ def compileBody(body: String) = {
+ val str = """
+ |circuit Test :
+ | module Test :
+ |""".stripMargin + body.split("\n").mkString(" ", "\n ", "")
+ compile(str)
+ }
+
+ "Register update logic" should "not duplicate common subtrees" in {
+ val result = compileBody(s"""
+ |input clock : Clock
+ |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>}
+ |reg r : UInt<8>, clock
+ |when io.a :
+ | r <= io.in
+ |when io.b :
+ | when io.c :
+ | r <= UInt(2)
+ |io.out <= r""".stripMargin
+ )
+ // Checking intermediate state between FlattenRegUpdate and Verilog emission
+ val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get
+ fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""")
+ // Checking the Verilog
+ val verilog = result.getEmittedCircuit.value
+ result shouldNot containLine ("r <= io_in;")
+ verilog shouldNot include ("if (io_a) begin")
+ result should containLine ("r <= _GEN_0;")
+ }
+
+ it should "not let duplicate subtrees on one register affect another" in {
+
+ val result = compileBody(s"""
+ |input clock : Clock
+ |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>}
+
+ |reg r : UInt<8>, clock
+ |reg r2 : UInt<8>, clock
+ |when io.a :
+ | r <= io.in
+ | r2 <= io.in
+ |when io.b :
+ | r2 <= UInt(3)
+ | when io.c :
+ | r <= UInt(2)
+ |io.out <= and(r, r2)""".stripMargin
+ )
+ // Checking intermediate state between FlattenRegUpdate and Verilog emission
+ val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get
+ fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""")
+ fstate should containLine ("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""")
+ // Checking the Verilog
+ val verilog = result.getEmittedCircuit.value
+ result shouldNot containLine ("r <= io_in;")
+ result should containLine ("r <= _GEN_0;")
+ result should containLine ("r2 <= io_in;")
+ verilog should include ("if (io_a) begin") // For r2
+ // 1 time for r2, old versions would have 3 occurences
+ Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be (1)
+ }
+
+}
+