summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/Data.scala19
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala70
-rw-r--r--src/test/scala/chiselTests/OneHotMuxSpec.scala286
3 files changed, 362 insertions, 13 deletions
diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala
index c9cfe27b..556f2aeb 100644
--- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala
+++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala
@@ -35,9 +35,10 @@ object DataMirror {
* - For other types of the same class are are the same: clone of any of the elements
* - Otherwise: fail
*/
+//scalastyle:off cyclomatic.complexity
private[core] object cloneSupertype {
def apply[T <: Data](elts: Seq[T], createdType: String)(implicit sourceInfo: SourceInfo,
- compileOptions: CompileOptions): T = {
+ compileOptions: CompileOptions): T = {
require(!elts.isEmpty, s"can't create $createdType with no inputs")
if (elts forall {_.isInstanceOf[Bits]}) {
@@ -45,7 +46,9 @@ private[core] object cloneSupertype {
case (elt1: Bool, elt2: Bool) => elt1
case (elt1: Bool, elt2: UInt) => elt2 // TODO: what happens with zero width UInts?
case (elt1: UInt, elt2: Bool) => elt1 // TODO: what happens with zero width UInts?
- case (elt1: UInt, elt2: UInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 // TODO: perhaps redefine Widths to allow >= op?
+ case (elt1: UInt, elt2: UInt) =>
+ // TODO: perhaps redefine Widths to allow >= op?
+ if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2
case (elt1: SInt, elt2: SInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2
case (elt1: FixedPoint, elt2: FixedPoint) => {
(elt1.binaryPoint, elt2.binaryPoint, elt1.width, elt2.width) match {
@@ -59,13 +62,17 @@ private[core] object cloneSupertype {
}
}
case (elt1, elt2) =>
- throw new AssertionError(s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}")
+ throw new AssertionError(
+ s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}")
}).asInstanceOf[T] }
model.chiselCloneType
- } else {
+ }
+ else {
for (elt <- elts.tail) {
- require(elt.getClass == elts.head.getClass, s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}")
- require(elt typeEquivalent elts.head, s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}")
+ require(elt.getClass == elts.head.getClass,
+ s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}")
+ require(elt typeEquivalent elts.head,
+ s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}")
}
elts.head.chiselCloneType
}
diff --git a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala
index c7b59d96..02382e57 100644
--- a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala
+++ b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala
@@ -2,10 +2,13 @@
package chisel3.core
-import scala.language.experimental.macros
+import chisel3.internal.throwException
+import scala.language.experimental.macros
import chisel3.internal.sourceinfo._
+//scalastyle:off method.name
+
private[chisel3] object SeqUtils {
/** Concatenates the data elements of the input sequence, in sequence order, together.
* The first element of the sequence forms the least significant bits, while the last element
@@ -39,7 +42,8 @@ private[chisel3] object SeqUtils {
*/
def priorityMux[T <: Data](in: Seq[(Bool, T)]): T = macro CompileOptionsTransform.inArg
- def do_priorityMux[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = {
+ def do_priorityMux[T <: Data](in: Seq[(Bool, T)])
+ (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = {
if (in.size == 1) {
in.head._2
} else {
@@ -48,18 +52,70 @@ private[chisel3] object SeqUtils {
}
/** Returns the data value corresponding to the lone true predicate.
+ * This is elaborated to firrtl using a structure that should be optimized into and and/or tree.
*
* @note assumes exactly one true predicate, results undefined otherwise
+ * FixedPoint values or aggregates containing FixedPoint values cause this optimized structure to be lost
*/
def oneHotMux[T <: Data](in: Iterable[(Bool, T)]): T = macro CompileOptionsTransform.inArg
- def do_oneHotMux[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = {
+ //scalastyle:off method.length cyclomatic.complexity
+ def do_oneHotMux[T <: Data](in: Iterable[(Bool, T)])
+ (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = {
if (in.tail.isEmpty) {
in.head._2
- } else {
- val masked = for ((s, i) <- in) yield Mux(s, i.asUInt, 0.U)
- val output = cloneSupertype(in.toSeq map {_._2}, "oneHotMux")
- output.fromBits(masked.reduceLeft(_|_))
+ }
+ else {
+ val output = cloneSupertype(in.toSeq map { _._2}, "oneHotMux")
+
+ def buildAndOrMultiplexor[TT <: Data](inputs: Iterable[(Bool, TT)]): T = {
+ val masked = for ((s, i) <- inputs) yield Mux(s, i.asUInt(), 0.U)
+ output.fromBits(masked.reduceLeft(_ | _))
+ }
+
+ output match {
+ case _: SInt =>
+ // SInt's have to be managed carefully so sign extension works
+
+ val sInts: Iterable[(Bool, SInt)] = in.collect { case (s: Bool, f: SInt) =>
+ (s, f.asTypeOf(output).asInstanceOf[SInt])
+ }
+
+ val masked = for ((s, i) <- sInts) yield Mux(s, i, 0.S)
+ output.fromBits(masked.reduceLeft(_ | _))
+
+ case _: FixedPoint =>
+ val (sels, possibleOuts) = in.toSeq.unzip
+
+ val (intWidths, binaryPoints) = in.toSeq.map { case (_, o) =>
+ val fo = o.asInstanceOf[FixedPoint]
+ require(fo.binaryPoint.known, "Mux1H requires width/binary points to be defined")
+ (fo.getWidth - fo.binaryPoint.get, fo.binaryPoint.get)
+ }.unzip
+
+ if (intWidths.distinct.length == 1 && binaryPoints.distinct.length == 1) {
+ buildAndOrMultiplexor(in)
+ }
+ else {
+ val maxIntWidth = intWidths.max
+ val maxBP = binaryPoints.max
+ val inWidthMatched = Seq.fill(intWidths.length)(Wire(FixedPoint((maxIntWidth + maxBP).W, maxBP.BP)))
+ inWidthMatched.zipWithIndex foreach { case (e, idx) => e := possibleOuts(idx).asInstanceOf[FixedPoint] }
+ buildAndOrMultiplexor(sels.zip(inWidthMatched))
+ }
+
+ case _: Aggregate =>
+ val allDefineWidth = in.forall { case (_, element) => element.widthOption.isDefined }
+ if(allDefineWidth) {
+ buildAndOrMultiplexor(in)
+ }
+ else {
+ throwException(s"Cannot Mux1H with aggregates with inferred widths")
+ }
+
+ case _ =>
+ buildAndOrMultiplexor(in)
+ }
}
}
}
diff --git a/src/test/scala/chiselTests/OneHotMuxSpec.scala b/src/test/scala/chiselTests/OneHotMuxSpec.scala
new file mode 100644
index 00000000..c2efb6f8
--- /dev/null
+++ b/src/test/scala/chiselTests/OneHotMuxSpec.scala
@@ -0,0 +1,286 @@
+// See LICENSE for license details.
+
+package chiselTests
+
+import Chisel.testers.BasicTester
+import chisel3._
+import chisel3.experimental.FixedPoint
+import chisel3.util.Mux1H
+import org.scalatest._
+
+//scalastyle:off magic.number
+
+class OneHotMuxSpec extends FreeSpec with Matchers with ChiselRunners {
+ "simple one hot mux with uint should work" in {
+ assertTesterPasses(new SimpleOneHotTester)
+ }
+ "simple one hot mux with sint should work" in {
+ assertTesterPasses(new SIntOneHotTester)
+ }
+ "simple one hot mux with fixed point should work" in {
+ assertTesterPasses(new FixedPointOneHotTester)
+ }
+ "simple one hot mux with all same fixed point should work" in {
+ assertTesterPasses(new AllSameFixedPointOneHotTester)
+ }
+ "simple one hot mux with all same parameterized sint values should work" in {
+ val values: Seq[SInt] = Seq((-3).S, (-5).S, (-7).S, (-11).S)
+ assertTesterPasses(new ParameterizedOneHotTester(values, SInt(8.W), -5.S(8.W)))
+ }
+ "simple one hot mux with all same parameterized aggregates containing fixed values should work" in {
+ assertTesterPasses(new ParameterizedAggregateOneHotTester)
+ }
+ "simple one hot mux with all aggregates containing inferred width fixed values should NOT work" in {
+ intercept[ChiselException] {
+ assertTesterPasses(new InferredWidthAggregateOneHotTester)
+ }
+ }
+ "simple one hot mux with all fixed width bundles but with different bundles should Not work" in {
+ intercept[IllegalArgumentException] {
+ assertTesterPasses(new DifferentBundleOneHotTester)
+ }
+ }
+}
+
+class SimpleOneHotTester extends BasicTester {
+ val out = Wire(UInt())
+ out := Mux1H(Seq(
+ false.B -> 2.U,
+ false.B -> 4.U,
+ true.B -> 8.U,
+ false.B -> 11.U
+ ))
+
+ assert(out === 8.U)
+
+ stop()
+}
+
+class SIntOneHotTester extends BasicTester {
+ val out = Wire(SInt())
+ out := Mux1H(Seq(
+ false.B -> (-3).S,
+ true.B -> (-5).S,
+ false.B -> (-7).S,
+ false.B -> (-11).S
+ ))
+
+ assert(out === (-5).S)
+
+ stop()
+}
+
+class FixedPointOneHotTester extends BasicTester {
+ val out = Wire(FixedPoint(8.W, 4.BP))
+
+ out := Mux1H(Seq(
+ false.B -> (-1.5).F(1.BP),
+ true.B -> (-2.25).F(2.BP),
+ false.B -> (-4.125).F(3.BP),
+ false.B -> (-11.625).F(3.BP)
+ ))
+
+ assert(out === (-2.25).F(4.BP))
+
+ stop()
+}
+
+class AllSameFixedPointOneHotTester extends BasicTester {
+ val out = Wire(FixedPoint(12.W, 3.BP))
+
+ out := Mux1H(Seq(
+ false.B -> (-1.5).F(12.W, 3.BP),
+ true.B -> (-2.25).F(12.W, 3.BP),
+ false.B -> (-4.125).F(12.W, 3.BP),
+ false.B -> (-11.625).F(12.W, 3.BP)
+ ))
+
+ assert(out === (-2.25).F(14.W, 4.BP))
+
+ stop()
+}
+
+class ParameterizedOneHotTester[T <: Data](values: Seq[T], outGen: T, expected: T) extends BasicTester {
+ val dut = Module(new ParameterizedOneHot(values, outGen))
+ dut.io.selectors(0) := false.B
+ dut.io.selectors(1) := true.B
+ dut.io.selectors(2) := false.B
+ dut.io.selectors(3) := false.B
+
+ assert(dut.io.out.asUInt() === expected.asUInt())
+
+ stop()
+}
+
+class Agg1 extends Bundle {
+ val v = Vec(2, FixedPoint(8.W, 4.BP))
+ val a = new Bundle {
+ val f1 = FixedPoint(7.W, 3.BP)
+ val f2 = FixedPoint(9.W, 5.BP)
+ }
+}
+
+object Agg1 extends HasMakeLit[Agg1] {
+ def makeLit(n: Int): Agg1 = {
+ val x = n.toDouble / 4.0
+ val (d: Double, e: Double, f: Double, g: Double) = (x, x * 2.0, x * 3.0, x * 4.0)
+
+ val w = Wire(new Agg1)
+ w.v(0) := Wire(d.F(4.BP))
+ w.v(1) := Wire(e.F(4.BP))
+ w.a.f1 := Wire(f.F(3.BP))
+ w.a.f2 := Wire(g.F(5.BP))
+ w
+ }
+}
+class Agg2 extends Bundle {
+ val v = Vec(2, FixedPoint(8.W, 4.BP))
+ val a = new Bundle {
+ val f1 = FixedPoint(7.W, 3.BP)
+ val f2 = FixedPoint(9.W, 5.BP)
+ }
+}
+
+object Agg2 extends HasMakeLit[Agg2] {
+ def makeLit(n: Int): Agg2 = {
+ val x = n.toDouble / 4.0
+ val (d: Double, e: Double, f: Double, g: Double) = (x, x * 2.0, x * 3.0, x * 4.0)
+
+ val w = Wire(new Agg2)
+ w.v(0) := Wire(d.F(4.BP))
+ w.v(1) := Wire(e.F(4.BP))
+ w.a.f1 := Wire(f.F(3.BP))
+ w.a.f2 := Wire(g.F(5.BP))
+ w
+ }
+}
+
+class ParameterizedAggregateOneHotTester extends BasicTester {
+ val values = (0 until 4).map { n => Agg1.makeLit(n) }
+
+ val dut = Module(new ParameterizedAggregateOneHot(Agg1, new Agg1))
+ dut.io.selectors(0) := false.B
+ dut.io.selectors(1) := true.B
+ dut.io.selectors(2) := false.B
+ dut.io.selectors(3) := false.B
+
+ assert(dut.io.out.asUInt() === values(1).asUInt())
+
+ stop()
+}
+
+trait HasMakeLit[T] {
+ def makeLit(n: Int): T
+}
+
+class ParameterizedOneHot[T <: Data](values: Seq[T], outGen: T) extends Module {
+ val io = IO(new Bundle {
+ val selectors = Input(Vec(4, Bool()))
+ val out = Output(outGen)
+ })
+
+ val terms = io.selectors.zip(values)
+ io.out := Mux1H(terms)
+}
+
+class ParameterizedAggregateOneHot[T <: Data](valGen: HasMakeLit[T], outGen: T) extends Module {
+ val io = IO(new Bundle {
+ val selectors = Input(Vec(4, Bool()))
+ val out = Output(outGen)
+ })
+
+
+ val values = (0 until 4).map { n => valGen.makeLit(n) }
+ val terms = io.selectors.zip(values)
+ io.out := Mux1H(terms)
+}
+
+class Bundle1 extends Bundle {
+ val a = FixedPoint()
+ val b = new Bundle {
+ val c = FixedPoint()
+ }
+}
+
+class InferredWidthAggregateOneHotTester extends BasicTester {
+ val b0 = Wire(new Bundle1)
+ b0.a := -0.25.F(2.BP)
+ b0.b.c := -0.125.F(3.BP)
+
+ val b1 = Wire(new Bundle1)
+ b1.a := -0.0625.F(3.BP)
+ b1.b.c := -0.03125.F(4.BP)
+
+ val b2 = Wire(new Bundle1)
+ b2.a := -0.015625.F(5.BP)
+ b2.b.c := -0.0078125.F(6.BP)
+
+ val b3 = Wire(new Bundle1)
+ b3.a := -0.0078125.F(7.BP)
+ b3.b.c := -0.00390625.F(8.BP)
+
+ val o1 = Mux1H(Seq(
+ false.B -> b0,
+ false.B -> b1,
+ true.B -> b2,
+ false.B -> b3
+ ))
+
+ assert(o1.a === -0.015625.F(5.BP))
+ assert(o1.b.c === -0.0078125.F(6.BP))
+
+ val o2 = Mux1H(Seq(
+ false.B -> b0,
+ true.B -> b1,
+ false.B -> b2,
+ false.B -> b3
+ ))
+
+ assert(o2.a === -0.0625.F(3.BP))
+ assert(o2.b.c === -0.03125.F(4.BP))
+
+ stop()
+}
+
+class Bundle2 extends Bundle {
+ val a = FixedPoint(10.W, 4.BP)
+ val b = new Bundle {
+ val c = FixedPoint(10.W, 4.BP)
+ }
+}
+
+class Bundle3 extends Bundle {
+ val a = FixedPoint(10.W, 4.BP)
+ val b = new Bundle {
+ val c = FixedPoint(10.W, 4.BP)
+ }
+}
+
+class DifferentBundleOneHotTester extends BasicTester {
+ val b0 = Wire(new Bundle2)
+ b0.a := -0.25.F(2.BP)
+ b0.b.c := -0.125.F(3.BP)
+
+ val b1 = Wire(new Bundle2)
+ b1.a := -0.0625.F(3.BP)
+ b1.b.c := -0.03125.F(4.BP)
+
+ val b2 = Wire(new Bundle3)
+ b2.a := -0.015625.F(5.BP)
+ b2.b.c := -0.0078125.F(6.BP)
+
+ val b3 = Wire(new Bundle3)
+ b3.a := -0.0078125.F(7.BP)
+ b3.b.c := -0.00390625.F(8.BP)
+
+ val o1 = Mux1H(Seq(
+ false.B -> b0,
+ false.B -> b1,
+ true.B -> b2,
+ false.B -> b3
+ ))
+
+ stop()
+}
+
+