diff options
| -rw-r--r-- | chiselFrontend/src/main/scala/chisel3/core/Data.scala | 19 | ||||
| -rw-r--r-- | chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala | 70 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/OneHotMuxSpec.scala | 286 |
3 files changed, 362 insertions, 13 deletions
diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index c9cfe27b..556f2aeb 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -35,9 +35,10 @@ object DataMirror { * - For other types of the same class are are the same: clone of any of the elements * - Otherwise: fail */ +//scalastyle:off cyclomatic.complexity private[core] object cloneSupertype { def apply[T <: Data](elts: Seq[T], createdType: String)(implicit sourceInfo: SourceInfo, - compileOptions: CompileOptions): T = { + compileOptions: CompileOptions): T = { require(!elts.isEmpty, s"can't create $createdType with no inputs") if (elts forall {_.isInstanceOf[Bits]}) { @@ -45,7 +46,9 @@ private[core] object cloneSupertype { case (elt1: Bool, elt2: Bool) => elt1 case (elt1: Bool, elt2: UInt) => elt2 // TODO: what happens with zero width UInts? case (elt1: UInt, elt2: Bool) => elt1 // TODO: what happens with zero width UInts? - case (elt1: UInt, elt2: UInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 // TODO: perhaps redefine Widths to allow >= op? + case (elt1: UInt, elt2: UInt) => + // TODO: perhaps redefine Widths to allow >= op? + if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 case (elt1: SInt, elt2: SInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 case (elt1: FixedPoint, elt2: FixedPoint) => { (elt1.binaryPoint, elt2.binaryPoint, elt1.width, elt2.width) match { @@ -59,13 +62,17 @@ private[core] object cloneSupertype { } } case (elt1, elt2) => - throw new AssertionError(s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}") + throw new AssertionError( + s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}") }).asInstanceOf[T] } model.chiselCloneType - } else { + } + else { for (elt <- elts.tail) { - require(elt.getClass == elts.head.getClass, s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}") - require(elt typeEquivalent elts.head, s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}") + require(elt.getClass == elts.head.getClass, + s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}") + require(elt typeEquivalent elts.head, + s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}") } elts.head.chiselCloneType } diff --git a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala index c7b59d96..02382e57 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala @@ -2,10 +2,13 @@ package chisel3.core -import scala.language.experimental.macros +import chisel3.internal.throwException +import scala.language.experimental.macros import chisel3.internal.sourceinfo._ +//scalastyle:off method.name + private[chisel3] object SeqUtils { /** Concatenates the data elements of the input sequence, in sequence order, together. * The first element of the sequence forms the least significant bits, while the last element @@ -39,7 +42,8 @@ private[chisel3] object SeqUtils { */ def priorityMux[T <: Data](in: Seq[(Bool, T)]): T = macro CompileOptionsTransform.inArg - def do_priorityMux[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { + def do_priorityMux[T <: Data](in: Seq[(Bool, T)]) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { if (in.size == 1) { in.head._2 } else { @@ -48,18 +52,70 @@ private[chisel3] object SeqUtils { } /** Returns the data value corresponding to the lone true predicate. + * This is elaborated to firrtl using a structure that should be optimized into and and/or tree. * * @note assumes exactly one true predicate, results undefined otherwise + * FixedPoint values or aggregates containing FixedPoint values cause this optimized structure to be lost */ def oneHotMux[T <: Data](in: Iterable[(Bool, T)]): T = macro CompileOptionsTransform.inArg - def do_oneHotMux[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { + //scalastyle:off method.length cyclomatic.complexity + def do_oneHotMux[T <: Data](in: Iterable[(Bool, T)]) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { if (in.tail.isEmpty) { in.head._2 - } else { - val masked = for ((s, i) <- in) yield Mux(s, i.asUInt, 0.U) - val output = cloneSupertype(in.toSeq map {_._2}, "oneHotMux") - output.fromBits(masked.reduceLeft(_|_)) + } + else { + val output = cloneSupertype(in.toSeq map { _._2}, "oneHotMux") + + def buildAndOrMultiplexor[TT <: Data](inputs: Iterable[(Bool, TT)]): T = { + val masked = for ((s, i) <- inputs) yield Mux(s, i.asUInt(), 0.U) + output.fromBits(masked.reduceLeft(_ | _)) + } + + output match { + case _: SInt => + // SInt's have to be managed carefully so sign extension works + + val sInts: Iterable[(Bool, SInt)] = in.collect { case (s: Bool, f: SInt) => + (s, f.asTypeOf(output).asInstanceOf[SInt]) + } + + val masked = for ((s, i) <- sInts) yield Mux(s, i, 0.S) + output.fromBits(masked.reduceLeft(_ | _)) + + case _: FixedPoint => + val (sels, possibleOuts) = in.toSeq.unzip + + val (intWidths, binaryPoints) = in.toSeq.map { case (_, o) => + val fo = o.asInstanceOf[FixedPoint] + require(fo.binaryPoint.known, "Mux1H requires width/binary points to be defined") + (fo.getWidth - fo.binaryPoint.get, fo.binaryPoint.get) + }.unzip + + if (intWidths.distinct.length == 1 && binaryPoints.distinct.length == 1) { + buildAndOrMultiplexor(in) + } + else { + val maxIntWidth = intWidths.max + val maxBP = binaryPoints.max + val inWidthMatched = Seq.fill(intWidths.length)(Wire(FixedPoint((maxIntWidth + maxBP).W, maxBP.BP))) + inWidthMatched.zipWithIndex foreach { case (e, idx) => e := possibleOuts(idx).asInstanceOf[FixedPoint] } + buildAndOrMultiplexor(sels.zip(inWidthMatched)) + } + + case _: Aggregate => + val allDefineWidth = in.forall { case (_, element) => element.widthOption.isDefined } + if(allDefineWidth) { + buildAndOrMultiplexor(in) + } + else { + throwException(s"Cannot Mux1H with aggregates with inferred widths") + } + + case _ => + buildAndOrMultiplexor(in) + } } } } diff --git a/src/test/scala/chiselTests/OneHotMuxSpec.scala b/src/test/scala/chiselTests/OneHotMuxSpec.scala new file mode 100644 index 00000000..c2efb6f8 --- /dev/null +++ b/src/test/scala/chiselTests/OneHotMuxSpec.scala @@ -0,0 +1,286 @@ +// See LICENSE for license details. + +package chiselTests + +import Chisel.testers.BasicTester +import chisel3._ +import chisel3.experimental.FixedPoint +import chisel3.util.Mux1H +import org.scalatest._ + +//scalastyle:off magic.number + +class OneHotMuxSpec extends FreeSpec with Matchers with ChiselRunners { + "simple one hot mux with uint should work" in { + assertTesterPasses(new SimpleOneHotTester) + } + "simple one hot mux with sint should work" in { + assertTesterPasses(new SIntOneHotTester) + } + "simple one hot mux with fixed point should work" in { + assertTesterPasses(new FixedPointOneHotTester) + } + "simple one hot mux with all same fixed point should work" in { + assertTesterPasses(new AllSameFixedPointOneHotTester) + } + "simple one hot mux with all same parameterized sint values should work" in { + val values: Seq[SInt] = Seq((-3).S, (-5).S, (-7).S, (-11).S) + assertTesterPasses(new ParameterizedOneHotTester(values, SInt(8.W), -5.S(8.W))) + } + "simple one hot mux with all same parameterized aggregates containing fixed values should work" in { + assertTesterPasses(new ParameterizedAggregateOneHotTester) + } + "simple one hot mux with all aggregates containing inferred width fixed values should NOT work" in { + intercept[ChiselException] { + assertTesterPasses(new InferredWidthAggregateOneHotTester) + } + } + "simple one hot mux with all fixed width bundles but with different bundles should Not work" in { + intercept[IllegalArgumentException] { + assertTesterPasses(new DifferentBundleOneHotTester) + } + } +} + +class SimpleOneHotTester extends BasicTester { + val out = Wire(UInt()) + out := Mux1H(Seq( + false.B -> 2.U, + false.B -> 4.U, + true.B -> 8.U, + false.B -> 11.U + )) + + assert(out === 8.U) + + stop() +} + +class SIntOneHotTester extends BasicTester { + val out = Wire(SInt()) + out := Mux1H(Seq( + false.B -> (-3).S, + true.B -> (-5).S, + false.B -> (-7).S, + false.B -> (-11).S + )) + + assert(out === (-5).S) + + stop() +} + +class FixedPointOneHotTester extends BasicTester { + val out = Wire(FixedPoint(8.W, 4.BP)) + + out := Mux1H(Seq( + false.B -> (-1.5).F(1.BP), + true.B -> (-2.25).F(2.BP), + false.B -> (-4.125).F(3.BP), + false.B -> (-11.625).F(3.BP) + )) + + assert(out === (-2.25).F(4.BP)) + + stop() +} + +class AllSameFixedPointOneHotTester extends BasicTester { + val out = Wire(FixedPoint(12.W, 3.BP)) + + out := Mux1H(Seq( + false.B -> (-1.5).F(12.W, 3.BP), + true.B -> (-2.25).F(12.W, 3.BP), + false.B -> (-4.125).F(12.W, 3.BP), + false.B -> (-11.625).F(12.W, 3.BP) + )) + + assert(out === (-2.25).F(14.W, 4.BP)) + + stop() +} + +class ParameterizedOneHotTester[T <: Data](values: Seq[T], outGen: T, expected: T) extends BasicTester { + val dut = Module(new ParameterizedOneHot(values, outGen)) + dut.io.selectors(0) := false.B + dut.io.selectors(1) := true.B + dut.io.selectors(2) := false.B + dut.io.selectors(3) := false.B + + assert(dut.io.out.asUInt() === expected.asUInt()) + + stop() +} + +class Agg1 extends Bundle { + val v = Vec(2, FixedPoint(8.W, 4.BP)) + val a = new Bundle { + val f1 = FixedPoint(7.W, 3.BP) + val f2 = FixedPoint(9.W, 5.BP) + } +} + +object Agg1 extends HasMakeLit[Agg1] { + def makeLit(n: Int): Agg1 = { + val x = n.toDouble / 4.0 + val (d: Double, e: Double, f: Double, g: Double) = (x, x * 2.0, x * 3.0, x * 4.0) + + val w = Wire(new Agg1) + w.v(0) := Wire(d.F(4.BP)) + w.v(1) := Wire(e.F(4.BP)) + w.a.f1 := Wire(f.F(3.BP)) + w.a.f2 := Wire(g.F(5.BP)) + w + } +} +class Agg2 extends Bundle { + val v = Vec(2, FixedPoint(8.W, 4.BP)) + val a = new Bundle { + val f1 = FixedPoint(7.W, 3.BP) + val f2 = FixedPoint(9.W, 5.BP) + } +} + +object Agg2 extends HasMakeLit[Agg2] { + def makeLit(n: Int): Agg2 = { + val x = n.toDouble / 4.0 + val (d: Double, e: Double, f: Double, g: Double) = (x, x * 2.0, x * 3.0, x * 4.0) + + val w = Wire(new Agg2) + w.v(0) := Wire(d.F(4.BP)) + w.v(1) := Wire(e.F(4.BP)) + w.a.f1 := Wire(f.F(3.BP)) + w.a.f2 := Wire(g.F(5.BP)) + w + } +} + +class ParameterizedAggregateOneHotTester extends BasicTester { + val values = (0 until 4).map { n => Agg1.makeLit(n) } + + val dut = Module(new ParameterizedAggregateOneHot(Agg1, new Agg1)) + dut.io.selectors(0) := false.B + dut.io.selectors(1) := true.B + dut.io.selectors(2) := false.B + dut.io.selectors(3) := false.B + + assert(dut.io.out.asUInt() === values(1).asUInt()) + + stop() +} + +trait HasMakeLit[T] { + def makeLit(n: Int): T +} + +class ParameterizedOneHot[T <: Data](values: Seq[T], outGen: T) extends Module { + val io = IO(new Bundle { + val selectors = Input(Vec(4, Bool())) + val out = Output(outGen) + }) + + val terms = io.selectors.zip(values) + io.out := Mux1H(terms) +} + +class ParameterizedAggregateOneHot[T <: Data](valGen: HasMakeLit[T], outGen: T) extends Module { + val io = IO(new Bundle { + val selectors = Input(Vec(4, Bool())) + val out = Output(outGen) + }) + + + val values = (0 until 4).map { n => valGen.makeLit(n) } + val terms = io.selectors.zip(values) + io.out := Mux1H(terms) +} + +class Bundle1 extends Bundle { + val a = FixedPoint() + val b = new Bundle { + val c = FixedPoint() + } +} + +class InferredWidthAggregateOneHotTester extends BasicTester { + val b0 = Wire(new Bundle1) + b0.a := -0.25.F(2.BP) + b0.b.c := -0.125.F(3.BP) + + val b1 = Wire(new Bundle1) + b1.a := -0.0625.F(3.BP) + b1.b.c := -0.03125.F(4.BP) + + val b2 = Wire(new Bundle1) + b2.a := -0.015625.F(5.BP) + b2.b.c := -0.0078125.F(6.BP) + + val b3 = Wire(new Bundle1) + b3.a := -0.0078125.F(7.BP) + b3.b.c := -0.00390625.F(8.BP) + + val o1 = Mux1H(Seq( + false.B -> b0, + false.B -> b1, + true.B -> b2, + false.B -> b3 + )) + + assert(o1.a === -0.015625.F(5.BP)) + assert(o1.b.c === -0.0078125.F(6.BP)) + + val o2 = Mux1H(Seq( + false.B -> b0, + true.B -> b1, + false.B -> b2, + false.B -> b3 + )) + + assert(o2.a === -0.0625.F(3.BP)) + assert(o2.b.c === -0.03125.F(4.BP)) + + stop() +} + +class Bundle2 extends Bundle { + val a = FixedPoint(10.W, 4.BP) + val b = new Bundle { + val c = FixedPoint(10.W, 4.BP) + } +} + +class Bundle3 extends Bundle { + val a = FixedPoint(10.W, 4.BP) + val b = new Bundle { + val c = FixedPoint(10.W, 4.BP) + } +} + +class DifferentBundleOneHotTester extends BasicTester { + val b0 = Wire(new Bundle2) + b0.a := -0.25.F(2.BP) + b0.b.c := -0.125.F(3.BP) + + val b1 = Wire(new Bundle2) + b1.a := -0.0625.F(3.BP) + b1.b.c := -0.03125.F(4.BP) + + val b2 = Wire(new Bundle3) + b2.a := -0.015625.F(5.BP) + b2.b.c := -0.0078125.F(6.BP) + + val b3 = Wire(new Bundle3) + b3.a := -0.0078125.F(7.BP) + b3.b.c := -0.00390625.F(8.BP) + + val o1 = Mux1H(Seq( + false.B -> b0, + false.B -> b1, + true.B -> b2, + false.B -> b3 + )) + + stop() +} + + |
