summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/chisel3/Aggregate.scala28
-rw-r--r--src/test/scala/chiselTests/ReduceTreeSpec.scala106
2 files changed, 130 insertions, 4 deletions
diff --git a/core/src/main/scala/chisel3/Aggregate.scala b/core/src/main/scala/chisel3/Aggregate.scala
index 06ae36f3..cc5b83d9 100644
--- a/core/src/main/scala/chisel3/Aggregate.scala
+++ b/core/src/main/scala/chisel3/Aggregate.scala
@@ -14,6 +14,7 @@ import chisel3.internal.Builder.pushCommand
import chisel3.internal.firrtl._
import chisel3.internal.sourceinfo._
+import java.lang.Math.{floor, log10, pow}
import scala.collection.mutable
class AliasedAggregateFieldException(message: String) extends ChiselException(message)
@@ -381,11 +382,30 @@ sealed class Vec[T <: Data] private[chisel3] (gen: => T, val length: Int) extend
compileOptions: CompileOptions
): T = {
require(!isEmpty, "Cannot apply reduction on a vec of size 0")
- var curLayer: Seq[T] = this
- while (curLayer.length > 1) {
- curLayer = curLayer.grouped(2).map(x => if (x.length == 1) layerOp(x(0)) else redOp(x(0), x(1))).toSeq
+
+ def recReduce[T](s: Seq[T], op: (T, T) => T, lop: (T) => T): T = {
+
+ val n = s.length
+ n match {
+ case 1 => lop(s(0))
+ case 2 => op(s(0), s(1))
+ case _ =>
+ val m = pow(2, floor(log10(n - 1) / log10(2))).toInt // number of nodes in next level, will be a power of 2
+ val p = 2 * m - n // number of nodes promoted
+
+ val l = s.take(p).map(lop)
+ val r = s
+ .drop(p)
+ .grouped(2)
+ .map {
+ case Seq(a, b) => op(a, b)
+ }
+ .toVector
+ recReduce(l ++ r, op, lop)
+ }
}
- curLayer(0)
+
+ recReduce(this, redOp, layerOp)
}
/** Creates a Vec literal of this type with specified values. this must be a chisel type.
diff --git a/src/test/scala/chiselTests/ReduceTreeSpec.scala b/src/test/scala/chiselTests/ReduceTreeSpec.scala
new file mode 100644
index 00000000..3f078106
--- /dev/null
+++ b/src/test/scala/chiselTests/ReduceTreeSpec.scala
@@ -0,0 +1,106 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package chiselTests
+
+import chisel3._
+import chisel3.util._
+import chisel3.testers.BasicTester
+
+class Arbiter[T <: Data: Manifest](n: Int, private val gen: T) extends Module {
+ val io = IO(new Bundle {
+ val in = Flipped(Vec(n, new DecoupledIO(gen)))
+ val out = new DecoupledIO(gen)
+ })
+
+ def arbitrateTwo(a: DecoupledIO[T], b: DecoupledIO[T]) = {
+
+ val idleA :: idleB :: hasA :: hasB :: Nil = Enum(4)
+ val regData = Reg(gen)
+ val regState = RegInit(idleA)
+ val out = Wire(new DecoupledIO(gen))
+
+ a.ready := regState === idleA
+ b.ready := regState === idleB
+ out.valid := (regState === hasA || regState === hasB)
+
+ switch(regState) {
+ is(idleA) {
+ when(a.valid) {
+ regData := a.bits
+ regState := hasA
+ }.otherwise {
+ regState := idleB
+ }
+ }
+ is(idleB) {
+ when(b.valid) {
+ regData := b.bits
+ regState := hasB
+ }.otherwise {
+ regState := idleA
+ }
+ }
+ is(hasA) {
+ when(out.ready) {
+ regState := idleB
+ }
+ }
+ is(hasB) {
+ when(out.ready) {
+ regState := idleA
+ }
+ }
+ }
+
+ out.bits := regData.asUInt + 1.U
+ out
+ }
+
+ io.out <> io.in.reduceTree(arbitrateTwo)
+}
+
+class ReduceTreeBalancedTester(nodes: Int) extends BasicTester {
+
+ val cnt = RegInit(0.U(8.W))
+ val min = RegInit(99.U(8.W))
+ val max = RegInit(0.U(8.W))
+
+ val dut = Module(new Arbiter(nodes, UInt(16.W)))
+ for (i <- 0 until nodes) {
+ dut.io.in(i).valid := true.B
+ dut.io.in(i).bits := 0.U
+ }
+ dut.io.out.ready := true.B
+
+ when(dut.io.out.valid) {
+ val hops = dut.io.out.bits
+ when(hops < min) {
+ min := hops
+ }
+ when(hops > max) {
+ max := hops
+ }
+ }
+
+ when(!(max === 0.U || min === 99.U)) {
+ assert(max - min <= 1.U)
+ }
+
+ cnt := cnt + 1.U
+ when(cnt === 10.U) {
+ stop()
+ }
+}
+
+class ReduceTreeBalancedSpec extends ChiselPropSpec {
+ property("Tree shall be fair and shall have a maximum difference of one hop for each node") {
+
+ // This test will fail for 5 nodes due to an unbalanced tree.
+ // A fix is on the way.
+ for (n <- 1 to 5) {
+ assertTesterPasses {
+ new ReduceTreeBalancedTester(n)
+ }
+ }
+ }
+}