diff options
| author | mergify[bot] | 2022-04-20 21:10:07 +0000 |
|---|---|---|
| committer | GitHub | 2022-04-20 21:10:07 +0000 |
| commit | a16a8a52a3b2d72d80a27434217aaeba7be2d3a8 (patch) | |
| tree | db51f76087a33c3ed2b72f449fe261fe7c3e7586 | |
| parent | 70da39e140e96a9302a94864f077529e02596ef5 (diff) | |
Generate a balanced tree with reduceTree (#2318) (#2499)
The difference in logic depth for various paths now has a maximum of 1.
Also make treeReduce order the same for 2.12 and 2.13
.grouped(_) returns an Iterator
.toSeq on an Iterator returns a Stream in 2.12 and a List in 2.13
This can lead to changes in order when bumping from 2.12 to 2.13 that
can be avoided by simply using an eager collection explicitly.
Co-authored-by: Jack Koenig <koenig@sifive.com>
(cherry picked from commit 6975f77f3325dec46c613552eac663c29011a67c)
Co-authored-by: Martin Schoeberl <martin@jopdesign.com>
| -rw-r--r-- | core/src/main/scala/chisel3/Aggregate.scala | 28 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/ReduceTreeSpec.scala | 106 |
2 files changed, 130 insertions, 4 deletions
diff --git a/core/src/main/scala/chisel3/Aggregate.scala b/core/src/main/scala/chisel3/Aggregate.scala index 06ae36f3..cc5b83d9 100644 --- a/core/src/main/scala/chisel3/Aggregate.scala +++ b/core/src/main/scala/chisel3/Aggregate.scala @@ -14,6 +14,7 @@ import chisel3.internal.Builder.pushCommand import chisel3.internal.firrtl._ import chisel3.internal.sourceinfo._ +import java.lang.Math.{floor, log10, pow} import scala.collection.mutable class AliasedAggregateFieldException(message: String) extends ChiselException(message) @@ -381,11 +382,30 @@ sealed class Vec[T <: Data] private[chisel3] (gen: => T, val length: Int) extend compileOptions: CompileOptions ): T = { require(!isEmpty, "Cannot apply reduction on a vec of size 0") - var curLayer: Seq[T] = this - while (curLayer.length > 1) { - curLayer = curLayer.grouped(2).map(x => if (x.length == 1) layerOp(x(0)) else redOp(x(0), x(1))).toSeq + + def recReduce[T](s: Seq[T], op: (T, T) => T, lop: (T) => T): T = { + + val n = s.length + n match { + case 1 => lop(s(0)) + case 2 => op(s(0), s(1)) + case _ => + val m = pow(2, floor(log10(n - 1) / log10(2))).toInt // number of nodes in next level, will be a power of 2 + val p = 2 * m - n // number of nodes promoted + + val l = s.take(p).map(lop) + val r = s + .drop(p) + .grouped(2) + .map { + case Seq(a, b) => op(a, b) + } + .toVector + recReduce(l ++ r, op, lop) + } } - curLayer(0) + + recReduce(this, redOp, layerOp) } /** Creates a Vec literal of this type with specified values. this must be a chisel type. diff --git a/src/test/scala/chiselTests/ReduceTreeSpec.scala b/src/test/scala/chiselTests/ReduceTreeSpec.scala new file mode 100644 index 00000000..3f078106 --- /dev/null +++ b/src/test/scala/chiselTests/ReduceTreeSpec.scala @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: Apache-2.0 + +package chiselTests + +import chisel3._ +import chisel3.util._ +import chisel3.testers.BasicTester + +class Arbiter[T <: Data: Manifest](n: Int, private val gen: T) extends Module { + val io = IO(new Bundle { + val in = Flipped(Vec(n, new DecoupledIO(gen))) + val out = new DecoupledIO(gen) + }) + + def arbitrateTwo(a: DecoupledIO[T], b: DecoupledIO[T]) = { + + val idleA :: idleB :: hasA :: hasB :: Nil = Enum(4) + val regData = Reg(gen) + val regState = RegInit(idleA) + val out = Wire(new DecoupledIO(gen)) + + a.ready := regState === idleA + b.ready := regState === idleB + out.valid := (regState === hasA || regState === hasB) + + switch(regState) { + is(idleA) { + when(a.valid) { + regData := a.bits + regState := hasA + }.otherwise { + regState := idleB + } + } + is(idleB) { + when(b.valid) { + regData := b.bits + regState := hasB + }.otherwise { + regState := idleA + } + } + is(hasA) { + when(out.ready) { + regState := idleB + } + } + is(hasB) { + when(out.ready) { + regState := idleA + } + } + } + + out.bits := regData.asUInt + 1.U + out + } + + io.out <> io.in.reduceTree(arbitrateTwo) +} + +class ReduceTreeBalancedTester(nodes: Int) extends BasicTester { + + val cnt = RegInit(0.U(8.W)) + val min = RegInit(99.U(8.W)) + val max = RegInit(0.U(8.W)) + + val dut = Module(new Arbiter(nodes, UInt(16.W))) + for (i <- 0 until nodes) { + dut.io.in(i).valid := true.B + dut.io.in(i).bits := 0.U + } + dut.io.out.ready := true.B + + when(dut.io.out.valid) { + val hops = dut.io.out.bits + when(hops < min) { + min := hops + } + when(hops > max) { + max := hops + } + } + + when(!(max === 0.U || min === 99.U)) { + assert(max - min <= 1.U) + } + + cnt := cnt + 1.U + when(cnt === 10.U) { + stop() + } +} + +class ReduceTreeBalancedSpec extends ChiselPropSpec { + property("Tree shall be fair and shall have a maximum difference of one hop for each node") { + + // This test will fail for 5 nodes due to an unbalanced tree. + // A fix is on the way. + for (n <- 1 to 5) { + assertTesterPasses { + new ReduceTreeBalancedTester(n) + } + } + } +} |
