summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2019-09-16 15:54:31 -0700
committerJim Lawson2019-09-16 15:54:31 -0700
commite939181aebb5a0131562609a5782e1f1df88699d (patch)
treecb01c1e3438ae123061f2147600fd961ce509f1e
parenta83647a91326bcd2ac0e3c664056b5ced212cc67 (diff)
Da steve101 tree reduce (#485)
* Add a tree reduce function to Vec * Change function names of reduce operation function in Vec * Change reference to single layer operation in Vec.reduce * Commint name change for pair macro * Remove pair, call not necessary and can just be used from grouped(2) and map * Changed to reduceTree, added default identity function for single reduce. * Change style of Vec.reduceTree and tests to chisel3 and canonical Scala style * Cleanup Vec initialization, implicitCompileOptions
-rw-r--r--chiselFrontend/src/main/scala/chisel3/Aggregate.scala31
-rw-r--r--coreMacros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala6
-rw-r--r--src/test/scala/chiselTests/AdderTree.scala33
3 files changed, 70 insertions, 0 deletions
diff --git a/chiselFrontend/src/main/scala/chisel3/Aggregate.scala b/chiselFrontend/src/main/scala/chisel3/Aggregate.scala
index dfba1caf..4640cb0f 100644
--- a/chiselFrontend/src/main/scala/chisel3/Aggregate.scala
+++ b/chiselFrontend/src/main/scala/chisel3/Aggregate.scala
@@ -263,6 +263,37 @@ sealed class Vec[T <: Data] private[chisel3] (gen: => T, val length: Int)
// scalastyle:on if.brace
PString("Vec(") + Printables(elts) + PString(")")
}
+
+ /** A reduce operation in a tree like structure instead of sequentially
+ * @example An adder tree
+ * {{{
+ * val sumOut = inputNums.reduceTree((a: T, b: T) => (a + b))
+ * }}}
+ */
+ def reduceTree(redOp: (T, T) => T): T = macro VecTransform.reduceTreeDefault
+
+ /** A reduce operation in a tree like structure instead of sequentially
+ * @example A pipelined adder tree
+ * {{{
+ * val sumOut = inputNums.reduceTree(
+ * (a: T, b: T) => RegNext(a + b),
+ * (a: T) => RegNext(a)
+ * )
+ * }}}
+ */
+ def reduceTree(redOp: (T, T) => T, layerOp: (T) => T): T = macro VecTransform.reduceTree
+
+ def do_reduceTree(redOp: (T, T) => T, layerOp: (T) => T = (x: T) => x)
+ (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) : T = {
+ require(!isEmpty, "Cannot apply reduction on a vec of size 0")
+ var curLayer = this
+ while (curLayer.length > 1) {
+ curLayer = VecInit(curLayer.grouped(2).map( x =>
+ if (x.length == 1) layerOp(x(0)) else redOp(x(0), x(1))
+ ).toSeq)
+ }
+ curLayer(0)
+ }
}
object VecInit extends SourceInfoDoc {
diff --git a/coreMacros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala b/coreMacros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala
index e69c569a..d38396f4 100644
--- a/coreMacros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala
+++ b/coreMacros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala
@@ -84,6 +84,12 @@ class VecTransform(val c: Context) extends SourceInfoTransformMacro {
def contains(x: c.Tree)(ev: c.Tree): c.Tree = {
q"$thisObj.do_contains($x)($implicitSourceInfo, $ev, $implicitCompileOptions)"
}
+ def reduceTree(redOp: c.Tree, layerOp: c.Tree): c.Tree = {
+ q"$thisObj.do_reduceTree($redOp,$layerOp)($implicitSourceInfo, $implicitCompileOptions)"
+ }
+ def reduceTreeDefault(redOp: c.Tree ): c.Tree = {
+ q"$thisObj.do_reduceTree($redOp)($implicitSourceInfo, $implicitCompileOptions)"
+ }
}
/** "Automatic" source information transform / insertion macros, which generate the function name
diff --git a/src/test/scala/chiselTests/AdderTree.scala b/src/test/scala/chiselTests/AdderTree.scala
new file mode 100644
index 00000000..4e7ad1a6
--- /dev/null
+++ b/src/test/scala/chiselTests/AdderTree.scala
@@ -0,0 +1,33 @@
+package chiselTests
+
+import chisel3._
+import chisel3.testers.BasicTester
+
+class AdderTree[T <: Bits with Num[T]](genType: T, vecSize: Int) extends Module {
+ val io = IO(new Bundle {
+ val numIn = Input(Vec(vecSize, genType))
+ val numOut = Output(genType)
+ })
+ io.numOut := io.numIn.reduceTree((a : T, b : T) => (a + b))
+}
+
+class AdderTreeTester(bitWidth: Int, numsToAdd: List[Int]) extends BasicTester {
+ val genType = UInt(bitWidth.W)
+ val dut = Module(new AdderTree(genType, numsToAdd.size))
+ dut.io.numIn := VecInit(numsToAdd.map(x => x.asUInt(bitWidth.W)))
+ val sumCorrect = dut.io.numOut === (numsToAdd.reduce(_+_) % (1 << bitWidth)).asUInt(bitWidth.W)
+ assert(sumCorrect)
+ stop()
+}
+
+class AdderTreeSpec extends ChiselPropSpec {
+ property("All numbers should be added correctly by an Adder Tree") {
+ forAll(safeUIntN(20)) {
+ case (w: Int, v: List[Int]) => {
+ whenever(v.size > 0 && w > 0) {
+ assertTesterPasses { new AdderTreeTester(w, v.map(x => math.abs(x) % ( 1 << w )).toList) }
+ }
+ }
+ }
+ }
+}