summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoranniej-sifive2021-08-23 14:37:09 -0700
committerGitHub2021-08-23 14:37:09 -0700
commitf50ce19406e45982390162777fb62c8563c962c7 (patch)
tree010f8ecf120509d112b995a0a2866a40f6b12d98
parenta3d51e4c91059362b20296eaa00f06f96ec7a4e1 (diff)
Add multiple dimensions to VecInit fill and iterate (#2065)
Co-authored-by: Jack Koenig <koenig@sifive.com>
-rw-r--r--core/src/main/scala/chisel3/Aggregate.scala129
-rw-r--r--docs/src/cookbooks/cookbook.md46
-rw-r--r--macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala15
-rw-r--r--src/test/scala/chiselTests/Vec.scala150
4 files changed, 327 insertions, 13 deletions
diff --git a/core/src/main/scala/chisel3/Aggregate.scala b/core/src/main/scala/chisel3/Aggregate.scala
index 58bc5ccb..17e46cb3 100644
--- a/core/src/main/scala/chisel3/Aggregate.scala
+++ b/core/src/main/scala/chisel3/Aggregate.scala
@@ -488,6 +488,21 @@ sealed class Vec[T <: Data] private[chisel3] (gen: => T, val length: Int)
}
object VecInit extends SourceInfoDoc {
+
+ /** Gets the correct connect operation (directed hardware assign or bulk connect) for element in Vec.
+ */
+ private def getConnectOpFromDirectionality[T <: Data](proto: T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): (T, T) => Unit = proto.direction match {
+ case ActualDirection.Input | ActualDirection.Output | ActualDirection.Unspecified =>
+ // When internal wires are involved, driver / sink must be specified explicitly, otherwise
+ // the system is unable to infer which is driver / sink
+ (x, y) => x := y
+ case ActualDirection.Bidirectional(_) =>
+ // For bidirectional, must issue a bulk connect so subelements are resolved correctly.
+ // Bulk connecting two wires may not succeed because Chisel frontend does not infer
+ // directions.
+ (x, y) => x <> y
+ }
+
/** Creates a new [[Vec]] composed of elements of the input Seq of [[Data]]
* nodes.
*
@@ -513,18 +528,10 @@ object VecInit extends SourceInfoDoc {
elts.foreach(requireIsHardware(_, "vec element"))
val vec = Wire(Vec(elts.length, cloneSupertype(elts, "Vec")))
-
- // TODO: try to remove the logic for this mess
- elts.head.direction match {
- case ActualDirection.Input | ActualDirection.Output | ActualDirection.Unspecified =>
- // When internal wires are involved, driver / sink must be specified explicitly, otherwise
- // the system is unable to infer which is driver / sink
- (vec zip elts).foreach(x => x._1 := x._2)
- case ActualDirection.Bidirectional(_) =>
- // For bidirectional, must issue a bulk connect so subelements are resolved correctly.
- // Bulk connecting two wires may not succeed because Chisel frontend does not infer
- // directions.
- (vec zip elts).foreach(x => x._1 <> x._2)
+ val op = getConnectOpFromDirectionality(vec.head)
+
+ (vec zip elts).foreach{ x =>
+ op(x._1, x._2)
}
vec
}
@@ -557,6 +564,73 @@ object VecInit extends SourceInfoDoc {
def do_tabulate[T <: Data](n: Int)(gen: (Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[T] =
apply((0 until n).map(i => gen(i)))
+ /** Creates a new 2D [[Vec]] of length `n by m` composed of the results of the given
+ * function applied over a range of integer values starting from 0.
+ *
+ * @param n number of 1D vectors inside outer vector
+ * @param m number of elements in each 1D vector (the function is applied from
+ * 0 to `n-1`)
+ * @param gen function that takes in an Int (the index) and returns a
+ * [[Data]] that becomes the output element
+ */
+ def tabulate[T <: Data](n: Int, m: Int)(gen: (Int, Int) => T): Vec[Vec[T]] = macro VecTransform.tabulate2D
+
+ /** @group SourceInfoTransformMacro */
+ def do_tabulate[T <: Data](n: Int, m: Int)(gen: (Int, Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[T]] = {
+ // TODO make this lazy (requires LazyList and cross compilation, beyond the scope of this PR)
+ val elts = Seq.tabulate(n, m)(gen)
+ val flatElts = elts.flatten
+
+ require(flatElts.nonEmpty, "Vec hardware values are not allowed to be empty")
+ flatElts.foreach(requireIsHardware(_, "vec element"))
+
+ val tpe = cloneSupertype(flatElts, "Vec.tabulate")
+ val myVec = Wire(Vec(n, Vec(m, tpe)))
+ val op = getConnectOpFromDirectionality(myVec.head.head)
+ for (
+ (xs1D, ys1D) <- myVec zip elts;
+ (x, y) <- xs1D zip ys1D
+ ) {
+ op(x, y)
+ }
+ myVec
+ }
+
+ /** Creates a new 3D [[Vec]] of length `n by m by p` composed of the results of the given
+ * function applied over a range of integer values starting from 0.
+ *
+ * @param n number of 2D vectors inside outer vector
+ * @param m number of 1D vectors in each 2D vector
+ * @param p number of elements in each 1D vector
+ * @param gen function that takes in an Int (the index) and returns a
+ * [[Data]] that becomes the output element
+ */
+ def tabulate[T <: Data](n: Int, m: Int, p: Int)(gen: (Int, Int, Int) => T): Vec[Vec[Vec[T]]] = macro VecTransform.tabulate3D
+
+ /** @group SourceInfoTransformMacro */
+ def do_tabulate[T <: Data](n: Int, m: Int, p: Int)(gen: (Int, Int, Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[Vec[T]]] = {
+ // TODO make this lazy (requires LazyList and cross compilation, beyond the scope of this PR)
+ val elts = Seq.tabulate(n, m, p)(gen)
+ val flatElts = elts.flatten.flatten
+
+ require(flatElts.nonEmpty, "Vec hardware values are not allowed to be empty")
+ flatElts.foreach(requireIsHardware(_, "vec element"))
+
+ val tpe = cloneSupertype(flatElts, "Vec.tabulate")
+ val myVec = Wire(Vec(n, Vec(m, Vec(p, tpe))))
+ val op = getConnectOpFromDirectionality(myVec.head.head.head)
+
+ for (
+ (xs2D, ys2D) <- myVec zip elts;
+ (xs1D, ys1D) <- xs2D zip ys2D;
+ (x, y) <- xs1D zip ys1D
+ ) {
+ op(x, y)
+ }
+
+ myVec
+ }
+
/** Creates a new [[Vec]] of length `n` composed of the result of the given
* function applied to an element of data type T.
*
@@ -570,6 +644,37 @@ object VecInit extends SourceInfoDoc {
def do_fill[T <: Data](n: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[T] =
apply(Seq.fill(n)(gen))
+ /** Creates a new 2D [[Vec]] of length `n by m` composed of the result of the given
+ * function applied to an element of data type T.
+ *
+ * @param n number of inner vectors (rows) in the outer vector
+ * @param m number of elements in each inner vector (column)
+ * @param gen function that takes in an element T and returns an output
+ * element of the same type
+ */
+ def fill[T <: Data](n: Int, m: Int)(gen: => T): Vec[Vec[T]] = macro VecTransform.fill2D
+
+ /** @group SourceInfoTransformMacro */
+ def do_fill[T <: Data](n: Int, m: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[T]] = {
+ do_tabulate(n, m)((_, _) => gen)
+ }
+
+ /** Creates a new 3D [[Vec]] of length `n by m by p` composed of the result of the given
+ * function applied to an element of data type T.
+ *
+ * @param n number of 2D vectors inside outer vector
+ * @param m number of 1D vectors in each 2D vector
+ * @param p number of elements in each 1D vector
+ * @param gen function that takes in an element T and returns an output
+ * element of the same type
+ */
+ def fill[T <: Data](n: Int, m: Int, p: Int)(gen: => T): Vec[Vec[Vec[T]]] = macro VecTransform.fill3D
+
+ /** @group SourceInfoTransformMacro */
+ def do_fill[T <: Data](n: Int, m: Int, p: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[Vec[T]]] = {
+ do_tabulate(n, m, p)((_, _, _) => gen)
+ }
+
/** Creates a new [[Vec]] of length `n` composed of the result of the given
* function applied to an element of data type T.
*
diff --git a/docs/src/cookbooks/cookbook.md b/docs/src/cookbooks/cookbook.md
index cff7a5b2..ce49b668 100644
--- a/docs/src/cookbooks/cookbook.md
+++ b/docs/src/cookbooks/cookbook.md
@@ -17,6 +17,7 @@ Please note that these examples make use of [Chisel's scala-style printing](../e
* [How do I create a UInt from a Vec of Bool?](#how-do-i-create-a-uint-from-a-vec-of-bool)
* [How do I connect a subset of Bundle fields?](#how-do-i-connect-a-subset-of-bundle-fields)
* Vectors and Registers
+ * [Can I make a 2D or 3D Vector?](#can-i-make-a-2D-or-3D-Vector)
* [How do I create a Vector of Registers?](#how-do-i-create-a-vector-of-registers)
* [How do I create a Reg of type Vec?](#how-do-i-create-a-reg-of-type-vec)
* [How do I create a finite state machine?](#how-do-i-create-a-finite-state-machine-fsm)
@@ -157,6 +158,51 @@ See the [DataView cookbook](dataview#how-do-i-connect-a-subset-of-bundle-fields)
## Vectors and Registers
+### Can I make a 2D or 3D Vector?
+
+Yes. Using `VecInit` you can make Vectors that hold Vectors of Chisel types. Methods `fill` and `tabulate` make these multi-dimensional Vectors.
+
+```scala mdoc:silent:reset
+import chisel3._
+
+class MyBundle extends Bundle {
+ val foo = UInt(4.W)
+ val bar = UInt(4.W)
+}
+
+class Foo extends Module {
+ //2D Fill
+ val twoDVec = VecInit.fill(2, 3)(5.U)
+ //3D Fill
+ val myBundle = Wire(new MyBundle)
+ myBundle.foo := 0xc.U
+ myBundle.bar := 0x3.U
+ val threeDVec = VecInit.fill(1, 2, 3)(myBundle)
+ assert(threeDVec(0)(0)(0).foo === 0xc.U && threeDVec(0)(0)(0).bar === 0x3.U)
+
+ //2D Tabulate
+ val indexTiedVec = VecInit.tabulate(2, 2){ (x, y) => (x + y).U }
+ assert(indexTiedVec(0)(0) === 0.U)
+ assert(indexTiedVec(0)(1) === 1.U)
+ assert(indexTiedVec(1)(0) === 1.U)
+ assert(indexTiedVec(1)(1) === 2.U)
+ //3D Tabulate
+ val indexTiedVec3D = VecInit.tabulate(2, 3, 4){ (x, y, z) => (x + y * z).U }
+ assert(indexTiedVec3D(0)(0)(0) === 0.U)
+ assert(indexTiedVec3D(1)(1)(1) === 2.U)
+ assert(indexTiedVec3D(1)(1)(2) === 3.U)
+ assert(indexTiedVec3D(1)(1)(3) === 4.U)
+ assert(indexTiedVec3D(1)(2)(3) === 7.U)
+}
+```
+```scala mdoc:invisible
+// Hidden but will make sure this actually compiles
+import chisel3.stage.ChiselStage
+
+ChiselStage.emitVerilog(new Foo)
+```
+
+
### How do I create a Vector of Registers?
**Rule! Use Reg of Vec not Vec of Reg!**
diff --git a/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala b/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala
index d7c301e9..6121bc1e 100644
--- a/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala
+++ b/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala
@@ -81,9 +81,24 @@ class VecTransform(val c: Context) extends SourceInfoTransformMacro {
def tabulate(n: c.Tree)(gen: c.Tree): c.Tree = {
q"$thisObj.do_tabulate($n)($gen)($implicitSourceInfo, $implicitCompileOptions)"
}
+ def tabulate2D(n: c.Tree, m: c.Tree)(gen: c.Tree): c.Tree = {
+ q"$thisObj.do_tabulate($n,$m)($gen)($implicitSourceInfo, $implicitCompileOptions)"
+ }
+ def tabulate3D(n: c.Tree, m: c.Tree, p: c.Tree)(gen: c.Tree): c.Tree = {
+ q"$thisObj.do_tabulate($n,$m,$p)($gen)($implicitSourceInfo, $implicitCompileOptions)"
+ }
def fill(n: c.Tree)(gen: c.Tree): c.Tree = {
q"$thisObj.do_fill($n)($gen)($implicitSourceInfo, $implicitCompileOptions)"
}
+ def fill2D(n: c.Tree, m: c.Tree)(gen: c.Tree): c.Tree = {
+ q"$thisObj.do_fill($n,$m)($gen)($implicitSourceInfo, $implicitCompileOptions)"
+ }
+ def fill3D(n: c.Tree, m: c.Tree, p: c.Tree)(gen: c.Tree): c.Tree = {
+ q"$thisObj.do_fill($n,$m,$p)($gen)($implicitSourceInfo, $implicitCompileOptions)"
+ }
+ def fill4D(n: c.Tree, m: c.Tree, p: c.Tree, q: c.Tree)(gen: c.Tree): c.Tree = {
+ q"$thisObj.do_fill($n,$m,$p,$q)($gen)($implicitSourceInfo, $implicitCompileOptions)"
+ }
def iterate(start: c.Tree, len: c.Tree)(f: c.Tree): c.Tree = {
q"$thisObj.do_iterate($start,$len)($f)($implicitSourceInfo, $implicitCompileOptions)"
}
diff --git a/src/test/scala/chiselTests/Vec.scala b/src/test/scala/chiselTests/Vec.scala
index f3160c2e..97aea909 100644
--- a/src/test/scala/chiselTests/Vec.scala
+++ b/src/test/scala/chiselTests/Vec.scala
@@ -9,6 +9,7 @@ import chisel3.stage.ChiselStage
import chisel3.testers.BasicTester
import chisel3.util._
import org.scalacheck.Shrink
+import scala.annotation.tailrec
class LitTesterMod(vecSize: Int) extends Module {
val io = IO(new Bundle {
@@ -114,6 +115,121 @@ class FillTester(n: Int, value: Int) extends BasicTester {
stop()
}
+object VecMultiDimTester {
+
+ @tailrec
+ private def assert2DIsCorrect(n: Int, arr: Vec[Vec[UInt]], compArr: Seq[Seq[Int]]): Unit = {
+ val compareRow = arr(n) zip compArr(n)
+ compareRow.foreach (x => assert(x._1 === x._2.U))
+ if (n != 0) assert2DIsCorrect(n-1, arr, compArr)
+ }
+
+ @tailrec
+ private def assert3DIsCorrect(n: Int, m: Int, arr: Vec[Vec[Vec[UInt]]], compArr: Seq[Seq[Seq[Int]]]): Unit = {
+ assert2DIsCorrect(m-1, arr(n), compArr(n))
+ if (n != 0) assert3DIsCorrect(n-1, m, arr, compArr)
+ }
+
+ class TabulateTester2D(n: Int, m: Int) extends BasicTester {
+ def gen(x: Int, y: Int): UInt = (x+y).asUInt
+ def genCompVec(x: Int, y:Int): Int = x+y
+ val vec = VecInit.tabulate(n, m){ gen }
+ val compArr = Seq.tabulate(n,m){ genCompVec }
+
+ assert2DIsCorrect(n-1, vec, compArr)
+ stop()
+ }
+
+ class TabulateTester3D(n: Int, m: Int, p: Int) extends BasicTester {
+ def gen(x: Int, y: Int, z: Int): UInt = (x+y+z).asUInt
+ def genCompVec(x: Int, y:Int, z: Int): Int = x+y+z
+ val vec = VecInit.tabulate(n, m, p){ gen }
+ val compArr = Seq.tabulate(n, m, p){ genCompVec }
+
+ assert3DIsCorrect(n-1, m, vec, compArr)
+ stop()
+ }
+
+ class Fill2DTester(n: Int, m: Int, value: Int) extends BasicTester {
+ val u = VecInit.fill(n,m)(value.U)
+ val compareArr = Seq.fill(n,m)(value)
+
+ assert2DIsCorrect(n-1, u, compareArr)
+ stop()
+ }
+
+ class Fill3DTester(n: Int, m: Int, p: Int, value: Int) extends BasicTester {
+ val u = VecInit.fill(n,m,p)(value.U)
+ val compareArr = Seq.fill(n,m,p)(value)
+
+ assert3DIsCorrect(n-1, m, u, compareArr)
+ stop()
+ }
+
+ class BidirectionalTester2DFill(n: Int, m: Int) extends BasicTester {
+ val mod = Module(new PassthroughModule)
+ val vec2D = VecInit.fill(n, m)(mod.io)
+ for {
+ vec1D <- vec2D
+ module <- vec1D
+ } yield {
+ module <> Module(new PassthroughModuleTester).io
+ }
+ stop()
+ }
+
+ class BidirectionalTester3DFill(n: Int, m: Int, p: Int) extends BasicTester {
+ val mod = Module(new PassthroughModule)
+ val vec3D = VecInit.fill(n, m, p)(mod.io)
+
+ for {
+ vec2D <- vec3D
+ vec1D <- vec2D
+ module <- vec1D
+ } yield {
+ module <> (Module(new PassthroughModuleTester).io)
+ }
+ stop()
+ }
+
+ class TabulateModuleTester(value: Int) extends Module {
+ val io = IO(Flipped(new PassthroughModuleIO))
+ // This drives the input of a PassthroughModule
+ io.in := value.U
+ }
+
+ class BidirectionalTester2DTabulate(n: Int, m: Int) extends BasicTester {
+ val vec2D = VecInit.tabulate(n, m) { (x, y) => Module(new TabulateModuleTester(x + y + 1)).io}
+
+ for {
+ x <- 0 until n
+ y <- 0 until m
+ } yield {
+ val value = x + y + 1
+ val receiveMod = Module(new PassthroughModule).io
+ vec2D(x)(y) <> receiveMod
+ assert(receiveMod.out === value.U)
+ }
+ stop()
+ }
+
+ class BidirectionalTester3DTabulate(n: Int, m: Int, p: Int) extends BasicTester {
+ val vec3D = VecInit.tabulate(n, m, p) { (x, y, z) => Module(new TabulateModuleTester(x + y + z + 1)).io }
+
+ for {
+ x <- 0 until n
+ y <- 0 until m
+ z <- 0 until p
+ } yield {
+ val value = x + y + z + 1
+ val receiveMod = Module(new PassthroughModule).io
+ vec3D(x)(y)(z) <> receiveMod
+ assert(receiveMod.out === value.U)
+ }
+ stop()
+ }
+}
+
class IterateTester(start: Int, len: Int)(f: UInt => UInt) extends BasicTester {
val controlVec = VecInit(Seq.iterate(start.U, len)(f))
val testVec = VecInit.iterate(start.U, len)(f)
@@ -178,7 +294,7 @@ class PassthroughModuleTester extends Module {
}
class ModuleIODynamicIndexTester(n: Int) extends BasicTester {
- val duts = VecInit(Seq.fill(n)(Module(new PassthroughModule).io))
+ val duts = VecInit.fill(n)(Module(new PassthroughModule).io)
val tester = Module(new PassthroughModuleTester)
val (cycle, done) = Counter(true.B, n)
@@ -239,10 +355,42 @@ class VecSpec extends ChiselPropSpec with Utils {
forAll(smallPosInts) { (n: Int) => assertTesterPasses{ new TabulateTester(n) } }
}
+ property("VecInit should tabulate 2D vec correctly") {
+ forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses { new VecMultiDimTester.TabulateTester2D(n, m) } }
+ }
+
+ property("VecInit should tabulate 3D vec correctly") {
+ forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.TabulateTester3D(n, m, p) } }
+ }
+
property("VecInit should fill correctly") {
forAll(smallPosInts, Gen.choose(0, 50)) { (n: Int, value: Int) => assertTesterPasses{ new FillTester(n, value) } }
}
+ property("VecInit should fill 2D vec correctly") {
+ forAll(smallPosInts, smallPosInts, Gen.choose(0, 50)) { (n: Int, m: Int, value: Int) => assertTesterPasses{ new VecMultiDimTester.Fill2DTester(n, m, value) } }
+ }
+
+ property("VecInit should fill 3D vec correctly") {
+ forAll(smallPosInts, smallPosInts, smallPosInts, Gen.choose(0, 50)) { (n: Int, m: Int, p: Int, value: Int) => assertTesterPasses{ new VecMultiDimTester.Fill3DTester(n, m, p, value) } }
+ }
+
+ property("VecInit should support 2D fill bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester2DFill(n, m) }}
+ }
+
+ property("VecInit should support 3D fill bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester3DFill(n, m, p) }}
+ }
+
+ property("VecInit should support 2D tabulate bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester2DTabulate(n, m) }}
+ }
+
+ property("VecInit should support 3D tabulate bidirectional wire connection") {
+ forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester3DTabulate(n, m, p) }}
+ }
+
property("VecInit should iterate correctly") {
forAll(Gen.choose(1, 10), smallPosInts) { (start: Int, len: Int) => assertTesterPasses{ new IterateTester(start, len)(x => x + 50.U)}}
}