diff options
| -rw-r--r-- | core/src/main/scala/chisel3/Aggregate.scala | 129 | ||||
| -rw-r--r-- | docs/src/cookbooks/cookbook.md | 46 | ||||
| -rw-r--r-- | macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala | 15 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/Vec.scala | 150 |
4 files changed, 327 insertions, 13 deletions
diff --git a/core/src/main/scala/chisel3/Aggregate.scala b/core/src/main/scala/chisel3/Aggregate.scala index 58bc5ccb..17e46cb3 100644 --- a/core/src/main/scala/chisel3/Aggregate.scala +++ b/core/src/main/scala/chisel3/Aggregate.scala @@ -488,6 +488,21 @@ sealed class Vec[T <: Data] private[chisel3] (gen: => T, val length: Int) } object VecInit extends SourceInfoDoc { + + /** Gets the correct connect operation (directed hardware assign or bulk connect) for element in Vec. + */ + private def getConnectOpFromDirectionality[T <: Data](proto: T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): (T, T) => Unit = proto.direction match { + case ActualDirection.Input | ActualDirection.Output | ActualDirection.Unspecified => + // When internal wires are involved, driver / sink must be specified explicitly, otherwise + // the system is unable to infer which is driver / sink + (x, y) => x := y + case ActualDirection.Bidirectional(_) => + // For bidirectional, must issue a bulk connect so subelements are resolved correctly. + // Bulk connecting two wires may not succeed because Chisel frontend does not infer + // directions. + (x, y) => x <> y + } + /** Creates a new [[Vec]] composed of elements of the input Seq of [[Data]] * nodes. * @@ -513,18 +528,10 @@ object VecInit extends SourceInfoDoc { elts.foreach(requireIsHardware(_, "vec element")) val vec = Wire(Vec(elts.length, cloneSupertype(elts, "Vec"))) - - // TODO: try to remove the logic for this mess - elts.head.direction match { - case ActualDirection.Input | ActualDirection.Output | ActualDirection.Unspecified => - // When internal wires are involved, driver / sink must be specified explicitly, otherwise - // the system is unable to infer which is driver / sink - (vec zip elts).foreach(x => x._1 := x._2) - case ActualDirection.Bidirectional(_) => - // For bidirectional, must issue a bulk connect so subelements are resolved correctly. - // Bulk connecting two wires may not succeed because Chisel frontend does not infer - // directions. - (vec zip elts).foreach(x => x._1 <> x._2) + val op = getConnectOpFromDirectionality(vec.head) + + (vec zip elts).foreach{ x => + op(x._1, x._2) } vec } @@ -557,6 +564,73 @@ object VecInit extends SourceInfoDoc { def do_tabulate[T <: Data](n: Int)(gen: (Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[T] = apply((0 until n).map(i => gen(i))) + /** Creates a new 2D [[Vec]] of length `n by m` composed of the results of the given + * function applied over a range of integer values starting from 0. + * + * @param n number of 1D vectors inside outer vector + * @param m number of elements in each 1D vector (the function is applied from + * 0 to `n-1`) + * @param gen function that takes in an Int (the index) and returns a + * [[Data]] that becomes the output element + */ + def tabulate[T <: Data](n: Int, m: Int)(gen: (Int, Int) => T): Vec[Vec[T]] = macro VecTransform.tabulate2D + + /** @group SourceInfoTransformMacro */ + def do_tabulate[T <: Data](n: Int, m: Int)(gen: (Int, Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[T]] = { + // TODO make this lazy (requires LazyList and cross compilation, beyond the scope of this PR) + val elts = Seq.tabulate(n, m)(gen) + val flatElts = elts.flatten + + require(flatElts.nonEmpty, "Vec hardware values are not allowed to be empty") + flatElts.foreach(requireIsHardware(_, "vec element")) + + val tpe = cloneSupertype(flatElts, "Vec.tabulate") + val myVec = Wire(Vec(n, Vec(m, tpe))) + val op = getConnectOpFromDirectionality(myVec.head.head) + for ( + (xs1D, ys1D) <- myVec zip elts; + (x, y) <- xs1D zip ys1D + ) { + op(x, y) + } + myVec + } + + /** Creates a new 3D [[Vec]] of length `n by m by p` composed of the results of the given + * function applied over a range of integer values starting from 0. + * + * @param n number of 2D vectors inside outer vector + * @param m number of 1D vectors in each 2D vector + * @param p number of elements in each 1D vector + * @param gen function that takes in an Int (the index) and returns a + * [[Data]] that becomes the output element + */ + def tabulate[T <: Data](n: Int, m: Int, p: Int)(gen: (Int, Int, Int) => T): Vec[Vec[Vec[T]]] = macro VecTransform.tabulate3D + + /** @group SourceInfoTransformMacro */ + def do_tabulate[T <: Data](n: Int, m: Int, p: Int)(gen: (Int, Int, Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[Vec[T]]] = { + // TODO make this lazy (requires LazyList and cross compilation, beyond the scope of this PR) + val elts = Seq.tabulate(n, m, p)(gen) + val flatElts = elts.flatten.flatten + + require(flatElts.nonEmpty, "Vec hardware values are not allowed to be empty") + flatElts.foreach(requireIsHardware(_, "vec element")) + + val tpe = cloneSupertype(flatElts, "Vec.tabulate") + val myVec = Wire(Vec(n, Vec(m, Vec(p, tpe)))) + val op = getConnectOpFromDirectionality(myVec.head.head.head) + + for ( + (xs2D, ys2D) <- myVec zip elts; + (xs1D, ys1D) <- xs2D zip ys2D; + (x, y) <- xs1D zip ys1D + ) { + op(x, y) + } + + myVec + } + /** Creates a new [[Vec]] of length `n` composed of the result of the given * function applied to an element of data type T. * @@ -570,6 +644,37 @@ object VecInit extends SourceInfoDoc { def do_fill[T <: Data](n: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[T] = apply(Seq.fill(n)(gen)) + /** Creates a new 2D [[Vec]] of length `n by m` composed of the result of the given + * function applied to an element of data type T. + * + * @param n number of inner vectors (rows) in the outer vector + * @param m number of elements in each inner vector (column) + * @param gen function that takes in an element T and returns an output + * element of the same type + */ + def fill[T <: Data](n: Int, m: Int)(gen: => T): Vec[Vec[T]] = macro VecTransform.fill2D + + /** @group SourceInfoTransformMacro */ + def do_fill[T <: Data](n: Int, m: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[T]] = { + do_tabulate(n, m)((_, _) => gen) + } + + /** Creates a new 3D [[Vec]] of length `n by m by p` composed of the result of the given + * function applied to an element of data type T. + * + * @param n number of 2D vectors inside outer vector + * @param m number of 1D vectors in each 2D vector + * @param p number of elements in each 1D vector + * @param gen function that takes in an element T and returns an output + * element of the same type + */ + def fill[T <: Data](n: Int, m: Int, p: Int)(gen: => T): Vec[Vec[Vec[T]]] = macro VecTransform.fill3D + + /** @group SourceInfoTransformMacro */ + def do_fill[T <: Data](n: Int, m: Int, p: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[Vec[T]]] = { + do_tabulate(n, m, p)((_, _, _) => gen) + } + /** Creates a new [[Vec]] of length `n` composed of the result of the given * function applied to an element of data type T. * diff --git a/docs/src/cookbooks/cookbook.md b/docs/src/cookbooks/cookbook.md index cff7a5b2..ce49b668 100644 --- a/docs/src/cookbooks/cookbook.md +++ b/docs/src/cookbooks/cookbook.md @@ -17,6 +17,7 @@ Please note that these examples make use of [Chisel's scala-style printing](../e * [How do I create a UInt from a Vec of Bool?](#how-do-i-create-a-uint-from-a-vec-of-bool) * [How do I connect a subset of Bundle fields?](#how-do-i-connect-a-subset-of-bundle-fields) * Vectors and Registers + * [Can I make a 2D or 3D Vector?](#can-i-make-a-2D-or-3D-Vector) * [How do I create a Vector of Registers?](#how-do-i-create-a-vector-of-registers) * [How do I create a Reg of type Vec?](#how-do-i-create-a-reg-of-type-vec) * [How do I create a finite state machine?](#how-do-i-create-a-finite-state-machine-fsm) @@ -157,6 +158,51 @@ See the [DataView cookbook](dataview#how-do-i-connect-a-subset-of-bundle-fields) ## Vectors and Registers +### Can I make a 2D or 3D Vector? + +Yes. Using `VecInit` you can make Vectors that hold Vectors of Chisel types. Methods `fill` and `tabulate` make these multi-dimensional Vectors. + +```scala mdoc:silent:reset +import chisel3._ + +class MyBundle extends Bundle { + val foo = UInt(4.W) + val bar = UInt(4.W) +} + +class Foo extends Module { + //2D Fill + val twoDVec = VecInit.fill(2, 3)(5.U) + //3D Fill + val myBundle = Wire(new MyBundle) + myBundle.foo := 0xc.U + myBundle.bar := 0x3.U + val threeDVec = VecInit.fill(1, 2, 3)(myBundle) + assert(threeDVec(0)(0)(0).foo === 0xc.U && threeDVec(0)(0)(0).bar === 0x3.U) + + //2D Tabulate + val indexTiedVec = VecInit.tabulate(2, 2){ (x, y) => (x + y).U } + assert(indexTiedVec(0)(0) === 0.U) + assert(indexTiedVec(0)(1) === 1.U) + assert(indexTiedVec(1)(0) === 1.U) + assert(indexTiedVec(1)(1) === 2.U) + //3D Tabulate + val indexTiedVec3D = VecInit.tabulate(2, 3, 4){ (x, y, z) => (x + y * z).U } + assert(indexTiedVec3D(0)(0)(0) === 0.U) + assert(indexTiedVec3D(1)(1)(1) === 2.U) + assert(indexTiedVec3D(1)(1)(2) === 3.U) + assert(indexTiedVec3D(1)(1)(3) === 4.U) + assert(indexTiedVec3D(1)(2)(3) === 7.U) +} +``` +```scala mdoc:invisible +// Hidden but will make sure this actually compiles +import chisel3.stage.ChiselStage + +ChiselStage.emitVerilog(new Foo) +``` + + ### How do I create a Vector of Registers? **Rule! Use Reg of Vec not Vec of Reg!** diff --git a/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala b/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala index d7c301e9..6121bc1e 100644 --- a/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala +++ b/macros/src/main/scala/chisel3/internal/sourceinfo/SourceInfoTransform.scala @@ -81,9 +81,24 @@ class VecTransform(val c: Context) extends SourceInfoTransformMacro { def tabulate(n: c.Tree)(gen: c.Tree): c.Tree = { q"$thisObj.do_tabulate($n)($gen)($implicitSourceInfo, $implicitCompileOptions)" } + def tabulate2D(n: c.Tree, m: c.Tree)(gen: c.Tree): c.Tree = { + q"$thisObj.do_tabulate($n,$m)($gen)($implicitSourceInfo, $implicitCompileOptions)" + } + def tabulate3D(n: c.Tree, m: c.Tree, p: c.Tree)(gen: c.Tree): c.Tree = { + q"$thisObj.do_tabulate($n,$m,$p)($gen)($implicitSourceInfo, $implicitCompileOptions)" + } def fill(n: c.Tree)(gen: c.Tree): c.Tree = { q"$thisObj.do_fill($n)($gen)($implicitSourceInfo, $implicitCompileOptions)" } + def fill2D(n: c.Tree, m: c.Tree)(gen: c.Tree): c.Tree = { + q"$thisObj.do_fill($n,$m)($gen)($implicitSourceInfo, $implicitCompileOptions)" + } + def fill3D(n: c.Tree, m: c.Tree, p: c.Tree)(gen: c.Tree): c.Tree = { + q"$thisObj.do_fill($n,$m,$p)($gen)($implicitSourceInfo, $implicitCompileOptions)" + } + def fill4D(n: c.Tree, m: c.Tree, p: c.Tree, q: c.Tree)(gen: c.Tree): c.Tree = { + q"$thisObj.do_fill($n,$m,$p,$q)($gen)($implicitSourceInfo, $implicitCompileOptions)" + } def iterate(start: c.Tree, len: c.Tree)(f: c.Tree): c.Tree = { q"$thisObj.do_iterate($start,$len)($f)($implicitSourceInfo, $implicitCompileOptions)" } diff --git a/src/test/scala/chiselTests/Vec.scala b/src/test/scala/chiselTests/Vec.scala index f3160c2e..97aea909 100644 --- a/src/test/scala/chiselTests/Vec.scala +++ b/src/test/scala/chiselTests/Vec.scala @@ -9,6 +9,7 @@ import chisel3.stage.ChiselStage import chisel3.testers.BasicTester import chisel3.util._ import org.scalacheck.Shrink +import scala.annotation.tailrec class LitTesterMod(vecSize: Int) extends Module { val io = IO(new Bundle { @@ -114,6 +115,121 @@ class FillTester(n: Int, value: Int) extends BasicTester { stop() } +object VecMultiDimTester { + + @tailrec + private def assert2DIsCorrect(n: Int, arr: Vec[Vec[UInt]], compArr: Seq[Seq[Int]]): Unit = { + val compareRow = arr(n) zip compArr(n) + compareRow.foreach (x => assert(x._1 === x._2.U)) + if (n != 0) assert2DIsCorrect(n-1, arr, compArr) + } + + @tailrec + private def assert3DIsCorrect(n: Int, m: Int, arr: Vec[Vec[Vec[UInt]]], compArr: Seq[Seq[Seq[Int]]]): Unit = { + assert2DIsCorrect(m-1, arr(n), compArr(n)) + if (n != 0) assert3DIsCorrect(n-1, m, arr, compArr) + } + + class TabulateTester2D(n: Int, m: Int) extends BasicTester { + def gen(x: Int, y: Int): UInt = (x+y).asUInt + def genCompVec(x: Int, y:Int): Int = x+y + val vec = VecInit.tabulate(n, m){ gen } + val compArr = Seq.tabulate(n,m){ genCompVec } + + assert2DIsCorrect(n-1, vec, compArr) + stop() + } + + class TabulateTester3D(n: Int, m: Int, p: Int) extends BasicTester { + def gen(x: Int, y: Int, z: Int): UInt = (x+y+z).asUInt + def genCompVec(x: Int, y:Int, z: Int): Int = x+y+z + val vec = VecInit.tabulate(n, m, p){ gen } + val compArr = Seq.tabulate(n, m, p){ genCompVec } + + assert3DIsCorrect(n-1, m, vec, compArr) + stop() + } + + class Fill2DTester(n: Int, m: Int, value: Int) extends BasicTester { + val u = VecInit.fill(n,m)(value.U) + val compareArr = Seq.fill(n,m)(value) + + assert2DIsCorrect(n-1, u, compareArr) + stop() + } + + class Fill3DTester(n: Int, m: Int, p: Int, value: Int) extends BasicTester { + val u = VecInit.fill(n,m,p)(value.U) + val compareArr = Seq.fill(n,m,p)(value) + + assert3DIsCorrect(n-1, m, u, compareArr) + stop() + } + + class BidirectionalTester2DFill(n: Int, m: Int) extends BasicTester { + val mod = Module(new PassthroughModule) + val vec2D = VecInit.fill(n, m)(mod.io) + for { + vec1D <- vec2D + module <- vec1D + } yield { + module <> Module(new PassthroughModuleTester).io + } + stop() + } + + class BidirectionalTester3DFill(n: Int, m: Int, p: Int) extends BasicTester { + val mod = Module(new PassthroughModule) + val vec3D = VecInit.fill(n, m, p)(mod.io) + + for { + vec2D <- vec3D + vec1D <- vec2D + module <- vec1D + } yield { + module <> (Module(new PassthroughModuleTester).io) + } + stop() + } + + class TabulateModuleTester(value: Int) extends Module { + val io = IO(Flipped(new PassthroughModuleIO)) + // This drives the input of a PassthroughModule + io.in := value.U + } + + class BidirectionalTester2DTabulate(n: Int, m: Int) extends BasicTester { + val vec2D = VecInit.tabulate(n, m) { (x, y) => Module(new TabulateModuleTester(x + y + 1)).io} + + for { + x <- 0 until n + y <- 0 until m + } yield { + val value = x + y + 1 + val receiveMod = Module(new PassthroughModule).io + vec2D(x)(y) <> receiveMod + assert(receiveMod.out === value.U) + } + stop() + } + + class BidirectionalTester3DTabulate(n: Int, m: Int, p: Int) extends BasicTester { + val vec3D = VecInit.tabulate(n, m, p) { (x, y, z) => Module(new TabulateModuleTester(x + y + z + 1)).io } + + for { + x <- 0 until n + y <- 0 until m + z <- 0 until p + } yield { + val value = x + y + z + 1 + val receiveMod = Module(new PassthroughModule).io + vec3D(x)(y)(z) <> receiveMod + assert(receiveMod.out === value.U) + } + stop() + } +} + class IterateTester(start: Int, len: Int)(f: UInt => UInt) extends BasicTester { val controlVec = VecInit(Seq.iterate(start.U, len)(f)) val testVec = VecInit.iterate(start.U, len)(f) @@ -178,7 +294,7 @@ class PassthroughModuleTester extends Module { } class ModuleIODynamicIndexTester(n: Int) extends BasicTester { - val duts = VecInit(Seq.fill(n)(Module(new PassthroughModule).io)) + val duts = VecInit.fill(n)(Module(new PassthroughModule).io) val tester = Module(new PassthroughModuleTester) val (cycle, done) = Counter(true.B, n) @@ -239,10 +355,42 @@ class VecSpec extends ChiselPropSpec with Utils { forAll(smallPosInts) { (n: Int) => assertTesterPasses{ new TabulateTester(n) } } } + property("VecInit should tabulate 2D vec correctly") { + forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses { new VecMultiDimTester.TabulateTester2D(n, m) } } + } + + property("VecInit should tabulate 3D vec correctly") { + forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.TabulateTester3D(n, m, p) } } + } + property("VecInit should fill correctly") { forAll(smallPosInts, Gen.choose(0, 50)) { (n: Int, value: Int) => assertTesterPasses{ new FillTester(n, value) } } } + property("VecInit should fill 2D vec correctly") { + forAll(smallPosInts, smallPosInts, Gen.choose(0, 50)) { (n: Int, m: Int, value: Int) => assertTesterPasses{ new VecMultiDimTester.Fill2DTester(n, m, value) } } + } + + property("VecInit should fill 3D vec correctly") { + forAll(smallPosInts, smallPosInts, smallPosInts, Gen.choose(0, 50)) { (n: Int, m: Int, p: Int, value: Int) => assertTesterPasses{ new VecMultiDimTester.Fill3DTester(n, m, p, value) } } + } + + property("VecInit should support 2D fill bidirectional wire connection") { + forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester2DFill(n, m) }} + } + + property("VecInit should support 3D fill bidirectional wire connection") { + forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester3DFill(n, m, p) }} + } + + property("VecInit should support 2D tabulate bidirectional wire connection") { + forAll(smallPosInts, smallPosInts) { (n: Int, m: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester2DTabulate(n, m) }} + } + + property("VecInit should support 3D tabulate bidirectional wire connection") { + forAll(smallPosInts, smallPosInts, smallPosInts) { (n: Int, m: Int, p: Int) => assertTesterPasses{ new VecMultiDimTester.BidirectionalTester3DTabulate(n, m, p) }} + } + property("VecInit should iterate correctly") { forAll(Gen.choose(1, 10), smallPosInts) { (start: Int, len: Int) => assertTesterPasses{ new IterateTester(start, len)(x => x + 50.U)}} } |
