summaryrefslogtreecommitdiff
path: root/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'src/test')
-rw-r--r--src/test/scala/chiselTests/Vec.scala150
1 files changed, 149 insertions, 1 deletions
diff --git a/src/test/scala/chiselTests/Vec.scala b/src/test/scala/chiselTests/Vec.scala
index f3160c2e..97aea909 100644
--- a/src/test/scala/chiselTests/Vec.scala
+++ b/src/test/scala/chiselTests/Vec.scala
@@ -9,6 +9,7 @@ import chisel3.stage.ChiselStage
import chisel3.testers.BasicTester
import chisel3.util._
import org.scalacheck.Shrink
+import scala.annotation.tailrec
class LitTesterMod(vecSize: Int) extends Module {
val io = IO(new Bundle {
@@ -114,6 +115,121 @@ class FillTester(n: Int, value: Int) extends BasicTester {
stop()
}
+object VecMultiDimTester {
+
+ @tailrec
+ private def assert2DIsCorrect(n: Int, arr: Vec[Vec[UInt]], compArr: Seq[Seq[Int]]): Unit = {
+ val compareRow = arr(n) zip compArr(n)
+ compareRow.foreach (x => assert(x._1 === x._2.U))
+ if (n != 0) assert2DIsCorrect(n-1, arr, compArr)
+ }
+
+ @tailrec
+ private def assert3DIsCorrect(n: Int, m: Int, arr: Vec[Vec[Vec[UInt]]], compArr: Seq[Seq[Seq[Int]]]): Unit = {
+ assert2DIsCorrect(m-1, arr(n), compArr(n))
+ if (n != 0) assert3DIsCorrect(n-1, m, arr, compArr)
+ }
+
+ class TabulateTester2D(n: Int, m: Int) extends BasicTester {
+ def gen(x: Int, y: Int): UInt = (x+y).asUInt
+ def genCompVec(x: Int, y:Int): Int = x+y
+ val vec = VecInit.tabulate(n, m){ gen }
+ val compArr = Seq.tabulate(n,m){ genCompVec }
+
+ assert2DIsCorrect(n-1, vec, compArr)
+ stop()
+ }
+
+ class TabulateTester3D(n: Int, m: Int, p: Int) extends BasicTester {
+ def gen(x: Int, y: Int, z: Int): UInt = (x+y+z).asUInt
+ def genCompVec(x: Int, y:Int, z: Int): Int = x+y+z
+ val vec = VecInit.tabulate(n, m, p){ gen }
+ val compArr = Seq.tabulate(n, m, p){ genCompVec }
+
+ assert3DIsCorrect(n-1, m, vec, compArr)
+ stop()
+ }
+
+ class Fill2DTester(n: Int, m: Int, value: Int) extends BasicTester {
+ val u = VecInit.fill(n,m)(value.U)
+ val compareArr = Seq.fill(n,m)(value)
+
+ assert2DIsCorrect(n-1, u, compareArr)
+ stop()
+ }
+
+ class Fill3DTester(n: Int, m: Int, p: Int, value: Int) extends BasicTester {
+ val u = VecInit.fill(n,m,p)(value.U)
+ val compareArr = Seq.fill(n,m,p)(value)
+
+ assert3DIsCorrect(n-1, m, u, compareArr)
+ stop()
+ }
+
+ class BidirectionalTester2DFill(n: Int, m: Int) extends BasicTester {
+ val mod = Module(new PassthroughModule)
+ val vec2D = VecInit.fill(n, m)(mod.io)
+ for {
+ vec1D <- vec2D
+ module <- vec1D
+ } yield {
+ module <> Module(new PassthroughModuleTester).io
+ }
+ stop()
+ }
+
+ class BidirectionalTester3DFill(n: Int, m: Int, p: Int) extends BasicTester {
+ val mod = Module(new PassthroughModule)
+ val vec3D = VecInit.fill(n, m, p)(mod.io)
+
+ for {
+ vec2D <- vec3D
+ vec1D <- vec2D
+ module <- vec1D
+ } yield {
+ module <> (Module(new PassthroughModuleTester).io)
+ }
+ stop()
+ }
+
+ class TabulateModuleTester(value: Int) extends Module {
+ val io = IO(Flipped(new PassthroughModuleIO))
+ // This drives the input of a PassthroughModule
+ io.in := value.U
+ }
+
+ class BidirectionalTester2DTabulate(n: Int, m: Int) extends BasicTester {
+ val vec2D = VecInit.tabulate(n, m) { (x, y) => Module(new TabulateModuleTester(x + y + 1)).io}
+
+ for {
+ x <- 0 until n
+ y <- 0 until m
+ } yield {
+ val value = x + y + 1
+ val receiveMod = Module(new PassthroughModule).io
+ vec2D(x)(y) <> receiveMod
+ assert(receiveMod.out === value.U)
+ }
+ stop()
+ }
+
+ class BidirectionalTester3DTabulate(n: Int, m: Int, p: Int) extends BasicTester {
+ val vec3D = VecInit.tabulate(n, m, p) { (x, y, z) => Module(new TabulateModuleTester(x + y + z + 1)).io }
+
+ for {
+ x <- 0 until n
+ y <- 0 until m
+ z <- 0 until p
+ } yield {
+ val value = x + y + z + 1
+ val receiveMod = Module(new PassthroughModule).io
+ vec3D(x)(y)(z) <> receiveMod
+ assert(receiveMod.out === value.U)
+ }
+ stop()
+ }
+}
+
class IterateTester(start: Int, len: Int)(f: UInt => UInt) extends BasicTester {
val controlVec = VecInit(Seq.iterate(start.U, len)(f))
val testVec = VecInit.iterate(start.U, len)(f)
@@ -178,7 +294,7 @@ class PassthroughModuleTester extends Module {
}
class ModuleIODynamicIndexTester(n: Int) extends BasicTester {
- val duts = VecInit(Seq.fill(n)(Module(new PassthroughModule).io))
+ val duts = VecInit.fill(n)(Module(new PassthroughModule).io)
val tester = Module(new PassthroughModuleTester)
val (cycle, done) = Counter(true.B, n)
@@ -239,10 +355,42 @@ class VecSpec extends ChiselPropSpec with Utils {
forAll(smallPosInts) { (n: Int) => assertTesterPasses{ new TabulateTester(n) } }
}
+ property("VecInit should tabulate 2D vec correctly") {
+ forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses { new VecMultiDimTester.TabulateTester2D(n, m) } }
+ }
+
+ property("VecInit should tabulate 3D vec correctly") {
+ forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.TabulateTester3D(n, m, p) } }
+ }
+
property("VecInit should fill correctly") {
forAll(smallPosInts, Gen.choose(0, 50)) { (n: Int, value: Int) => assertTesterPasses{ new FillTester(n, value) } }
}
+ property("VecInit should fill 2D vec correctly") {
+ forAll(smallPosInts, smallPosInts, Gen.choose(0, 50)) { (n: Int, m: Int, value: Int) => assertTesterPasses{ new VecMultiDimTester.Fill2DTester(n, m, value) } }
+ }
+
+ property("VecInit should fill 3D vec correctly") {
+ forAll(smallPosInts, smallPosInts, smallPosInts, Gen.choose(0, 50)) { (n: Int, m: Int, p: Int, value: Int) => assertTesterPasses{ new VecMultiDimTester.Fill3DTester(n, m, p, value) } }
+ }
+
+ property("VecInit should support 2D fill bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester2DFill(n, m) }}
+ }
+
+ property("VecInit should support 3D fill bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester3DFill(n, m, p) }}
+ }
+
+ property("VecInit should support 2D tabulate bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester2DTabulate(n, m) }}
+ }
+
+ property("VecInit should support 3D tabulate bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester3DTabulate(n, m, p) }}
+ }
+
property("VecInit should iterate correctly") {
forAll(Gen.choose(1, 10), smallPosInts) { (start: Int, len: Int) => assertTesterPasses{ new IterateTester(start, len)(x => x + 50.U)}}
}