From 97902cdc53eec52aa0cd806b8cb49a0e3f2fb769 Mon Sep 17 00:00:00 2001 From: Chick Markley Date: Wed, 12 Apr 2017 20:55:37 -0700 Subject: Fix one hot mux (#573) * still trying to find right mix * Making some progress on Mux1H * Mux1H that works in non-optimzed fashion for FixedPoint, works pretty well in general Catches some additional problem edge cases Some tests that illustrate most of this * Moved in Angie's code for handling FixedPoint case Cleaned up tests considerably, per @ducky64 review * Just a bit more cleanup --- .../src/main/scala/chisel3/core/Data.scala | 19 ++++-- .../src/main/scala/chisel3/core/SeqUtils.scala | 70 +++++++++++++++++++--- 2 files changed, 76 insertions(+), 13 deletions(-) (limited to 'chiselFrontend') diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index c9cfe27b..556f2aeb 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -35,9 +35,10 @@ object DataMirror { * - For other types of the same class are are the same: clone of any of the elements * - Otherwise: fail */ +//scalastyle:off cyclomatic.complexity private[core] object cloneSupertype { def apply[T <: Data](elts: Seq[T], createdType: String)(implicit sourceInfo: SourceInfo, - compileOptions: CompileOptions): T = { + compileOptions: CompileOptions): T = { require(!elts.isEmpty, s"can't create $createdType with no inputs") if (elts forall {_.isInstanceOf[Bits]}) { @@ -45,7 +46,9 @@ private[core] object cloneSupertype { case (elt1: Bool, elt2: Bool) => elt1 case (elt1: Bool, elt2: UInt) => elt2 // TODO: what happens with zero width UInts? case (elt1: UInt, elt2: Bool) => elt1 // TODO: what happens with zero width UInts? - case (elt1: UInt, elt2: UInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 // TODO: perhaps redefine Widths to allow >= op? + case (elt1: UInt, elt2: UInt) => + // TODO: perhaps redefine Widths to allow >= op? + if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 case (elt1: SInt, elt2: SInt) => if (elt1.width == (elt1.width max elt2.width)) elt1 else elt2 case (elt1: FixedPoint, elt2: FixedPoint) => { (elt1.binaryPoint, elt2.binaryPoint, elt1.width, elt2.width) match { @@ -59,13 +62,17 @@ private[core] object cloneSupertype { } } case (elt1, elt2) => - throw new AssertionError(s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}") + throw new AssertionError( + s"can't create $createdType with heterogeneous Bits types ${elt1.getClass} and ${elt2.getClass}") }).asInstanceOf[T] } model.chiselCloneType - } else { + } + else { for (elt <- elts.tail) { - require(elt.getClass == elts.head.getClass, s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}") - require(elt typeEquivalent elts.head, s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}") + require(elt.getClass == elts.head.getClass, + s"can't create $createdType with heterogeneous types ${elts.head.getClass} and ${elt.getClass}") + require(elt typeEquivalent elts.head, + s"can't create $createdType with non-equivalent types ${elts.head} and ${elt}") } elts.head.chiselCloneType } diff --git a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala index c7b59d96..02382e57 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/SeqUtils.scala @@ -2,10 +2,13 @@ package chisel3.core -import scala.language.experimental.macros +import chisel3.internal.throwException +import scala.language.experimental.macros import chisel3.internal.sourceinfo._ +//scalastyle:off method.name + private[chisel3] object SeqUtils { /** Concatenates the data elements of the input sequence, in sequence order, together. * The first element of the sequence forms the least significant bits, while the last element @@ -39,7 +42,8 @@ private[chisel3] object SeqUtils { */ def priorityMux[T <: Data](in: Seq[(Bool, T)]): T = macro CompileOptionsTransform.inArg - def do_priorityMux[T <: Data](in: Seq[(Bool, T)])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { + def do_priorityMux[T <: Data](in: Seq[(Bool, T)]) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { if (in.size == 1) { in.head._2 } else { @@ -48,18 +52,70 @@ private[chisel3] object SeqUtils { } /** Returns the data value corresponding to the lone true predicate. + * This is elaborated to firrtl using a structure that should be optimized into and and/or tree. * * @note assumes exactly one true predicate, results undefined otherwise + * FixedPoint values or aggregates containing FixedPoint values cause this optimized structure to be lost */ def oneHotMux[T <: Data](in: Iterable[(Bool, T)]): T = macro CompileOptionsTransform.inArg - def do_oneHotMux[T <: Data](in: Iterable[(Bool, T)])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { + //scalastyle:off method.length cyclomatic.complexity + def do_oneHotMux[T <: Data](in: Iterable[(Bool, T)]) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): T = { if (in.tail.isEmpty) { in.head._2 - } else { - val masked = for ((s, i) <- in) yield Mux(s, i.asUInt, 0.U) - val output = cloneSupertype(in.toSeq map {_._2}, "oneHotMux") - output.fromBits(masked.reduceLeft(_|_)) + } + else { + val output = cloneSupertype(in.toSeq map { _._2}, "oneHotMux") + + def buildAndOrMultiplexor[TT <: Data](inputs: Iterable[(Bool, TT)]): T = { + val masked = for ((s, i) <- inputs) yield Mux(s, i.asUInt(), 0.U) + output.fromBits(masked.reduceLeft(_ | _)) + } + + output match { + case _: SInt => + // SInt's have to be managed carefully so sign extension works + + val sInts: Iterable[(Bool, SInt)] = in.collect { case (s: Bool, f: SInt) => + (s, f.asTypeOf(output).asInstanceOf[SInt]) + } + + val masked = for ((s, i) <- sInts) yield Mux(s, i, 0.S) + output.fromBits(masked.reduceLeft(_ | _)) + + case _: FixedPoint => + val (sels, possibleOuts) = in.toSeq.unzip + + val (intWidths, binaryPoints) = in.toSeq.map { case (_, o) => + val fo = o.asInstanceOf[FixedPoint] + require(fo.binaryPoint.known, "Mux1H requires width/binary points to be defined") + (fo.getWidth - fo.binaryPoint.get, fo.binaryPoint.get) + }.unzip + + if (intWidths.distinct.length == 1 && binaryPoints.distinct.length == 1) { + buildAndOrMultiplexor(in) + } + else { + val maxIntWidth = intWidths.max + val maxBP = binaryPoints.max + val inWidthMatched = Seq.fill(intWidths.length)(Wire(FixedPoint((maxIntWidth + maxBP).W, maxBP.BP))) + inWidthMatched.zipWithIndex foreach { case (e, idx) => e := possibleOuts(idx).asInstanceOf[FixedPoint] } + buildAndOrMultiplexor(sels.zip(inWidthMatched)) + } + + case _: Aggregate => + val allDefineWidth = in.forall { case (_, element) => element.widthOption.isDefined } + if(allDefineWidth) { + buildAndOrMultiplexor(in) + } + else { + throwException(s"Cannot Mux1H with aggregates with inferred widths") + } + + case _ => + buildAndOrMultiplexor(in) + } } } } -- cgit v1.2.3