diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/FlattenRegUpdate.scala | 96 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/RegisterUpdateSpec.scala | 100 |
2 files changed, 169 insertions, 27 deletions
diff --git a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala index b272f134..a2399b5a 100644 --- a/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala +++ b/src/main/scala/firrtl/transforms/FlattenRegUpdate.scala @@ -7,7 +7,7 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.Utils._ import firrtl.options.Dependency -import firrtl.InfoExpr.orElse +import firrtl.InfoExpr.{orElse, unwrap} import scala.collection.mutable @@ -58,38 +58,80 @@ object FlattenRegUpdate { * @return [[firrtl.ir.Module Module]] with register updates flattened */ def flattenReg(mod: Module): Module = { - // We want to flatten Mux trees for reg updates into if-trees for - // improved QoR for conditional updates. However, unbounded recursion - // would take exponential time, so don't redundantly flatten the same - // Mux more than a bounded number of times, preserving linear runtime. - // The threshold is empirical but ample. - val flattenThreshold = 4 - val numTimesFlattened = mutable.HashMap[Mux, Int]() - def canFlatten(m: Mux): Boolean = { - val n = numTimesFlattened.getOrElse(m, 0) - numTimesFlattened(m) = n + 1 - n < flattenThreshold - } + // We want to flatten Mux trees for reg updates into if-trees for improved QoR for conditional + // updates. Sometimes the fan-in for a register has a mux structure with repeated + // sub-expressions that are themselves complex mux structures. These repeated structures can + // cause explosions in the size and complexity of the Verilog. In addition, user code that + // follows such structure often will have conditions in the sub-trees that are mutually + // exclusive with the conditions in the muxes closer to the register input. For example: + // + // when a : ; when 1 + // r <= foo + // when b : ; when 2 + // when a : + // r <= bar ; when 3 + // + // After expand whens, when 1 is a common sub-expression that will show up twice in the mux + // structure from when 2: + // + // _GEN_0 = mux(a, foo, r) + // _GEN_1 = mux(a, bar, _GEN_0) + // r <= mux(b, _GEN_1, _GEN_0) + // + // Inlining _GEN_0 into _GEN_1 would result in unreachable lines in the Verilog. While we could + // do some optimizations here, this is *not* really a problem, it's just that Verilog metrics + // are based on the assumption of human-written code and as such it results in unreachable + // lines. Simply not inlining avoids this issue and leaves the optimizations up to synthesis + // tools which do a great job here. + val maxDepth = 4 val regUpdates = mutable.ArrayBuffer.empty[Connect] val netlist = buildNetlist(mod) - def constructRegUpdate(e: Expression): (Info, Expression) = { - import InfoExpr.unwrap - // Only walk netlist for nodes and wires, NOT registers or other state - val (info, expr) = kind(e) match { - case NodeKind | WireKind => unwrap(netlist.getOrElse(e, e)) - case _ => unwrap(e) + // First traversal marks expression that would be inlined multiple times as endpoints + // Note that we could traverse more than maxDepth times - this corresponds to an expression that + // is already a very deeply nested mux + def determineEndpoints(expr: Expression): collection.Set[WrappedExpression] = { + val seen = mutable.HashSet.empty[WrappedExpression] + val endpoint = mutable.HashSet.empty[WrappedExpression] + def rec(depth: Int)(e: Expression): Unit = { + val (_, ex) = kind(e) match { + case NodeKind | WireKind if depth < maxDepth && !seen(e) => + seen += e + unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) + } + ex match { + case Mux(_, tval, fval, _) => + rec(depth + 1)(tval) + rec(depth + 1)(fval) + case _ => + // Mark e not ex because original reference is the endpoint, not op or whatever + endpoint += ex + } } - expr match { - case mux: Mux if canFlatten(mux) => - val (tinfo, tvalx) = constructRegUpdate(mux.tval) - val (finfo, fvalx) = constructRegUpdate(mux.fval) - val infox = combineInfos(info, tinfo, finfo) - (infox, mux.copy(tval = tvalx, fval = fvalx)) - // Return the original expression to end flattening - case _ => unwrap(e) + rec(0)(expr) + endpoint + } + + def constructRegUpdate(start: Expression): (Info, Expression) = { + val endpoints = determineEndpoints(start) + def rec(e: Expression): (Info, Expression) = { + val (info, expr) = kind(e) match { + case NodeKind | WireKind if !endpoints(e) => unwrap(netlist.getOrElse(e, e)) + case _ => unwrap(e) + } + expr match { + case Mux(cond, tval, fval, tpe) => + val (tinfo, tvalx) = rec(tval) + val (finfo, fvalx) = rec(fval) + val infox = combineInfos(info, tinfo, finfo) + (infox, Mux(cond, tvalx, fvalx, tpe)) + // Return the original expression to end flattening + case _ => unwrap(e) + } } + rec(start) } def onStmt(stmt: Statement): Statement = stmt.map(onStmt) match { diff --git a/src/test/scala/firrtlTests/RegisterUpdateSpec.scala b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala new file mode 100644 index 00000000..dfef5955 --- /dev/null +++ b/src/test/scala/firrtlTests/RegisterUpdateSpec.scala @@ -0,0 +1,100 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.transforms.FlattenRegUpdate +import firrtl.annotations.NoTargetAnnotation +import firrtl.stage.transforms.Compiler +import firrtl.options.Dependency +import firrtl.testutils._ +import FirrtlCheckers._ +import scala.util.matching.Regex + +object RegisterUpdateSpec { + case class CaptureStateAnno(value: CircuitState) extends NoTargetAnnotation + // Capture the CircuitState between FlattenRegUpdate and VerilogEmitter + //Emit captured state as FIRRTL for use in testing + class CaptureCircuitState extends Transform with DependencyAPIMigration { + override def prerequisites = Dependency[FlattenRegUpdate] :: Nil + override def optionalPrerequisiteOf = Dependency[VerilogEmitter] :: Nil + override def invalidates(a: Transform): Boolean = false + def execute(state: CircuitState): CircuitState = { + val emittedAnno = EmittedFirrtlCircuitAnnotation( + EmittedFirrtlCircuit(state.circuit.main, state.circuit.serialize, ".fir")) + val capturedState = state.copy(annotations = emittedAnno +: state.annotations) + state.copy(annotations = CaptureStateAnno(capturedState) +: state.annotations) + } + } +} + +class RegisterUpdateSpec extends FirrtlFlatSpec { + import RegisterUpdateSpec._ + def compile(input: String): CircuitState = { + val compiler = new Compiler(Seq(Dependency[CaptureCircuitState], Dependency[VerilogEmitter])) + compiler.execute(CircuitState(parse(input), EmitCircuitAnnotation(classOf[VerilogEmitter]) :: Nil)) + } + def compileBody(body: String) = { + val str = """ + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + compile(str) + } + + "Register update logic" should "not duplicate common subtrees" in { + val result = compileBody(s""" + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + |reg r : UInt<8>, clock + |when io.a : + | r <= io.in + |when io.b : + | when io.c : + | r <= UInt(2) + |io.out <= r""".stripMargin + ) + // Checking intermediate state between FlattenRegUpdate and Verilog emission + val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get + fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + // Checking the Verilog + val verilog = result.getEmittedCircuit.value + result shouldNot containLine ("r <= io_in;") + verilog shouldNot include ("if (io_a) begin") + result should containLine ("r <= _GEN_0;") + } + + it should "not let duplicate subtrees on one register affect another" in { + + val result = compileBody(s""" + |input clock : Clock + |output io : { flip in : UInt<8>, flip a : UInt<1>, flip b : UInt<1>, flip c : UInt<1>, out : UInt<8>} + + |reg r : UInt<8>, clock + |reg r2 : UInt<8>, clock + |when io.a : + | r <= io.in + | r2 <= io.in + |when io.b : + | r2 <= UInt(3) + | when io.c : + | r <= UInt(2) + |io.out <= and(r, r2)""".stripMargin + ) + // Checking intermediate state between FlattenRegUpdate and Verilog emission + val fstate = result.annotations.collectFirst { case CaptureStateAnno(x) => x }.get + fstate should containLine ("""r <= mux(io_b, mux(io_c, UInt<8>("h2"), _GEN_0), _GEN_0)""") + fstate should containLine ("""r2 <= mux(io_b, UInt<8>("h3"), mux(io_a, io_in, r2))""") + // Checking the Verilog + val verilog = result.getEmittedCircuit.value + result shouldNot containLine ("r <= io_in;") + result should containLine ("r <= _GEN_0;") + result should containLine ("r2 <= io_in;") + verilog should include ("if (io_a) begin") // For r2 + // 1 time for r2, old versions would have 3 occurences + Regex.quote("if (io_a) begin").r.findAllMatchIn(verilog).size should be (1) + } + +} + |
