diff options
| author | anniej-sifive | 2021-08-23 14:37:09 -0700 |
|---|---|---|
| committer | GitHub | 2021-08-23 14:37:09 -0700 |
| commit | f50ce19406e45982390162777fb62c8563c962c7 (patch) | |
| tree | 010f8ecf120509d112b995a0a2866a40f6b12d98 /core | |
| parent | a3d51e4c91059362b20296eaa00f06f96ec7a4e1 (diff) | |
Add multiple dimensions to VecInit fill and iterate (#2065)
Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'core')
| -rw-r--r-- | core/src/main/scala/chisel3/Aggregate.scala | 129 |
1 files changed, 117 insertions, 12 deletions
diff --git a/core/src/main/scala/chisel3/Aggregate.scala b/core/src/main/scala/chisel3/Aggregate.scala index 58bc5ccb..17e46cb3 100644 --- a/core/src/main/scala/chisel3/Aggregate.scala +++ b/core/src/main/scala/chisel3/Aggregate.scala @@ -488,6 +488,21 @@ sealed class Vec[T <: Data] private[chisel3] (gen: => T, val length: Int) } object VecInit extends SourceInfoDoc { + + /** Gets the correct connect operation (directed hardware assign or bulk connect) for element in Vec. + */ + private def getConnectOpFromDirectionality[T <: Data](proto: T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): (T, T) => Unit = proto.direction match { + case ActualDirection.Input | ActualDirection.Output | ActualDirection.Unspecified => + // When internal wires are involved, driver / sink must be specified explicitly, otherwise + // the system is unable to infer which is driver / sink + (x, y) => x := y + case ActualDirection.Bidirectional(_) => + // For bidirectional, must issue a bulk connect so subelements are resolved correctly. + // Bulk connecting two wires may not succeed because Chisel frontend does not infer + // directions. + (x, y) => x <> y + } + /** Creates a new [[Vec]] composed of elements of the input Seq of [[Data]] * nodes. * @@ -513,18 +528,10 @@ object VecInit extends SourceInfoDoc { elts.foreach(requireIsHardware(_, "vec element")) val vec = Wire(Vec(elts.length, cloneSupertype(elts, "Vec"))) - - // TODO: try to remove the logic for this mess - elts.head.direction match { - case ActualDirection.Input | ActualDirection.Output | ActualDirection.Unspecified => - // When internal wires are involved, driver / sink must be specified explicitly, otherwise - // the system is unable to infer which is driver / sink - (vec zip elts).foreach(x => x._1 := x._2) - case ActualDirection.Bidirectional(_) => - // For bidirectional, must issue a bulk connect so subelements are resolved correctly. - // Bulk connecting two wires may not succeed because Chisel frontend does not infer - // directions. - (vec zip elts).foreach(x => x._1 <> x._2) + val op = getConnectOpFromDirectionality(vec.head) + + (vec zip elts).foreach{ x => + op(x._1, x._2) } vec } @@ -557,6 +564,73 @@ object VecInit extends SourceInfoDoc { def do_tabulate[T <: Data](n: Int)(gen: (Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[T] = apply((0 until n).map(i => gen(i))) + /** Creates a new 2D [[Vec]] of length `n by m` composed of the results of the given + * function applied over a range of integer values starting from 0. + * + * @param n number of 1D vectors inside outer vector + * @param m number of elements in each 1D vector (the function is applied from + * 0 to `n-1`) + * @param gen function that takes in an Int (the index) and returns a + * [[Data]] that becomes the output element + */ + def tabulate[T <: Data](n: Int, m: Int)(gen: (Int, Int) => T): Vec[Vec[T]] = macro VecTransform.tabulate2D + + /** @group SourceInfoTransformMacro */ + def do_tabulate[T <: Data](n: Int, m: Int)(gen: (Int, Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[T]] = { + // TODO make this lazy (requires LazyList and cross compilation, beyond the scope of this PR) + val elts = Seq.tabulate(n, m)(gen) + val flatElts = elts.flatten + + require(flatElts.nonEmpty, "Vec hardware values are not allowed to be empty") + flatElts.foreach(requireIsHardware(_, "vec element")) + + val tpe = cloneSupertype(flatElts, "Vec.tabulate") + val myVec = Wire(Vec(n, Vec(m, tpe))) + val op = getConnectOpFromDirectionality(myVec.head.head) + for ( + (xs1D, ys1D) <- myVec zip elts; + (x, y) <- xs1D zip ys1D + ) { + op(x, y) + } + myVec + } + + /** Creates a new 3D [[Vec]] of length `n by m by p` composed of the results of the given + * function applied over a range of integer values starting from 0. + * + * @param n number of 2D vectors inside outer vector + * @param m number of 1D vectors in each 2D vector + * @param p number of elements in each 1D vector + * @param gen function that takes in an Int (the index) and returns a + * [[Data]] that becomes the output element + */ + def tabulate[T <: Data](n: Int, m: Int, p: Int)(gen: (Int, Int, Int) => T): Vec[Vec[Vec[T]]] = macro VecTransform.tabulate3D + + /** @group SourceInfoTransformMacro */ + def do_tabulate[T <: Data](n: Int, m: Int, p: Int)(gen: (Int, Int, Int) => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[Vec[T]]] = { + // TODO make this lazy (requires LazyList and cross compilation, beyond the scope of this PR) + val elts = Seq.tabulate(n, m, p)(gen) + val flatElts = elts.flatten.flatten + + require(flatElts.nonEmpty, "Vec hardware values are not allowed to be empty") + flatElts.foreach(requireIsHardware(_, "vec element")) + + val tpe = cloneSupertype(flatElts, "Vec.tabulate") + val myVec = Wire(Vec(n, Vec(m, Vec(p, tpe)))) + val op = getConnectOpFromDirectionality(myVec.head.head.head) + + for ( + (xs2D, ys2D) <- myVec zip elts; + (xs1D, ys1D) <- xs2D zip ys2D; + (x, y) <- xs1D zip ys1D + ) { + op(x, y) + } + + myVec + } + /** Creates a new [[Vec]] of length `n` composed of the result of the given * function applied to an element of data type T. * @@ -570,6 +644,37 @@ object VecInit extends SourceInfoDoc { def do_fill[T <: Data](n: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[T] = apply(Seq.fill(n)(gen)) + /** Creates a new 2D [[Vec]] of length `n by m` composed of the result of the given + * function applied to an element of data type T. + * + * @param n number of inner vectors (rows) in the outer vector + * @param m number of elements in each inner vector (column) + * @param gen function that takes in an element T and returns an output + * element of the same type + */ + def fill[T <: Data](n: Int, m: Int)(gen: => T): Vec[Vec[T]] = macro VecTransform.fill2D + + /** @group SourceInfoTransformMacro */ + def do_fill[T <: Data](n: Int, m: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[T]] = { + do_tabulate(n, m)((_, _) => gen) + } + + /** Creates a new 3D [[Vec]] of length `n by m by p` composed of the result of the given + * function applied to an element of data type T. + * + * @param n number of 2D vectors inside outer vector + * @param m number of 1D vectors in each 2D vector + * @param p number of elements in each 1D vector + * @param gen function that takes in an element T and returns an output + * element of the same type + */ + def fill[T <: Data](n: Int, m: Int, p: Int)(gen: => T): Vec[Vec[Vec[T]]] = macro VecTransform.fill3D + + /** @group SourceInfoTransformMacro */ + def do_fill[T <: Data](n: Int, m: Int, p: Int)(gen: => T)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Vec[Vec[Vec[T]]] = { + do_tabulate(n, m, p)((_, _, _) => gen) + } + /** Creates a new [[Vec]] of length `n` composed of the result of the given * function applied to an element of data type T. * |
