diff options
17 files changed, 2899 insertions, 329 deletions
diff --git a/chiselFrontend/src/main/scala/chisel3/Bits.scala b/chiselFrontend/src/main/scala/chisel3/Bits.scala index ef9a752b..28d1690d 100644 --- a/chiselFrontend/src/main/scala/chisel3/Bits.scala +++ b/chiselFrontend/src/main/scala/chisel3/Bits.scala @@ -4,13 +4,15 @@ package chisel3 import scala.language.experimental.macros -import chisel3.experimental.FixedPoint +import chisel3.experimental.{FixedPoint, Interval} import chisel3.internal._ import chisel3.internal.Builder.pushOp import chisel3.internal.firrtl._ import chisel3.internal.sourceinfo.{SourceInfo, SourceInfoTransform, SourceInfoWhiteboxTransform, UIntTransform} import chisel3.internal.firrtl.PrimOp._ +import _root_.firrtl.{ir => firrtlir} +import _root_.firrtl.{constraint => firrtlconstraint} // scalastyle:off method.name line.size.limit file.size.limit @@ -349,6 +351,18 @@ sealed abstract class Bits(private[chisel3] val width: Width) extends Element wi throwException(s"Cannot call .asFixedPoint on $this") } + /** Reinterpret cast as a Interval. + * + * @note value not guaranteed to be preserved: for example, an UInt of width + * 3 and value 7 (0b111) would become a FixedInt with value -1, the interpretation + * of the number is also affected by the specified binary point. Caution advised + */ + final def asInterval(that: IntervalRange): Interval = macro SourceInfoTransform.thatArg + + def do_asInterval(that: IntervalRange)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + throwException(s"Cannot call .asInterval on $this") + } + final def do_asBool(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = { width match { case KnownWidth(1) => this(0) @@ -670,6 +684,27 @@ sealed class UInt private[chisel3] (width: Width) extends Bits(width) with Num[U } } + override def do_asInterval(range: IntervalRange = IntervalRange.Unknown) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + (range.lower, range.upper, range.binaryPoint) match { + case (lx: firrtlconstraint.IsKnown, ux: firrtlconstraint.IsKnown, KnownBinaryPoint(bp)) => + // No mechanism to pass open/close to firrtl so need to handle directly + val l = lx match { + case firrtlir.Open(x) => x + BigDecimal(1) / BigDecimal(BigInt(1) << bp) + case firrtlir.Closed(x) => x + } + val u = ux match { + case firrtlir.Open(x) => x - BigDecimal(1) / BigDecimal(BigInt(1) << bp) + case firrtlir.Closed(x) => x + } + val minBI = (l * BigDecimal(BigInt(1) << bp)).setScale(0, BigDecimal.RoundingMode.FLOOR).toBigIntExact.get + val maxBI = (u * BigDecimal(BigInt(1) << bp)).setScale(0, BigDecimal.RoundingMode.FLOOR).toBigIntExact.get + pushOp(DefPrim(sourceInfo, Interval(range), AsIntervalOp, ref, ILit(minBI), ILit(maxBI), ILit(bp))) + case _ => + throwException( + s"cannot call $this.asInterval($range), you must specify a known binaryPoint and range") + } + } private[chisel3] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Unit = { this := that.asUInt @@ -903,13 +938,35 @@ sealed class SInt private[chisel3] (width: Width) extends Bits(width) with Num[S } } + override def do_asInterval(range: IntervalRange = IntervalRange.Unknown) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + (range.lower, range.upper, range.binaryPoint) match { + case (lx: firrtlconstraint.IsKnown, ux: firrtlconstraint.IsKnown, KnownBinaryPoint(bp)) => + // No mechanism to pass open/close to firrtl so need to handle directly + val l = lx match { + case firrtlir.Open(x) => x + BigDecimal(1) / BigDecimal(BigInt(1) << bp) + case firrtlir.Closed(x) => x + } + val u = ux match { + case firrtlir.Open(x) => x - BigDecimal(1) / BigDecimal(BigInt(1) << bp) + case firrtlir.Closed(x) => x + } + //TODO: (chick) Need to determine, what asInterval needs, and why it might need min and max as args -- CAN IT BE UNKNOWN? + // Angie's operation: Decimal -> Int -> Decimal loses information. Need to be conservative here? + val minBI = (l * BigDecimal(BigInt(1) << bp)).setScale(0, BigDecimal.RoundingMode.FLOOR).toBigIntExact.get + val maxBI = (u * BigDecimal(BigInt(1) << bp)).setScale(0, BigDecimal.RoundingMode.FLOOR).toBigIntExact.get + pushOp(DefPrim(sourceInfo, Interval(range), AsIntervalOp, ref, ILit(minBI), ILit(maxBI), ILit(bp))) + case _ => + throwException( + s"cannot call $this.asInterval($range), you must specify a known binaryPoint and range") + } + } + private[chisel3] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) { this := that.asSInt } } -object SInt extends SIntFactory - sealed trait Reset extends Element with ToBoolable { /** Casts this $coll to an [[AsyncReset]] */ final def asAsyncReset(): AsyncReset = macro SourceInfoWhiteboxTransform.noArg @@ -1118,9 +1175,10 @@ sealed class Bool() extends UInt(1.W) with Reset { pushOp(DefPrim(sourceInfo, AsyncReset(), AsAsyncResetOp, ref)) } -object Bool extends BoolFactory - package experimental { + + import chisel3.internal.firrtl.BinaryPoint + //scalastyle:off number.of.methods /** A sealed class representing a fixed point number that has a bit width and a binary point The width and binary point * may be inferred. @@ -1138,7 +1196,6 @@ package experimental { */ sealed class FixedPoint private(width: Width, val binaryPoint: BinaryPoint) extends Bits(width) with Num[FixedPoint] { - import FixedPoint.Implicits._ override def toString: String = { val bindingString = litToDoubleOption match { @@ -1356,7 +1413,6 @@ package experimental { def do_unary_~ (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint = throwException(s"Not is illegal on $this") - // TODO(chick): Consider comparison with UInt and SInt override def do_< (that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that) override def do_> (that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that) override def do_<= (that: FixedPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that) @@ -1419,6 +1475,32 @@ package experimental { } } + def do_asInterval(binaryPoint: BinaryPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + throwException(s"cannot call $this.asInterval(binaryPoint=$binaryPoint), you must specify a range") + } + + override def do_asInterval(range: IntervalRange = IntervalRange.Unknown) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + (range.lower, range.upper, range.binaryPoint) match { + case (lx: firrtlconstraint.IsKnown, ux: firrtlconstraint.IsKnown, KnownBinaryPoint(bp)) => + // No mechanism to pass open/close to firrtl so need to handle directly + val l = lx match { + case firrtlir.Open(x) => x + BigDecimal(1) / BigDecimal(BigInt(1) << bp) + case firrtlir.Closed(x) => x + } + val u = ux match { + case firrtlir.Open(x) => x - BigDecimal(1) / BigDecimal(BigInt(1) << bp) + case firrtlir.Closed(x) => x + } + val minBI = (l * BigDecimal(BigInt(1) << bp)).setScale(0, BigDecimal.RoundingMode.FLOOR).toBigIntExact.get + val maxBI = (u * BigDecimal(BigInt(1) << bp)).setScale(0, BigDecimal.RoundingMode.FLOOR).toBigIntExact.get + pushOp(DefPrim(sourceInfo, Interval(range), AsIntervalOp, ref, ILit(minBI), ILit(maxBI), ILit(bp))) + case _ => + throwException( + s"cannot call $this.asInterval($range), you must specify a known binaryPoint and range") + } + } + private[chisel3] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) { // TODO: redefine as just asFixedPoint on that, where FixedPoint.asFixedPoint just works. this := (that match { @@ -1426,7 +1508,6 @@ package experimental { case _ => that.asFixedPoint(this.binaryPoint) }) } - //TODO(chick): Consider "convert" as an arithmetic conversion to UInt/SInt } /** Use PrivateObject to force users to specify width and binaryPoint by name @@ -1441,6 +1522,7 @@ package experimental { object FixedPoint { import FixedPoint.Implicits._ + /** Create an FixedPoint type with inferred width. */ def apply(): FixedPoint = apply(Width(), BinaryPoint()) @@ -1510,6 +1592,7 @@ package experimental { result } + object Implicits { // implicit class fromDoubleToLiteral(val double: Double) extends AnyVal { @@ -1522,12 +1605,715 @@ package experimental { FixedPoint.fromDouble(double, width, binaryPoint) } } + } + } + + //scalastyle:off number.of.methods cyclomatic.complexity + /** + * A sealed class representing a fixed point number that has a range, an additional + * parameter that can determine a minimum and maximum supported value. + * The range can be used to reduce the required widths particularly in primitive + * operations with other Intervals, the canonical example being + * {{{ + * val one = 1.I + * val six = Seq.fill(6)(one).reduce(_ + _) + * }}} + * A UInt computed in this way would require a [[Width]] + * binary point + * The width and binary point may be inferred. + * + * IMPORTANT: The API provided here is experimental and may change in the future. + * + * @param range a range specifies min, max and binary point + */ + sealed class Interval private[chisel3] (val range: chisel3.internal.firrtl.IntervalRange) + extends Bits(range.getWidth) with Num[Interval] { + + override def toString: String = { + val bindingString = litOption match { + case Some(value) => s"($value)" + case _ => bindingToString + } + s"Interval$width$bindingString" + } + + private[chisel3] override def cloneTypeWidth(w: Width): this.type = + new Interval(range).asInstanceOf[this.type] + + //scalastyle:off cyclomatic.complexity + def toType: String = { + val zdec1 = """([+\-]?[0-9]\d*)(\.[0-9]*[1-9])(0*)""".r + val zdec2 = """([+\-]?[0-9]\d*)(\.0*)""".r + val dec = """([+\-]?[0-9]\d*)(\.[0-9]\d*)""".r + val int = """([+\-]?[0-9]\d*)""".r + def dec2string(v: BigDecimal): String = v.toString match { + case zdec1(x, y, z) => x + y + case zdec2(x, y) => x + case other => other + } + + val lowerString = range.lower match { + case firrtlir.Open(l) => s"(${dec2string(l)}, " + case firrtlir.Closed(l) => s"[${dec2string(l)}, " + case firrtlir.UnknownBound => s"[?, " + case _ => s"[?, " + } + val upperString = range.upper match { + case firrtlir.Open(u) => s"${dec2string(u)})" + case firrtlir.Closed(u) => s"${dec2string(u)}]" + case firrtlir.UnknownBound => s"?]" + case _ => s"?]" + } + val bounds = lowerString + upperString + + val pointString = range.binaryPoint match { + case KnownBinaryPoint(i) => "." + i.toString + case _ => "" + } + "Interval" + bounds + pointString + } + + private[chisel3] override def typeEquivalent(that: Data): Boolean = + that.isInstanceOf[Interval] && this.width == that.width + + def binaryPoint: BinaryPoint = range.binaryPoint + + override def connect(that: Data)(implicit sourceInfo: SourceInfo, connectCompileOptions: CompileOptions): Unit = { + that match { + case _: Interval|DontCare => super.connect(that) + case _ => this badConnect that + } + } + + final def unary_-(): Interval = macro SourceInfoTransform.noArg + final def unary_-%(): Interval = macro SourceInfoTransform.noArg + + def unary_-(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + Interval.Zero - this + } + def unary_-%(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + Interval.Zero -% this + } + + /** add (default - growing) operator */ + override def do_+(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + this +& that + /** subtract (default - growing) operator */ + override def do_-(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + this -& that + override def do_*(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + binop(sourceInfo, Interval(this.range * that.range), TimesOp, that) + + override def do_/(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + throwException(s"division is illegal on Interval types") + override def do_%(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + throwException(s"mod is illegal on Interval types") + + /** add (width +1) operator */ + final def +&(that: Interval): Interval = macro SourceInfoTransform.thatArg + /** add (no growth) operator */ + final def +%(that: Interval): Interval = macro SourceInfoTransform.thatArg + /** subtract (width +1) operator */ + final def -&(that: Interval): Interval = macro SourceInfoTransform.thatArg + /** subtract (no growth) operator */ + final def -%(that: Interval): Interval = macro SourceInfoTransform.thatArg + + def do_+&(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + binop(sourceInfo, Interval(this.range +& that.range), AddOp, that) + } + + def do_+%(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + throwException(s"Non-growing addition is not supported on Intervals: ${sourceInfo}") + } + + def do_-&(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + binop(sourceInfo, Interval(this.range -& that.range), SubOp, that) + } + + def do_-%(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + throwException(s"Non-growing subtraction is not supported on Intervals: ${sourceInfo}, try squeeze") + } + + final def &(that: Interval): Interval = macro SourceInfoTransform.thatArg + final def |(that: Interval): Interval = macro SourceInfoTransform.thatArg + final def ^(that: Interval): Interval = macro SourceInfoTransform.thatArg + + def do_&(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + throwException(s"And is illegal between $this and $that") + def do_|(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + throwException(s"Or is illegal between $this and $that") + def do_^(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + throwException(s"Xor is illegal between $this and $that") + + final def setPrecision(that: Int): Interval = macro SourceInfoTransform.thatArg + + // Precision change changes range -- see firrtl PrimOps (requires floor) + // aaa.bbb -> aaa.bb for sbp(2) + def do_setPrecision(that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + val newBinaryPoint = BinaryPoint(that) + val newIntervalRange = this.range.setPrecision(newBinaryPoint) + binop(sourceInfo, Interval(newIntervalRange), SetBinaryPoint, that) + } + + /** Increase the precision of this Interval, moves the binary point to the left. + * aaa.bbb -> aaa.bbb00 + * @param that how many bits to shift binary point + * @return + */ + final def increasePrecision(that: Int): Interval = macro SourceInfoTransform.thatArg + + def do_increasePrecision(that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + assert(that > 0, s"Must increase precision by an integer greater than zero.") + val newBinaryPoint = BinaryPoint(that) + val newIntervalRange = this.range.incPrecision(newBinaryPoint) + binop(sourceInfo, Interval(newIntervalRange), IncreasePrecision, that) + } + + /** Decrease the precision of this Interval, moves the binary point to the right. + * aaa.bbb -> aaa.b + * + * @param that number of bits to move binary point + * @return + */ + final def decreasePrecision(that: Int): Interval = macro SourceInfoTransform.thatArg + + def do_decreasePrecision(that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + assert(that > 0, s"Must decrease precision by an integer greater than zero.") + val newBinaryPoint = BinaryPoint(that) + val newIntervalRange = this.range.decPrecision(newBinaryPoint) + binop(sourceInfo, Interval(newIntervalRange), DecreasePrecision, that) + } + + /** Returns this wire bitwise-inverted. */ + def do_unary_~ (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + throwException(s"Not is illegal on $this") + + override def do_< (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that) + override def do_> (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that) + override def do_<= (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that) + override def do_>= (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that) + + final def != (that: Interval): Bool = macro SourceInfoTransform.thatArg + final def =/= (that: Interval): Bool = macro SourceInfoTransform.thatArg + final def === (that: Interval): Bool = macro SourceInfoTransform.thatArg + + def do_!= (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that) + def do_=/= (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that) + def do_=== (that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that) + + // final def abs(): UInt = macro SourceInfoTransform.noArg + + def do_abs(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + Mux(this < Interval.Zero, (Interval.Zero - this), this) + } + + override def do_<< (that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + binop(sourceInfo, Interval(this.range << that), ShiftLeftOp, that) + + override def do_<< (that: BigInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + do_<<(that.toInt) + + override def do_<< (that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + binop(sourceInfo, Interval(this.range << that), DynamicShiftLeftOp, that) + } + + override def do_>> (that: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + binop(sourceInfo, Interval(this.range >> that), ShiftRightOp, that) + } + + override def do_>> (that: BigInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = + do_>>(that.toInt) + + override def do_>> (that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + binop(sourceInfo, Interval(this.range >> that), DynamicShiftRightOp, that) + } + + /** + * Squeeze returns the intersection of the ranges this interval and that Interval + * Ignores binary point of argument + * Treat as an unsafe cast; gives undefined behavior if this signal's value is outside of the resulting range + * Adds no additional hardware; this strictly an unsafe type conversion to use at your own risk + * @param that + * @return + */ + final def squeeze(that: Interval): Interval = macro SourceInfoTransform.thatArg + def do_squeeze(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + val other = that + requireIsHardware(this, s"'this' ($this)") + requireIsHardware(other, s"'other' ($other)") + pushOp(DefPrim(sourceInfo, Interval(this.range.squeeze(that.range)), SqueezeOp, this.ref, other.ref)) + } + + /** + * Squeeze returns the intersection of the ranges this interval and that UInt + * Currently, that must have a defined width + * Treat as an unsafe cast; gives undefined behavior if this signal's value is outside of the resulting range + * Adds no additional hardware; this strictly an unsafe type conversion to use at your own risk + * @param that an UInt whose properties determine the squeezing + * @return + */ + final def squeeze(that: UInt): Interval = macro SourceInfoTransform.thatArg + def do_squeeze(that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + that.widthOption match { + case Some(w) => + do_squeeze(Wire(Interval(IntervalRange(that.width, BinaryPoint(0))))) + case _ => + throwException(s"$this.squeeze($that) requires an UInt argument with a known width") + } + } - // implicit class fromIntToBinaryPoint(val int: Int) extends AnyVal { - implicit class fromIntToBinaryPoint(int: Int) { - def BP: BinaryPoint = BinaryPoint(int) // scalastyle:ignore method.name + /** + * Squeeze returns the intersection of the ranges this interval and that SInt + * Currently, that must have a defined width + * Treat as an unsafe cast; gives undefined behavior if this signal's value is outside of the resulting range + * Adds no additional hardware; this strictly an unsafe type conversion to use at your own risk + * @param that an SInt whose properties determine the squeezing + * @return + */ + final def squeeze(that: SInt): Interval = macro SourceInfoTransform.thatArg + def do_squeeze(that: SInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + that.widthOption match { + case Some(w) => + do_squeeze(Wire(Interval(IntervalRange(that.width, BinaryPoint(0))))) + case _ => + throwException(s"$this.squeeze($that) requires an SInt argument with a known width") } + } + + /** + * Squeeze returns the intersection of the ranges this interval and that IntervalRange + * Ignores binary point of argument + * Treat as an unsafe cast; gives undefined behavior if this signal's value is outside of the resulting range + * Adds no additional hardware; this strictly an unsafe type conversion to use at your own risk + * @param that an Interval whose properties determine the squeezing + * @return + */ + final def squeeze(that: IntervalRange): Interval = macro SourceInfoTransform.thatArg + def do_squeeze(that: IntervalRange)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + val intervalLitOpt = Interval.getSmallestLegalLit(that) + val intervalLit = intervalLitOpt.getOrElse( + throwException(s"$this.squeeze($that) requires an Interval range with known lower and upper bounds") + ) + do_squeeze(intervalLit) + } + + + /** + * Wrap the value of this [[Interval]] into the range of a different Interval with a presumably smaller range. + * Ignores binary point of argument + * Errors if requires wrapping more than once + * @param that + * @return + */ + final def wrap(that: Interval): Interval = macro SourceInfoTransform.thatArg + + def do_wrap(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + val other = that + requireIsHardware(this, s"'this' ($this)") + requireIsHardware(other, s"'other' ($other)") + pushOp(DefPrim(sourceInfo, Interval(this.range.wrap(that.range)), WrapOp, this.ref, other.ref)) + } + /** + * Wrap this interval into the range determined by that UInt + * Errors if requires wrapping more than once + * @param that an UInt whose properties determine the wrap + * @return + */ + final def wrap(that: UInt): Interval = macro SourceInfoTransform.thatArg + def do_wrap(that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + that.widthOption match { + case Some(w) => + val u = BigDecimal(BigInt(1) << w) - 1 + do_wrap(0.U.asInterval(IntervalRange(firrtlir.Closed(0), firrtlir.Closed(u), BinaryPoint(0)))) + case _ => + throwException(s"$this.wrap($that) requires UInt with known width") + } + } + + /** + * Wrap this interval into the range determined by an SInt + * Errors if requires wrapping more than once + * @param that an SInt whose properties determine the bounds of the wrap + * @return + */ + final def wrap(that: SInt): Interval = macro SourceInfoTransform.thatArg + def do_wrap(that: SInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + that.widthOption match { + case Some(w) => + val l = -BigDecimal(BigInt(1) << (that.getWidth - 1)) + val u = BigDecimal(BigInt(1) << (that.getWidth - 1)) - 1 + do_wrap(Wire(Interval(IntervalRange(firrtlir.Closed(l), firrtlir.Closed(u), BinaryPoint(0))))) + case _ => + throwException(s"$this.wrap($that) requires SInt with known width") + } + } + + /** + * Wrap this interval into the range determined by an IntervalRange + * Adds hardware to change values outside of wrapped range to be at the boundary + * Errors if requires wrapping more than once + * Ignores binary point of argument + * @param that an Interval whose properties determine the bounds of the wrap + * @return + */ + final def wrap(that: IntervalRange): Interval = macro SourceInfoTransform.thatArg + def do_wrap(that: IntervalRange)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + (that.lowerBound, that.upperBound) match { + case (lower: firrtlconstraint.IsKnown, upperBound: firrtlconstraint.IsKnown) => + do_wrap(0.U.asInterval(IntervalRange(that.lowerBound, that.upperBound, BinaryPoint(0)))) + case _ => + throwException(s"$this.wrap($that) requires Interval argument with known lower and upper bounds") + } + } + + /** + * Clip this interval into the range determined by argument's range + * Adds hardware to change values outside of clipped range to be at the boundary + * Ignores binary point of argument + * @param that an Interval whose properties determine the clipping + * @return + */ + final def clip(that: Interval): Interval = macro SourceInfoTransform.thatArg + def do_clip(that: Interval)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + binop(sourceInfo, Interval(this.range.clip(that.range)), ClipOp, that) + } + + /** + * Clip this interval into the range determined by argument's range + * Adds hardware to change values outside of clipped range to be at the boundary + * @param that an UInt whose width determines the clipping + * @return + */ + final def clip(that: UInt): Interval = macro SourceInfoTransform.thatArg + def do_clip(that: UInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + require(that.widthKnown, "UInt clip width must be known") + val u = BigDecimal(BigInt(1) << that.getWidth) - 1 + do_clip(Wire(Interval(IntervalRange(firrtlir.Closed(0), firrtlir.Closed(u), BinaryPoint(0))))) + } + + /** + * Clip this interval into the range determined by argument's range + * Adds hardware to move values outside of clipped range to the boundary + * @param that an SInt whose width determines the clipping + * @return + */ + final def clip(that: SInt): Interval = macro SourceInfoTransform.thatArg + def do_clip(that: SInt)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + require(that.widthKnown, "SInt clip width must be known") + val l = -BigDecimal(BigInt(1) << (that.getWidth - 1)) + val u = BigDecimal(BigInt(1) << (that.getWidth - 1)) - 1 + do_clip(Wire(Interval(IntervalRange(firrtlir.Closed(l), firrtlir.Closed(u), BinaryPoint(0))))) + } + + /** + * Clip this interval into the range determined by argument's range + * Adds hardware to move values outside of clipped range to the boundary + * Ignores binary point of argument + * @param that an SInt whose width determines the clipping + * @return + */ + final def clip(that: IntervalRange): Interval = macro SourceInfoTransform.thatArg + def do_clip(that: IntervalRange)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + (that.lowerBound, that.upperBound) match { + case (lower: firrtlconstraint.IsKnown, upperBound: firrtlconstraint.IsKnown) => + do_clip(0.U.asInterval(IntervalRange(that.lowerBound, that.upperBound, BinaryPoint(0)))) + case _ => + throwException(s"$this.clip($that) requires Interval argument with known lower and upper bounds") + } + } + + override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = { + pushOp(DefPrim(sourceInfo, UInt(this.width), AsUIntOp, ref)) + } + override def do_asSInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): SInt = { + pushOp(DefPrim(sourceInfo, SInt(this.width), AsSIntOp, ref)) + } + + override def do_asFixedPoint(binaryPoint: BinaryPoint)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): FixedPoint = { + binaryPoint match { + case KnownBinaryPoint(value) => + val iLit = ILit(value) + pushOp(DefPrim(sourceInfo, FixedPoint(width, binaryPoint), AsFixedPointOp, ref, iLit)) + case _ => + throwException( + s"cannot call $this.asFixedPoint(binaryPoint=$binaryPoint), you must specify a known binaryPoint") + } + } + + // TODO: intervals chick INVALID -- not enough args + def do_asInterval(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Interval = { + pushOp(DefPrim(sourceInfo, Interval(this.range), AsIntervalOp, ref)) + throwException(s"($this).asInterval must specify arguments INVALID") + } + + // TODO:(chick) intervals chick looks like this is wrong and only for FP? + def do_fromBits(that: Bits)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { + /*val res = Wire(this, null).asInstanceOf[this.type] + res := (that match { + case fp: FixedPoint => fp.asSInt.asFixedPoint(this.binaryPoint) + case _ => that.asFixedPoint(this.binaryPoint) + }) + res*/ + throwException("fromBits INVALID for intervals") + } + + private[chisel3] override def connectFromBits(that: Bits) + (implicit sourceInfo: SourceInfo, compileOptions: CompileOptions) { + this := that.asInterval(this.range) + } + } + + /** Use PrivateObject to force users to specify width and binaryPoint by name + */ + + /** + * Factory and convenience methods for the Interval class + * IMPORTANT: The API provided here is experimental and may change in the future. + */ + object Interval { + /** Create an Interval type with inferred width and binary point. */ + def apply(): Interval = Interval(range"[?,?]") + + /** Create an Interval type with specified width. */ + def apply(binaryPoint: BinaryPoint): Interval = { + val binaryPointString = binaryPoint match { + case KnownBinaryPoint(value) => s"$value" + case _ => s"" + } + Interval(range"[?,?].$binaryPointString") + } + + /** Create an Interval type with specified width. */ + def apply(width: Width): Interval = Interval(width, 0.BP) + + /** Create an Interval type with specified width and binary point */ + def apply(width: Width, binaryPoint: BinaryPoint): Interval = { + Interval(IntervalRange(width, binaryPoint)) + } + + /** Create an Interval type with specified range. + * @param range defines the properties + */ + def apply(range: IntervalRange): Interval = { + new Interval(range) + } + + /** Creates a Interval connected to a Interval literal with the value zero */ + def Zero: Interval = Lit(0, 1.W, 0.BP) + + /** Creates an Interval zero that supports the given range + * Useful for creating a Interval register that has a desired number of bits + * {{{ + * val myRegister = RegInit(Interval.Zero(r"[0,12]") + * }}} + * @param range + * @return + */ + def Zero(range: IntervalRange): Interval = Lit(0, range) + + /** Make an interval from this BigInt, the BigInt is treated as bits + * So lower binaryPoint number of bits will treated as mantissa + * + * @param value + * @param width + * @param binaryPoint + * @return + */ + def fromBigInt(value: BigInt, width: Width = Width(), binaryPoint: BinaryPoint = 0.BP): Interval = { + Interval.Lit(value, Width(), binaryPoint) + } + + /** Create an Interval literal with inferred width from Double. + * Use PrivateObject to force users to specify width and binaryPoint by name + */ + def fromDouble(value: Double, dummy: PrivateType = PrivateObject, + width: Width, binaryPoint: BinaryPoint): Interval = { + fromBigInt( + toBigInt(value, binaryPoint), width = width, binaryPoint = binaryPoint + ) + } + + /** Create an Interval literal with inferred width from Double. + * Use PrivateObject to force users to specify width and binaryPoint by name + */ + def fromBigDecimal(value: Double, dummy: PrivateType = PrivateObject, + width: Width, binaryPoint: BinaryPoint): Interval = { + fromBigInt( + toBigInt(value, binaryPoint), width = width, binaryPoint = binaryPoint + ) + } + + protected[chisel3] def Lit(value: BigInt, width: Width, binaryPoint: BinaryPoint): Interval = { + width match { + case KnownWidth(w) => + if(value >= 0 && value.bitLength >= w || value < 0 && value.bitLength > w) { + throw new ChiselException( + s"Error literal interval value $value is too many bits for specified width $w" + ) + } + case _ => + } + val lit = IntervalLit(value, width, binaryPoint) + val bound = firrtlir.Closed(Interval.toDouble(value, binaryPoint.asInstanceOf[KnownBinaryPoint].value)) + val result = new Interval(IntervalRange(bound, bound, binaryPoint)) + lit.bindLitArg(result) + } + + protected[chisel3] def Lit(value: BigInt, range: IntervalRange): Interval = { + val lit = IntervalLit(value, range.getWidth, range.binaryPoint) + val bigDecimal = BigDecimal(value) + val inRange = (range.lowerBound, range.upperBound) match { + case (firrtlir.Closed(l), firrtlir.Closed(u)) => l <= bigDecimal && bigDecimal <= u + case (firrtlir.Closed(l), firrtlir.Open(u)) => l <= bigDecimal && bigDecimal <= u + case (firrtlir.Open(l), firrtlir.Closed(u)) => l <= bigDecimal && bigDecimal <= u + case (firrtlir.Open(l), firrtlir.Open(u)) => l <= bigDecimal && bigDecimal <= u + } + if(! inRange) { + throw new ChiselException( + s"Error literal interval value $value is not contained in specified range $range" + ) + } + val result = Interval(range) + lit.bindLitArg(result) + } + + /** How to create a BigInt from a double with a specific binaryPoint + * + * @param x a double value + * @param binaryPoint a binaryPoint that you would like to use + * @return + */ + def toBigInt(x: Double, binaryPoint: BinaryPoint): BigInt = { + val intBinaryPoint = binaryPoint match { + case KnownBinaryPoint(n) => n + case b => + throw new ChiselException(s"Error converting Double $x to BigInt, binary point must be known, not $b") + } + val multiplier = BigInt(1) << intBinaryPoint + val result = BigInt(math.round(x * multiplier.doubleValue)) + result + + } + + /** + * How to create a BigInt from a BigDecimal with a specific binaryPoint + * + * @param b a BigDecimal value + * @param binaryPoint a binaryPoint that you would like to use + * @return + */ + def toBigInt(b: BigDecimal, binaryPoint: BinaryPoint): BigInt = { + val bp = binaryPoint match { + case KnownBinaryPoint(n) => n + case x => + throw new ChiselException(s"Error converting BigDecimal $b to BigInt, binary point must be known, not $x") + } + (b * math.pow(2.0, bp.toDouble)).toBigInt + } + + /** + * converts a bigInt with the given binaryPoint into the double representation + * + * @param i a BigInt + * @param binaryPoint the implied binaryPoint of @i + * @return + */ + def toDouble(i: BigInt, binaryPoint: Int): Double = { + val multiplier = BigInt(1) << binaryPoint + val result = i.toDouble / multiplier.doubleValue + result + } + + /** + * This returns the smallest number that can legally fit in range, if possible + * If the lower bound or binary point is not known then return None + * + * @param range use to figure low number + * @return + */ + def getSmallestLegalLit(range: IntervalRange): Option[Interval] = { + val bp = range.binaryPoint + range.lowerBound match { + case firrtlir.Closed(lowerBound) => + Some(Interval.Lit(toBigInt(lowerBound.toDouble, bp), width = range.getWidth, bp)) + case firrtlir.Open(lowerBound) => + Some(Interval.Lit(toBigInt(lowerBound.toDouble, bp) + BigInt(1), width = range.getWidth, bp)) + case _ => + None + } + } + + /** + * This returns the largest number that can legally fit in range, if possible + * If the upper bound or binary point is not known then return None + * + * @param range use to figure low number + * @return + */ + def getLargestLegalLit(range: IntervalRange): Option[Interval] = { + val bp = range.binaryPoint + range.upperBound match { + case firrtlir.Closed(upperBound) => + Some(Interval.Lit(toBigInt(upperBound.toDouble, bp), width = range.getWidth, bp)) + case firrtlir.Open(upperBound) => + Some(Interval.Lit(toBigInt(upperBound.toDouble, bp) - BigInt(1), width = range.getWidth, bp)) + case _ => + None + } + } + + /** Contains the implicit classes used to provide the .I methods to create intervals + * from the standard numberic types. + * {{{ + * val x = 7.I + * val y = 7.5.I(4.BP) + * }}} + */ + object Implicits { + implicit class fromBigIntToLiteralInterval(bigInt: BigInt) { + def I: Interval = { + Interval.Lit(bigInt, width = Width(), 0.BP) + } + + def I(binaryPoint: BinaryPoint): Interval = { + Interval.Lit(bigInt, width = Width(), binaryPoint = binaryPoint) + } + + def I(width: Width, binaryPoint: BinaryPoint): Interval = { + Interval.Lit(bigInt, width, binaryPoint) + } + + def I(range: IntervalRange): Interval = { + Interval.Lit(bigInt, range) + } + } + + implicit class fromIntToLiteralInterval(int: Int) extends fromBigIntToLiteralInterval(int) + implicit class fromLongToLiteralInterval(long: Long) extends fromBigIntToLiteralInterval(long) + + implicit class fromBigDecimalToLiteralInterval(bigDecimal: BigDecimal) { + def I: Interval = { + Interval.Lit(Interval.toBigInt(bigDecimal, 0.BP), width = Width(), 0.BP) + } + + def I(binaryPoint: BinaryPoint): Interval = { + Interval.Lit(Interval.toBigInt(bigDecimal, binaryPoint), width = Width(), binaryPoint = binaryPoint) + } + + def I(width: Width, binaryPoint: BinaryPoint): Interval = { + Interval.Lit(Interval.toBigInt(bigDecimal, binaryPoint), width, binaryPoint) + } + + def I(range: IntervalRange): Interval = { + Interval.Lit(Interval.toBigInt(bigDecimal, range.binaryPoint), range) + } + } + + implicit class fromDoubleToLiteralInterval(double: Double) + extends fromBigDecimalToLiteralInterval(BigDecimal(double)) } } } + + diff --git a/chiselFrontend/src/main/scala/chisel3/Data.scala b/chiselFrontend/src/main/scala/chisel3/Data.scala index 1a931135..59348dcd 100644 --- a/chiselFrontend/src/main/scala/chisel3/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/Data.scala @@ -3,12 +3,11 @@ package chisel3 import scala.language.experimental.macros - -import chisel3.experimental.{Analog, DataMirror, FixedPoint} +import chisel3.experimental.{Analog, DataMirror, FixedPoint, Interval} import chisel3.internal.Builder.pushCommand import chisel3.internal._ import chisel3.internal.firrtl._ -import chisel3.internal.sourceinfo.{SourceInfo, SourceInfoTransform, UnlocatableSourceInfo, DeprecatedSourceInfo} +import chisel3.internal.sourceinfo.{DeprecatedSourceInfo, SourceInfo, SourceInfoTransform, UnlocatableSourceInfo} /** User-specified directions. */ @@ -195,6 +194,9 @@ private[chisel3] object cloneSupertype { case _ => FixedPoint() } } + case (elt1: Interval, elt2: Interval) => + val range = if(elt1.range.width == elt1.range.width.max(elt2.range.width)) elt1.range else elt2.range + Interval(range) case (elt1, elt2) => throw new AssertionError( s"can't create $createdType with heterogeneous types ${elt1.getClass} and ${elt2.getClass}") diff --git a/chiselFrontend/src/main/scala/chisel3/SIntFactory.scala b/chiselFrontend/src/main/scala/chisel3/SIntFactory.scala index 607e2e35..c1c6b1db 100644 --- a/chiselFrontend/src/main/scala/chisel3/SIntFactory.scala +++ b/chiselFrontend/src/main/scala/chisel3/SIntFactory.scala @@ -2,9 +2,7 @@ package chisel3 -import chisel3.internal.firrtl.{KnownSIntRange, NumericBound, Range, SLit, Width} - -// scalastyle:off method.name +import chisel3.internal.firrtl.{IntervalRange, SLit, Width} trait SIntFactory { /** Create an SInt type with inferred width. */ @@ -13,15 +11,12 @@ trait SIntFactory { def apply(width: Width): SInt = new SInt(width) /** Create a SInt with the specified range */ - def apply(range: Range): SInt = { + def apply(range: IntervalRange): SInt = { apply(range.getWidth) } - /** Create a SInt with the specified range */ - def apply(range: (NumericBound[Int], NumericBound[Int])): SInt = { - apply(KnownSIntRange(range._1, range._2)) - } - /** Create an SInt literal with specified width. */ + /** Create an SInt literal with specified width. */ + // scalastyle:off method.name protected[chisel3] def Lit(value: BigInt, width: Width): SInt = { val lit = SLit(value, width) val result = new SInt(lit.width) diff --git a/chiselFrontend/src/main/scala/chisel3/UIntFactory.scala b/chiselFrontend/src/main/scala/chisel3/UIntFactory.scala index a62aa493..3868962b 100644 --- a/chiselFrontend/src/main/scala/chisel3/UIntFactory.scala +++ b/chiselFrontend/src/main/scala/chisel3/UIntFactory.scala @@ -2,9 +2,10 @@ package chisel3 -import chisel3.internal.firrtl.{KnownUIntRange, NumericBound, Range, ULit, Width} - -// scalastyle:off method.name +import chisel3.internal.firrtl.{IntervalRange, KnownWidth, ULit, UnknownWidth, Width} +import firrtl.Utils +import firrtl.constraint.IsKnown +import firrtl.ir.{Closed, IntWidth, Open} // This is currently a factory because both Bits and UInt inherit it. trait UIntFactory { @@ -13,20 +14,34 @@ trait UIntFactory { /** Create a UInt port with specified width. */ def apply(width: Width): UInt = new UInt(width) - /** Create a UInt literal with specified width. */ + /** Create a UInt literal with specified width. */ + // scalastyle:off method.name protected[chisel3] def Lit(value: BigInt, width: Width): UInt = { val lit = ULit(value, width) val result = new UInt(lit.width) // Bind result to being an Literal lit.bindLitArg(result) } + /** Create a UInt with the specified range, validate that range is effectively > 0 + */ + //scalastyle:off cyclomatic.complexity + def apply(range: IntervalRange): UInt = { + // Check is only done against lower bound because range will already insist that range high >= low + range.lowerBound match { + case Closed(bound) if bound < 0 => + throw new ChiselException(s"Attempt to create UInt with closed lower bound of $bound, must be > 0") + case Open(bound) if bound < -1 => + throw new ChiselException(s"Attempt to create UInt with open lower bound of $bound, must be > -1") + case _ => + } - /** Create a UInt with the specified range */ - def apply(range: Range): UInt = { - apply(range.getWidth) - } - /** Create a UInt with the specified range */ - def apply(range: (NumericBound[Int], NumericBound[Int])): UInt = { - apply(KnownUIntRange(range._1, range._2)) + // because this is a UInt we don't have to take into account the lower bound + val newWidth = if(range.upperBound.isInstanceOf[IsKnown]) { + KnownWidth(Utils.getUIntWidth(range.maxAdjusted.get).max(1)) // max(1) handles range"[0,0]" + } else { + UnknownWidth() + } + + apply(newWidth) } } diff --git a/chiselFrontend/src/main/scala/chisel3/core/package.scala b/chiselFrontend/src/main/scala/chisel3/core/package.scala index 2c60ce85..92c4617b 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/package.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/package.scala @@ -231,8 +231,7 @@ package object core { @deprecated("Use the version in chisel3.experimental._", "3.2") implicit class fromDoubleToLiteral(double: Double) extends experimental.FixedPoint.Implicits.fromDoubleToLiteral(double) @deprecated("Use the version in chisel3.experimental._", "3.2") - implicit class fromIntToBinaryPoint(int: Int) extends experimental.FixedPoint.Implicits.fromIntToBinaryPoint(int) - + implicit class fromIntToBinaryPoint(int: Int) extends chisel3.fromIntToBinaryPoint(int) @deprecated("Use the version in chisel3.experimental._", "3.2") type RunFirrtlTransform = chisel3.experimental.RunFirrtlTransform diff --git a/chiselFrontend/src/main/scala/chisel3/experimental/package.scala b/chiselFrontend/src/main/scala/chisel3/experimental/package.scala index 2ce3a1c6..7ade2cb3 100644 --- a/chiselFrontend/src/main/scala/chisel3/experimental/package.scala +++ b/chiselFrontend/src/main/scala/chisel3/experimental/package.scala @@ -55,7 +55,6 @@ package object experimental { // scalastyle:ignore object.name val Direction = ActualDirection implicit class ChiselRange(val sc: StringContext) extends AnyVal { - import chisel3.internal.firrtl.NumericBound import scala.language.experimental.macros @@ -67,7 +66,7 @@ package object experimental { // scalastyle:ignore object.name * UInt(range"[0, \${myInt + 2})") * }}} */ - def range(args: Any*): (NumericBound[Int], NumericBound[Int]) = macro chisel3.internal.RangeTransform.apply + def range(args: Any*): chisel3.internal.firrtl.IntervalRange = macro chisel3.internal.RangeTransform.apply } class dump extends chisel3.internal.naming.dump // scalastyle:ignore class.name @@ -76,6 +75,7 @@ package object experimental { // scalastyle:ignore object.name object BundleLiterals { implicit class AddBundleLiteralConstructor[T <: Bundle](x: T) { + //scalastyle:off method.name def Lit(elems: (T => (Data, Data))*): T = { x._makeLit(elems: _*) } diff --git a/chiselFrontend/src/main/scala/chisel3/internal/MonoConnect.scala b/chiselFrontend/src/main/scala/chisel3/internal/MonoConnect.scala index 1c001183..41402021 100644 --- a/chiselFrontend/src/main/scala/chisel3/internal/MonoConnect.scala +++ b/chiselFrontend/src/main/scala/chisel3/internal/MonoConnect.scala @@ -3,7 +3,7 @@ package chisel3.internal import chisel3._ -import chisel3.experimental.{Analog, BaseModule, EnumType, FixedPoint, UnsafeEnum} +import chisel3.experimental.{Analog, BaseModule, EnumType, FixedPoint, Interval, UnsafeEnum} import chisel3.internal.Builder.pushCommand import chisel3.internal.firrtl.{Connect, DefInvalid} import scala.language.experimental.macros @@ -85,6 +85,8 @@ private[chisel3] object MonoConnect { elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: FixedPoint, source_e: FixedPoint) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: Interval, source_e: Interval) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: Clock, source_e: Clock) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: AsyncReset, source_e: AsyncReset) => diff --git a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/Converter.scala b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/Converter.scala index 5309609b..548ed294 100644 --- a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/Converter.scala +++ b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/Converter.scala @@ -71,6 +71,11 @@ private[chisel3] object Converter { val uint = convert(ULit(unsigned, fplit.width), ctx) val lit = bp.asInstanceOf[KnownBinaryPoint].value fir.DoPrim(firrtl.PrimOps.AsFixedPoint, Seq(uint), Seq(lit), fir.UnknownType) + case intervalLit @ IntervalLit(n, w, bp) => + val unsigned = if (n < 0) (BigInt(1) << intervalLit.width.get) + n else n + val uint = convert(ULit(unsigned, intervalLit.width), ctx) + val lit = bp.asInstanceOf[KnownBinaryPoint].value + fir.DoPrim(firrtl.PrimOps.AsInterval, Seq(uint), Seq(n, n, lit), fir.UnknownType) case lit: ILit => throwException(s"Internal Error! Unexpected ILit: $lit") } @@ -220,6 +225,7 @@ private[chisel3] object Converter { case d: UInt => fir.UIntType(convert(d.width)) case d: SInt => fir.SIntType(convert(d.width)) case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint)) + case d: Interval => fir.IntervalType(d.range.lowerBound, d.range.upperBound, d.range.firrtlBinaryPoint) case d: Analog => fir.AnalogType(convert(d.width)) case d: Vec[_] => fir.VectorType(extractType(d.sample_element, clearDir), d.length) case d: Record => diff --git a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala index 4643f66c..bc662ddb 100644 --- a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala +++ b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala @@ -5,11 +5,15 @@ package chisel3.internal.firrtl import chisel3._ import chisel3.internal._ import chisel3.internal.sourceinfo.SourceInfo -import chisel3.experimental.{BaseModule, ChiselAnnotation, Param} +import chisel3.experimental._ +import _root_.firrtl.{ir => firrtlir} +import _root_.firrtl.PrimOps + +import scala.math.BigDecimal.RoundingMode // scalastyle:off number.of.types -case class PrimOp(val name: String) { +case class PrimOp(name: String) { override def toString: String = name } @@ -45,7 +49,13 @@ object PrimOp { val AsUIntOp = PrimOp("asUInt") val AsSIntOp = PrimOp("asSInt") val AsFixedPointOp = PrimOp("asFixedPoint") - val SetBinaryPoint = PrimOp("bpset") + val AsIntervalOp = PrimOp("asInterval") + val WrapOp = PrimOp("wrap") + val SqueezeOp = PrimOp("squz") + val ClipOp = PrimOp("clip") + val SetBinaryPoint = PrimOp("setp") + val IncreasePrecision = PrimOp("incp") + val DecreasePrecision = PrimOp("decp") val AsClockOp = PrimOp("asClock") val AsAsyncResetOp = PrimOp("asAsyncReset") } @@ -111,6 +121,18 @@ case class FPLit(n: BigInt, w: Width, binaryPoint: BinaryPoint) extends LitArg(n def minWidth: Int = 1 + n.bitLength } +case class IntervalLit(n: BigInt, w: Width, binaryPoint: BinaryPoint) extends LitArg(n, w) { + def name: String = { + val unsigned = if (n < 0) (BigInt(1) << width.get) + n else n + s"asInterval(${ULit(unsigned, width).name}, ${n}, ${n}, ${binaryPoint.asInstanceOf[KnownBinaryPoint].value})" + } + val range: IntervalRange = { + new IntervalRange(IntervalRange.getBound(isClosed = true, BigDecimal(n)), + IntervalRange.getBound(isClosed = true, BigDecimal(n)), IntervalRange.getRangeWidth(binaryPoint)) + } + def minWidth: Int = 1 + n.bitLength +} + case class Ref(name: String) extends Arg case class ModuleIO(mod: BaseModule, name: String) extends Arg { override def fullName(ctx: Component): String = @@ -125,54 +147,6 @@ case class Index(imm: Arg, value: Arg) extends Arg { override def fullName(ctx: Component): String = s"${imm.fullName(ctx)}[${value.fullName(ctx)}]" } -sealed trait Bound -sealed trait NumericBound[T] extends Bound { - val value: T -} -sealed case class Open[T](value: T) extends NumericBound[T] -sealed case class Closed[T](value: T) extends NumericBound[T] - -sealed trait Range { - val min: Bound - val max: Bound - def getWidth: Width -} - -sealed trait KnownIntRange extends Range { - val min: NumericBound[Int] - val max: NumericBound[Int] - - require( (min, max) match { - case (Open(low_val), Open(high_val)) => low_val < high_val - 1 - case (Closed(low_val), Open(high_val)) => low_val < high_val - case (Open(low_val), Closed(high_val)) => low_val < high_val - case (Closed(low_val), Closed(high_val)) => low_val <= high_val - }) -} - -sealed case class KnownUIntRange(min: NumericBound[Int], max: NumericBound[Int]) extends KnownIntRange { - require (min.value >= 0) - - def getWidth: Width = max match { - case Open(v) => Width(BigInt(v - 1).bitLength.max(1)) - case Closed(v) => Width(BigInt(v).bitLength.max(1)) - } -} - -sealed case class KnownSIntRange(min: NumericBound[Int], max: NumericBound[Int]) extends KnownIntRange { - - val maxWidth = max match { - case Open(v) => Width(BigInt(v - 1).bitLength + 1) - case Closed(v) => Width(BigInt(v).bitLength + 1) - } - val minWidth = min match { - case Open(v) => Width(BigInt(v + 1).bitLength + 1) - case Closed(v) => Width(BigInt(v).bitLength + 1) - } - def getWidth: Width = maxWidth.max(minWidth) - -} - object Width { def apply(x: Int): Width = KnownWidth(x) def apply(): Width = UnknownWidth() @@ -257,6 +231,481 @@ object MemPortDirection { object INFER extends MemPortDirection("infer") } +sealed trait RangeType { + def getWidth: Width + + def * (that: IntervalRange): IntervalRange + def +& (that: IntervalRange): IntervalRange + def -& (that: IntervalRange): IntervalRange + def << (that: Int): IntervalRange + def >> (that: Int): IntervalRange + def << (that: KnownWidth): IntervalRange + def >> (that: KnownWidth): IntervalRange + def merge(that: IntervalRange): IntervalRange +} + +object IntervalRange { + /** Creates an IntervalRange, this is used primarily by the range interpolator macro + * @param lower lower bound + * @param upper upper bound + * @param firrtlBinaryPoint binary point firrtl style + * @return + */ + def apply(lower: firrtlir.Bound, upper: firrtlir.Bound, firrtlBinaryPoint: firrtlir.Width): IntervalRange = { + new IntervalRange(lower, upper, firrtlBinaryPoint) + } + + def apply(lower: firrtlir.Bound, upper: firrtlir.Bound, binaryPoint: BinaryPoint): IntervalRange = { + new IntervalRange(lower, upper, IntervalRange.getBinaryPoint(binaryPoint)) + } + + def apply(lower: firrtlir.Bound, upper: firrtlir.Bound, binaryPoint: Int): IntervalRange = { + IntervalRange(lower, upper, BinaryPoint(binaryPoint)) + } + + /** Returns an IntervalRange appropriate for a signed value of the given width + * @param binaryPoint number of bits of mantissa + * @return + */ + def apply(binaryPoint: BinaryPoint): IntervalRange = { + IntervalRange(firrtlir.UnknownBound, firrtlir.UnknownBound, binaryPoint) + } + + /** Returns an IntervalRange appropriate for a signed value of the given width + * @param width number of bits to have in the interval + * @param binaryPoint number of bits of mantissa + * @return + */ + def apply(width: Width, binaryPoint: BinaryPoint = 0.BP): IntervalRange = { + val range = width match { + case KnownWidth(w) => + val nearestPowerOf2 = BigInt("1" + ("0" * (w - 1)), 2) + IntervalRange( + firrtlir.Closed(BigDecimal(-nearestPowerOf2)), firrtlir.Closed(BigDecimal(nearestPowerOf2 - 1)), binaryPoint + ) + case _ => + IntervalRange(firrtlir.UnknownBound, firrtlir.UnknownBound, binaryPoint) + } + range + } + + def unapply(arg: IntervalRange): Option[(firrtlir.Bound, firrtlir.Bound, BinaryPoint)] = { + return Some((arg.lower, arg.upper, arg.binaryPoint)) + } + + def getBound(isClosed: Boolean, value: String): firrtlir.Bound = { + if(value == "?") { + firrtlir.UnknownBound + } + else if(isClosed) { + firrtlir.Closed(BigDecimal(value)) + } + else { + firrtlir.Open(BigDecimal(value)) + } + } + + def getBound(isClosed: Boolean, value: BigDecimal): firrtlir.Bound = { + if(isClosed) { + firrtlir.Closed(value) + } + else { + firrtlir.Open(value) + } + } + + def getBound(isClosed: Boolean, value: Int): firrtlir.Bound = { + getBound(isClosed, (BigDecimal(value))) + } + + def getBinaryPoint(s: String): firrtlir.Width = { + firrtlir.UnknownWidth + } + + def getBinaryPoint(n: Int): firrtlir.Width = { + if(n < 0) { + firrtlir.UnknownWidth + } + else { + firrtlir.IntWidth(n) + } + } + def getBinaryPoint(n: BinaryPoint): firrtlir.Width = { + n match { + case UnknownBinaryPoint => firrtlir.UnknownWidth + case KnownBinaryPoint(w) => firrtlir.IntWidth(w) + } + } + + def getRangeWidth(w: Width): firrtlir.Width = { + if(w.known) { + firrtlir.IntWidth(w.get) + } + else { + firrtlir.UnknownWidth + } + } + def getRangeWidth(binaryPoint: BinaryPoint): firrtlir.Width = { + if(binaryPoint.known) { + firrtlir.IntWidth(binaryPoint.get) + } + else { + firrtlir.UnknownWidth + } + } + + //scalastyle:off method.name + def Unknown: IntervalRange = range"[?,?].?" +} + + +sealed class IntervalRange( + val lowerBound: firrtlir.Bound, + val upperBound: firrtlir.Bound, + private[chisel3] val firrtlBinaryPoint: firrtlir.Width) + extends firrtlir.IntervalType(lowerBound, upperBound, firrtlBinaryPoint) + with RangeType { + + (lowerBound, upperBound) match { + case (firrtlir.Open(begin), firrtlir.Open(end)) => + if(begin >= end) throw new ChiselException(s"Invalid range with ${serialize}") + binaryPoint match { + case KnownBinaryPoint(bp) => + if(begin >= end - (BigDecimal(1) / BigDecimal(BigInt(1) << bp))) { + throw new ChiselException(s"Invalid range with ${serialize}") + } + case _ => + } + case (firrtlir.Open(begin), firrtlir.Closed(end)) => + if(begin >= end) throw new ChiselException(s"Invalid range with ${serialize}") + case (firrtlir.Closed(begin), firrtlir.Open(end)) => + if(begin >= end) throw new ChiselException(s"Invalid range with ${serialize}") + case (firrtlir.Closed(begin), firrtlir.Closed(end)) => + if(begin > end) throw new ChiselException(s"Invalid range with ${serialize}") + case _ => + } + + //scalastyle:off cyclomatic.complexity + override def toString: String = { + val binaryPoint = firrtlBinaryPoint match { + case firrtlir.IntWidth(n) => s"$n" + case _ => "?" + } + val lowerBoundString = lowerBound match { + case firrtlir.Closed(l) => s"[$l" + case firrtlir.Open(l) => s"($l" + case firrtlir.UnknownBound => s"[?" + } + val upperBoundString = upperBound match { + case firrtlir.Closed(l) => s"$l]" + case firrtlir.Open(l) => s"$l)" + case firrtlir.UnknownBound => s"?]" + } + s"""range"$lowerBoundString,$upperBoundString.$binaryPoint"""" + } + + val increment: Option[BigDecimal] = firrtlBinaryPoint match { + case firrtlir.IntWidth(bp) => + Some(BigDecimal(math.pow(2, -bp.doubleValue))) + case _ => None + } + + /** If possible returns the lowest possible value for this Interval + * @return + */ + val getLowestPossibleValue: Option[BigDecimal] = { + increment match { + case Some(inc) => + lower match { + case firrtlir.Closed(n) => Some(n) + case firrtlir.Open(n) => Some(n + inc) + case _ => None + } + case _ => + None + } + } + + /** If possible returns the highest possible value for this Interval + * @return + */ + val getHighestPossibleValue: Option[BigDecimal] = { + increment match { + case Some(inc) => + lower match { + case firrtlir.Closed(n) => Some(n) + case firrtlir.Open(n) => Some(n - inc) + case _ => None + } + case _ => + None + } + } + + /** Return a Seq of the possible values for this range + * Mostly to be used for testing + * @return + */ + def getPossibleValues: Seq[BigDecimal] = { + (getLowestPossibleValue, getHighestPossibleValue, increment) match { + case (Some(low), Some(high), Some(inc)) => (low to high by inc) + case (_, _, None) => + throw new ChiselException(s"BinaryPoint unknown. Cannot get possible values from IntervalRange $toString") + case _ => + throw new ChiselException(s"Unknown Bound. Cannot get possible values from IntervalRange $toString") + + } + } + + override def getWidth: Width = { + width match { + case firrtlir.IntWidth(n) => KnownWidth(n.toInt) + case firrtlir.UnknownWidth => UnknownWidth() + } + } + + private def doFirrtlOp(op: firrtlir.PrimOp, that: IntervalRange): IntervalRange = { + PrimOps.set_primop_type( + firrtlir.DoPrim(op, + Seq(firrtlir.Reference("a", this), firrtlir.Reference("b", that)), Nil,firrtlir.UnknownType) + ).tpe match { + case i: firrtlir.IntervalType => IntervalRange(i.lower, i.upper, i.point) + case other => sys.error("BAD!") + } + } + + private def doFirrtlDynamicShift(that: UInt, isLeft: Boolean): IntervalRange = { + val uinttpe = that.widthOption match { + case None => firrtlir.UIntType(firrtlir.UnknownWidth) + case Some(w) => firrtlir.UIntType(firrtlir.IntWidth(w)) + } + val op = if(isLeft) PrimOps.Dshl else PrimOps.Dshr + PrimOps.set_primop_type( + firrtlir.DoPrim(op, + Seq(firrtlir.Reference("a", this), firrtlir.Reference("b", uinttpe)), Nil,firrtlir.UnknownType) + ).tpe match { + case i: firrtlir.IntervalType => IntervalRange(i.lower, i.upper, i.point) + case other => sys.error("BAD!") + } + } + + private def doFirrtlOp(op: firrtlir.PrimOp, that: Int): IntervalRange = { + PrimOps.set_primop_type( + firrtlir.DoPrim(op, + Seq(firrtlir.Reference("a", this)), Seq(BigInt(that)), firrtlir.UnknownType) + ).tpe match { + case i: firrtlir.IntervalType => IntervalRange(i.lower, i.upper, i.point) + case other => sys.error("BAD!") + } + } + + /** Multiply this by that, here we return a fully unknown range, + * firrtl's range inference can figure this out + * @param that + * @return + */ + override def *(that: IntervalRange): IntervalRange = { + doFirrtlOp(PrimOps.Mul, that) + } + + /** Add that to this, here we return a fully unknown range, + * firrtl's range inference can figure this out + * @param that + * @return + */ + override def +&(that: IntervalRange): IntervalRange = { + doFirrtlOp(PrimOps.Add, that) + } + + /** Subtract that from this, here we return a fully unknown range, + * firrtl's range inference can figure this out + * @param that + * @return + */ + override def -&(that: IntervalRange): IntervalRange = { + doFirrtlOp(PrimOps.Sub, that) + } + + private def adjustBoundValue(value: BigDecimal, binaryPointValue: Int): BigDecimal = { + if(binaryPointValue >= 0) { + val maskFactor = BigDecimal(1 << binaryPointValue) + val a = (value * maskFactor) + val b = a.setScale(0, RoundingMode.DOWN) + val c = b / maskFactor + c + } else { + value + } + } + + private def adjustBound(bound: firrtlir.Bound, binaryPoint: BinaryPoint): firrtlir.Bound = { + binaryPoint match { + case KnownBinaryPoint(binaryPointValue) => + bound match { + case firrtlir.Open(value) => firrtlir.Open(adjustBoundValue(value, binaryPointValue)) + case firrtlir.Closed(value) => firrtlir.Closed(adjustBoundValue(value, binaryPointValue)) + case _ => bound + } + case _ => firrtlir.UnknownBound + } + } + + /** Creates a new range with the increased precision + * + * @param newBinaryPoint + * @return + */ + def incPrecision(newBinaryPoint: BinaryPoint): IntervalRange = { + newBinaryPoint match { + case KnownBinaryPoint(that) => + doFirrtlOp(PrimOps.IncP, that) + case _ => + throwException(s"$this.incPrecision(newBinaryPoint = $newBinaryPoint) error, newBinaryPoint must be know") + } + } + + /** Creates a new range with the decreased precision + * + * @param newBinaryPoint + * @return + */ + def decPrecision(newBinaryPoint: BinaryPoint): IntervalRange = { + newBinaryPoint match { + case KnownBinaryPoint(that) => + doFirrtlOp(PrimOps.DecP, that) + case _ => + throwException(s"$this.decPrecision(newBinaryPoint = $newBinaryPoint) error, newBinaryPoint must be know") + } + } + + /** Creates a new range with the given binary point, adjusting precision + * on bounds as necessary + * + * @param newBinaryPoint + * @return + */ + def setPrecision(newBinaryPoint: BinaryPoint): IntervalRange = { + newBinaryPoint match { + case KnownBinaryPoint(that) => + doFirrtlOp(PrimOps.SetP, that) + case _ => + throwException(s"$this.setPrecision(newBinaryPoint = $newBinaryPoint) error, newBinaryPoint must be know") + } + } + + /** Shift this range left, i.e. shifts the min and max by the specified amount + * @param that + * @return + */ + override def <<(that: Int): IntervalRange = { + doFirrtlOp(PrimOps.Shl, that) + } + + /** Shift this range left, i.e. shifts the min and max by the known width + * @param that + * @return + */ + override def <<(that: KnownWidth): IntervalRange = { + <<(that.value) + } + + /** Shift this range left, i.e. shifts the min and max by value + * @param that + * @return + */ + def <<(that: UInt): IntervalRange = { + doFirrtlDynamicShift(that, isLeft = true) + } + + /** Shift this range right, i.e. shifts the min and max by the specified amount + * @param that + * @return + */ + override def >>(that: Int): IntervalRange = { + doFirrtlOp(PrimOps.Shr, that) + } + + /** Shift this range right, i.e. shifts the min and max by the known width + * @param that + * @return + */ + override def >>(that: KnownWidth): IntervalRange = { + >>(that.value) + } + + /** Shift this range right, i.e. shifts the min and max by value + * @param that + * @return + */ + def >>(that: UInt): IntervalRange = { + doFirrtlDynamicShift(that, isLeft = false) + } + + /** + * Squeeze returns the intersection of the ranges this interval and that Interval + * @param that + * @return + */ + def squeeze(that: IntervalRange): IntervalRange = { + doFirrtlOp(PrimOps.Squeeze, that) + } + + /** + * Wrap the value of this [[Interval]] into the range of a different Interval with a presumably smaller range. + * @param that + * @return + */ + def wrap(that: IntervalRange): IntervalRange = { + doFirrtlOp(PrimOps.Wrap, that) + } + + /** + * Clip the value of this [[Interval]] into the range of a different Interval with a presumably smaller range. + * @param that + * @return + */ + def clip(that: IntervalRange): IntervalRange = { + doFirrtlOp(PrimOps.Clip, that) + } + + /** merges the ranges of this and that, basically takes lowest low, highest high and biggest bp + * set unknown if any of this or that's value of above is unknown + * Like an union but will slurp up points in between the two ranges that were part of neither + * @param that + * @return + */ + override def merge(that: IntervalRange): IntervalRange = { + val lowest = (this.getLowestPossibleValue, that.getLowestPossibleValue) match { + case (Some(l1), Some(l2)) => + if(l1 < l2) { this.lower } else { that.lower } + case _ => + firrtlir.UnknownBound + } + val highest = (this.getHighestPossibleValue, that.getHighestPossibleValue) match { + case (Some(l1), Some(l2)) => + if(l1 >= l2) { this.lower } else { that.lower } + case _ => + firrtlir.UnknownBound + } + val newBinaryPoint = (this.firrtlBinaryPoint, that.firrtlBinaryPoint) match { + case (firrtlir.IntWidth(b1), firrtlir.IntWidth(b2)) => + if(b1 > b2) { firrtlir.IntWidth(b1)} else { firrtlir.IntWidth(b2) } + case _ => + firrtlir.UnknownWidth + } + IntervalRange(lowest, highest, newBinaryPoint) + } + + def binaryPoint: BinaryPoint = { + firrtlBinaryPoint match { + case firrtlir.IntWidth(n) => + assert(n < Int.MaxValue, s"binary point value $n is out of range") + KnownBinaryPoint(n.toInt) + case _ => UnknownBinaryPoint + } + } +} + abstract class Command { def sourceInfo: SourceInfo } diff --git a/chiselFrontend/src/main/scala/chisel3/package.scala b/chiselFrontend/src/main/scala/chisel3/package.scala index 51bcf1fe..3af21d57 100644 --- a/chiselFrontend/src/main/scala/chisel3/package.scala +++ b/chiselFrontend/src/main/scala/chisel3/package.scala @@ -1,118 +1,140 @@ // See LICENSE for license details. +import chisel3.internal.firrtl.BinaryPoint + /** This package contains the main chisel3 API. */ package object chisel3 { // scalastyle:ignore package.object.name import internal.firrtl.{Port, Width} - import internal.sourceinfo.{SourceInfo, VecTransform} - import internal.{Builder, chiselRuntimeDeprecated} + import internal.Builder import scala.language.implicitConversions - /** - * These implicit classes allow one to convert scala.Int|scala.BigInt to - * Chisel.UInt|Chisel.SInt by calling .asUInt|.asSInt on them, respectively. - * The versions .asUInt(width)|.asSInt(width) are also available to explicitly - * mark a width for the new literal. - * - * Also provides .asBool to scala.Boolean and .asUInt to String - * - * Note that, for stylistic reasons, one should avoid extracting immediately - * after this call using apply, ie. 0.asUInt(1)(0) due to potential for - * confusion (the 1 is a bit length and the 0 is a bit extraction position). - * Prefer storing the result and then extracting from it. - * - * Implementation note: the empty parameter list (like `U()`) is necessary to prevent - * interpreting calls that have a non-Width parameter as a chained apply, otherwise things like - * `0.asUInt(16)` (instead of `16.W`) compile without error and produce undesired results. - */ - implicit class fromBigIntToLiteral(bigint: BigInt) { - /** Int to Bool conversion, allowing compact syntax like 1.B and 0.B - */ - def B: Bool = bigint match { // scalastyle:ignore method.name - case bigint if bigint == 0 => Bool.Lit(false) - case bigint if bigint == 1 => Bool.Lit(true) - case bigint => Builder.error(s"Cannot convert $bigint to Bool, must be 0 or 1"); Bool.Lit(false) - } - /** Int to UInt conversion, recommended style for constants. - */ - def U: UInt = UInt.Lit(bigint, Width()) // scalastyle:ignore method.name - /** Int to SInt conversion, recommended style for constants. - */ - def S: SInt = SInt.Lit(bigint, Width()) // scalastyle:ignore method.name - /** Int to UInt conversion with specified width, recommended style for constants. - */ - def U(width: Width): UInt = UInt.Lit(bigint, width) // scalastyle:ignore method.name - /** Int to SInt conversion with specified width, recommended style for constants. - */ - def S(width: Width): SInt = SInt.Lit(bigint, width) // scalastyle:ignore method.name - - /** Int to UInt conversion, recommended style for variables. - */ - def asUInt(): UInt = UInt.Lit(bigint, Width()) - /** Int to SInt conversion, recommended style for variables. - */ - def asSInt(): SInt = SInt.Lit(bigint, Width()) - /** Int to UInt conversion with specified width, recommended style for variables. - */ - def asUInt(width: Width): UInt = UInt.Lit(bigint, width) - /** Int to SInt conversion with specified width, recommended style for variables. - */ - def asSInt(width: Width): SInt = SInt.Lit(bigint, width) - } + /** + * These implicit classes allow one to convert scala.Int|scala.BigInt to + * Chisel.UInt|Chisel.SInt by calling .asUInt|.asSInt on them, respectively. + * The versions .asUInt(width)|.asSInt(width) are also available to explicitly + * mark a width for the new literal. + * + * Also provides .asBool to scala.Boolean and .asUInt to String + * + * Note that, for stylistic reasons, one should avoid extracting immediately + * after this call using apply, ie. 0.asUInt(1)(0) due to potential for + * confusion (the 1 is a bit length and the 0 is a bit extraction position). + * Prefer storing the result and then extracting from it. + * + * Implementation note: the empty parameter list (like `U()`) is necessary to prevent + * interpreting calls that have a non-Width parameter as a chained apply, otherwise things like + * `0.asUInt(16)` (instead of `16.W`) compile without error and produce undesired results. + */ + implicit class fromBigIntToLiteral(bigint: BigInt) { + /** Int to Bool conversion, allowing compact syntax like 1.B and 0.B + */ + def B: Bool = bigint match { // scalastyle:ignore method.name + case bigint if bigint == 0 => Bool.Lit(false) + case bigint if bigint == 1 => Bool.Lit(true) + case bigint => Builder.error(s"Cannot convert $bigint to Bool, must be 0 or 1"); Bool.Lit(false) + } + /** Int to UInt conversion, recommended style for constants. + */ + def U: UInt = UInt.Lit(bigint, Width()) // scalastyle:ignore method.name + /** Int to SInt conversion, recommended style for constants. + */ + def S: SInt = SInt.Lit(bigint, Width()) // scalastyle:ignore method.name + /** Int to UInt conversion with specified width, recommended style for constants. + */ + def U(width: Width): UInt = UInt.Lit(bigint, width) // scalastyle:ignore method.name + /** Int to SInt conversion with specified width, recommended style for constants. + */ + def S(width: Width): SInt = SInt.Lit(bigint, width) // scalastyle:ignore method.name - implicit class fromIntToLiteral(int: Int) extends fromBigIntToLiteral(int) - implicit class fromLongToLiteral(long: Long) extends fromBigIntToLiteral(long) - - implicit class fromStringToLiteral(str: String) { - /** String to UInt parse, recommended style for constants. - */ - def U: UInt = str.asUInt() // scalastyle:ignore method.name - /** String to UInt parse with specified width, recommended style for constants. - */ - def U(width: Width): UInt = str.asUInt(width) // scalastyle:ignore method.name - - /** String to UInt parse, recommended style for variables. - */ - def asUInt(): UInt = { - val bigInt = parse(str) - UInt.Lit(bigInt, Width(bigInt.bitLength max 1)) - } - /** String to UInt parse with specified width, recommended style for variables. - */ - def asUInt(width: Width): UInt = UInt.Lit(parse(str), width) - - protected def parse(n: String) = { - val (base, num) = n.splitAt(1) - val radix = base match { - case "x" | "h" => 16 - case "d" => 10 - case "o" => 8 - case "b" => 2 - case _ => Builder.error(s"Invalid base $base"); 2 - } - BigInt(num.filterNot(_ == '_'), radix) - } - } + /** Int to UInt conversion, recommended style for variables. + */ + def asUInt(): UInt = UInt.Lit(bigint, Width()) + /** Int to SInt conversion, recommended style for variables. + */ + def asSInt(): SInt = SInt.Lit(bigint, Width()) + /** Int to UInt conversion with specified width, recommended style for variables. + */ + def asUInt(width: Width): UInt = UInt.Lit(bigint, width) + /** Int to SInt conversion with specified width, recommended style for variables. + */ + def asSInt(width: Width): SInt = SInt.Lit(bigint, width) + } - implicit class fromBooleanToLiteral(boolean: Boolean) { - /** Boolean to Bool conversion, recommended style for constants. - */ - def B: Bool = Bool.Lit(boolean) // scalastyle:ignore method.name + implicit class fromIntToLiteral(int: Int) extends fromBigIntToLiteral(int) + implicit class fromLongToLiteral(long: Long) extends fromBigIntToLiteral(long) - /** Boolean to Bool conversion, recommended style for variables. - */ - def asBool(): Bool = Bool.Lit(boolean) + implicit class fromStringToLiteral(str: String) { + /** String to UInt parse, recommended style for constants. + */ + def U: UInt = str.asUInt() // scalastyle:ignore method.name + /** String to UInt parse with specified width, recommended style for constants. + */ + def U(width: Width): UInt = str.asUInt(width) // scalastyle:ignore method.name + + /** String to UInt parse, recommended style for variables. + */ + def asUInt(): UInt = { + val bigInt = parse(str) + UInt.Lit(bigInt, Width(bigInt.bitLength max 1)) + } + /** String to UInt parse with specified width, recommended style for variables. + */ + def asUInt(width: Width): UInt = UInt.Lit(parse(str), width) + + protected def parse(n: String): BigInt = { + val (base, num) = n.splitAt(1) + val radix = base match { + case "x" | "h" => 16 + case "d" => 10 + case "o" => 8 + case "b" => 2 + case _ => Builder.error(s"Invalid base $base"); 2 } + BigInt(num.filterNot(_ == '_'), radix) + } + } - // Fixed Point is experimental for now, but we alias the implicit conversion classes here - // to minimize disruption with existing code. - implicit class fromDoubleToLiteral(double: Double) extends experimental.FixedPoint.Implicits.fromDoubleToLiteral(double) - implicit class fromIntToBinaryPoint(int: Int) extends experimental.FixedPoint.Implicits.fromIntToBinaryPoint(int) + implicit class fromIntToBinaryPoint(int: Int) { + def BP: BinaryPoint = BinaryPoint(int) // scalastyle:ignore method.name + } - implicit class fromIntToWidth(int: Int) { - def W: Width = Width(int) // scalastyle:ignore method.name - } + implicit class fromBooleanToLiteral(boolean: Boolean) { + /** Boolean to Bool conversion, recommended style for constants. + */ + def B: Bool = Bool.Lit(boolean) // scalastyle:ignore method.name + + /** Boolean to Bool conversion, recommended style for variables. + */ + def asBool(): Bool = Bool.Lit(boolean) + } + + // Fixed Point is experimental for now, but we alias the implicit conversion classes here + // to minimize disruption with existing code. + implicit class fromDoubleToLiteral(double: Double) + extends experimental.FixedPoint.Implicits.fromDoubleToLiteral(double) + + // Interval is experimental for now, but we alias the implicit conversion classes here + // to minimize disruption with existing code. + implicit class fromIntToLiteralInterval(int: Int) + extends experimental.Interval.Implicits.fromIntToLiteralInterval(int) + + implicit class fromLongToLiteralInterval(long: Long) + extends experimental.Interval.Implicits.fromLongToLiteralInterval(long) + + implicit class fromBigIntToLiteralInterval(bigInt: BigInt) + extends experimental.Interval.Implicits.fromBigIntToLiteralInterval(bigInt) + + implicit class fromDoubleToLiteralInterval(double: Double) + extends experimental.Interval.Implicits.fromDoubleToLiteralInterval(double) + + implicit class fromBigDecimalToLiteralInterval(bigDecimal: BigDecimal) + extends experimental.Interval.Implicits.fromBigDecimalToLiteralInterval(bigDecimal) + + implicit class fromIntToWidth(int: Int) { + def W: Width = Width(int) // scalastyle:ignore method.name + } val WireInit = WireDefault diff --git a/coreMacros/src/main/scala/chisel3/internal/RangeTransform.scala b/coreMacros/src/main/scala/chisel3/internal/RangeTransform.scala index e61ddc6a..0fdbff81 100644 --- a/coreMacros/src/main/scala/chisel3/internal/RangeTransform.scala +++ b/coreMacros/src/main/scala/chisel3/internal/RangeTransform.scala @@ -6,32 +6,56 @@ package chisel3.internal import scala.language.experimental.macros -import scala.reflect.macros.blackbox.Context +import scala.reflect.macros.blackbox +import scala.util.matching.Regex // Workaround for https://github.com/sbt/sbt/issues/3966 -object RangeTransform -class RangeTransform(val c: Context) { +object RangeTransform { + val UnspecifiedNumber: Regex = """(\?).*""".r + val IntegerNumber: Regex = """(-?\d+).*""".r + val DecimalNumber: Regex = """(-?\d+\.\d+).*""".r +} + +/** Convert the string to IntervalRange, with unknown, open or closed endpoints and a binary point + * ranges looks like + * range"[0,4].1" range starts at 0 inclusive ends at 4.inclusively with a binary point of 1 + * range"(0,4).1" range starts at 0 exclusive ends at 4.exclusively with a binary point of 1 + * + * the min and max of the range are the actually min and max values, thus the binary point + * becomes a sort of multiplier for the number of bits. + * E.g. range"[0,3].2" will require at least 4 bits two provide the two decimal places + * + * @param c contains the string context to be parsed + */ +//scalastyle:off cyclomatic.complexity method.length +class RangeTransform(val c: blackbox.Context) { import c.universe._ - // scalastyle:off method.length line.size.limit def apply(args: c.Tree*): c.Tree = { val stringTrees = c.prefix.tree match { case q"$_(scala.StringContext.apply(..$strings))" => strings - case _ => c.abort(c.enclosingPosition, s"Range macro unable to parse StringContext, got: ${showCode(c.prefix.tree)}") + case _ => + c.abort( + c.enclosingPosition, + s"Range macro unable to parse StringContext, got: ${showCode(c.prefix.tree)}" + ) } - val strings = stringTrees.map { tree => tree match { + val strings = stringTrees.map { case Literal(Constant(string: String)) => string - case _ => c.abort(c.enclosingPosition, s"Range macro unable to parse StringContext element, got: ${showRaw(tree)}") - } } - // scalastyle:on line.size.limit + case tree => + c.abort( + c.enclosingPosition, + s"Range macro unable to parse StringContext element, got: ${showRaw(tree)}" + ) + } var nextStringIndex: Int = 1 var nextArgIndex: Int = 0 - var currString: String = strings(0) + var currString: String = strings.head /** Mutably gets the next numeric value in the range specifier. */ - def getNextValue(): c.Tree = { - currString = currString.dropWhile(_ == ' ') // allow whitespace + def computeNextValue(): c.Tree = { + currString = currString.dropWhile(_ == ' ') // allow whitespace if (currString.isEmpty) { if (nextArgIndex >= args.length) { c.abort(c.enclosingPosition, s"Incomplete range specifier") @@ -47,59 +71,127 @@ class RangeTransform(val c: Context) { nextArg } else { - val nextStringVal = currString.takeWhile(!Set('[', '(', ' ', ',', ')', ']').contains(_)) + val nextStringVal = currString match { + case RangeTransform.DecimalNumber(numberString) => numberString + case RangeTransform.IntegerNumber(numberString) => numberString + case RangeTransform.UnspecifiedNumber(_) => "?" + case _ => + c.abort( + c.enclosingPosition, + s"Bad number or unspecified bound $currString" + ) + } currString = currString.substring(nextStringVal.length) - if (currString.isEmpty) { - c.abort(c.enclosingPosition, s"Incomplete range specifier") + + if (nextStringVal == "?") { + Literal(Constant("?")) + } else { + c.parse(nextStringVal) } - c.parse(nextStringVal) } } // Currently, not allowed to have the end stops (inclusive / exclusive) be interpolated. currString = currString.dropWhile(_ == ' ') - val startInclusive = currString(0) match { - case '[' => true - case '(' => false - case other => c.abort(c.enclosingPosition, s"Unknown start inclusive/exclusive specifier, got: '$other'") + val startInclusive = currString.headOption match { + case Some('[') => true + case Some('(') => false + case Some('?') => + c.abort( + c.enclosingPosition, + s"start of range as unknown s must be '[?' or '(?' not '?'" + ) + case Some(other) => + c.abort( + c.enclosingPosition, + s"Unknown start inclusive/exclusive specifier, got: '$other'" + ) + case None => + c.abort( + c.enclosingPosition, + s"No initial inclusive/exclusive specifier" + ) } - currString = currString.substring(1) // eat the inclusive/exclusive specifier - val minArg = getNextValue() + + currString = currString.substring(1) // eat the inclusive/exclusive specifier + val minArg = computeNextValue() currString = currString.dropWhile(_ == ' ') if (currString(0) != ',') { c.abort(c.enclosingPosition, s"Incomplete range specifier, expected ','") } - currString = currString.substring(1) // eat the comma - val maxArg = getNextValue() + if (currString.head != ',') { + c.abort( + c.enclosingPosition, + s"Incomplete range specifier, expected ',', got $currString" + ) + } + + currString = currString.substring(1) // eat the comma + + val maxArg = computeNextValue() currString = currString.dropWhile(_ == ' ') - val endInclusive = currString(0) match { - case ']' => true - case ')' => false - case other => c.abort(c.enclosingPosition, s"Unknown end inclusive/exclusive specifier, got: '$other'") + + val endInclusive = currString.headOption match { + case Some(']') => true + case Some(')') => false + case Some('?') => + c.abort( + c.enclosingPosition, + s"start of range as unknown s must be '[?' or '(?' not '?'" + ) + case Some(other) => + c.abort( + c.enclosingPosition, + s"Unknown end inclusive/exclusive specifier, got: '$other' expecting ')' or ']'" + ) + case None => + c.abort( + c.enclosingPosition, + s"Incomplete range specifier, missing end inclusive/exclusive specifier" + ) } - currString = currString.substring(1) // eat the inclusive/exclusive specifier + currString = currString.substring(1) // eat the inclusive/exclusive specifier currString = currString.dropWhile(_ == ' ') + val binaryPointString = currString.headOption match { + case Some('.') => + currString = currString.substring(1) + computeNextValue() + case Some(other) => + c.abort( + c.enclosingPosition, + s"Unknown end binary point prefix, got: '$other' was expecting '.'" + ) + case None => + Literal(Constant(0)) + } + if (nextArgIndex < args.length) { val unused = args.mkString("") - c.abort(c.enclosingPosition, s"Unused interpolated values in range specifier: '$unused'") + c.abort( + c.enclosingPosition, + s"Unused interpolated values in range specifier: '$unused'" + ) } if (!currString.isEmpty || nextStringIndex < strings.length) { - val unused = currString + strings.slice(nextStringIndex, strings.length).mkString(", ") - c.abort(c.enclosingPosition, s"Unused characters in range specifier: '$unused'") + val unused = currString + strings + .slice(nextStringIndex, strings.length) + .mkString(", ") + c.abort( + c.enclosingPosition, + s"Unused characters in range specifier: '$unused'" + ) } - val startBound = if (startInclusive) { - q"_root_.chisel3.internal.firrtl.Closed($minArg)" - } else { - q"_root_.chisel3.internal.firrtl.Open($minArg)" - } - val endBound = if (endInclusive) { - q"_root_.chisel3.internal.firrtl.Closed($maxArg)" - } else { - q"_root_.chisel3.internal.firrtl.Open($maxArg)" - } + val startBound = + q"_root_.chisel3.internal.firrtl.IntervalRange.getBound($startInclusive, $minArg)" + + val endBound = + q"_root_.chisel3.internal.firrtl.IntervalRange.getBound($endInclusive, $maxArg)" + + val binaryPoint = + q"_root_.chisel3.internal.firrtl.IntervalRange.getBinaryPoint($binaryPointString)" - q"($startBound, $endBound)" + q"_root_.chisel3.internal.firrtl.IntervalRange($startBound, $endBound, $binaryPoint)" } } diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala index 3d10670e..3409ce94 100644 --- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.experimental._ +import chisel3.experimental.{Interval, _} import chisel3.internal.BaseBlackBox private[chisel3] object Emitter { @@ -33,6 +33,12 @@ private class Emitter(circuit: Circuit) { case d: UInt => s"UInt${d.width}" case d: SInt => s"SInt${d.width}" case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}" + case d: Interval => + val binaryPointString = d.binaryPoint match { + case UnknownBinaryPoint => "" + case KnownBinaryPoint(value) => s".$value" + } + d.toType case d: Analog => s"Analog${d.width}" case d: Vec[_] => s"${emitType(d.sample_element, clearDir)}[${d.length}]" case d: Record => { diff --git a/src/test/scala/chiselTests/IntervalRangeSpec.scala b/src/test/scala/chiselTests/IntervalRangeSpec.scala new file mode 100644 index 00000000..c152e72d --- /dev/null +++ b/src/test/scala/chiselTests/IntervalRangeSpec.scala @@ -0,0 +1,221 @@ +// See README.md for license details. + +package chiselTests + +import chisel3._ +import chisel3.experimental._ +import _root_.firrtl.{ir => firrtlir} +import chisel3.internal.firrtl.{BinaryPoint, IntervalRange, KnownBinaryPoint, UnknownBinaryPoint} +import org.scalatest.{FreeSpec, Matchers} + +//scalastyle:off method.name magic.number +class IntervalRangeSpec extends FreeSpec with Matchers { + + "IntervalRanges" - { + def C(b: BigDecimal): firrtlir.Bound = firrtlir.Closed(b) + + def O(b: BigDecimal): firrtlir.Bound = firrtlir.Open(b) + + def U(): firrtlir.Bound = firrtlir.UnknownBound + + def UBP(): BinaryPoint = UnknownBinaryPoint + + def checkRange(r: IntervalRange, l: firrtlir.Bound, u: firrtlir.Bound, b: BinaryPoint): Unit = { + r.lowerBound should be(l) + r.upperBound should be(u) + r.binaryPoint should be(b) + } + + def checkBinaryPoint(r: IntervalRange, b: BinaryPoint): Unit = { + r.binaryPoint should be(b) + } + + "IntervalRange describes the range of values of the Interval Type" - { + "Factory methods can create IntervalRanges" - { + "ranges can start or end open or closed, default binary point is none" in { + checkRange(range"[0,10]", C(0), C(10), 0.BP) + checkRange(range"[-1,10)", C(-1), O(10), 0.BP) + checkRange(range"(11,12]", O(11), C(12), 0.BP) + checkRange(range"(-21,-10)", O(-21), O(-10), 0.BP) + } + + "ranges can have unknown bounds" in { + checkRange(range"[?,10]", U(), C(10), 0.BP) + checkRange(range"(?,10]", U(), C(10), 0.BP) + checkRange(range"[-1,?]", C(-1), U(), 0.BP) + checkRange(range"[-1,?)", C(-1), U(), 0.BP) + checkRange(range"[?,?]", U(), U(), 0.BP) + checkRange(range"[?,?].?", U(), U(), UBP()) + } + + "binary points can be specified" in { + checkBinaryPoint(range"[?,10].0", 0.BP) + checkBinaryPoint(range"[?,10].2", 2.BP) + checkBinaryPoint(range"[?,10].?", UBP()) + } + "malformed ranges will throw ChiselException or are compile time errors" in { + // must be a cleverer way to show this + intercept[ChiselException] { + range"[19,5]" + } + assertDoesNotCompile(""" range"?,10] """) + assertDoesNotCompile(""" range"?,? """) + } + } + } + + "Ranges can be specified for UInt, SInt, and FixedPoint" - { + "invalid range specifiers should fail at compile time" in { + assertDoesNotCompile(""" range"" """) + assertDoesNotCompile(""" range"[]" """) + assertDoesNotCompile(""" range"0" """) + assertDoesNotCompile(""" range"[0]" """) + assertDoesNotCompile(""" range"[0, 1" """) + assertDoesNotCompile(""" range"0, 1]" """) + assertDoesNotCompile(""" range"[0, 1, 2]" """) + assertDoesNotCompile(""" range"[a]" """) + assertDoesNotCompile(""" range"[a, b]" """) + assertCompiles(""" range"[0, 1]" """) // syntax sanity check + } + + "range macros should allow open and closed bounds" in { + range"[-1, 1)" should be(range"[-1,1).0") + range"[-1, 1)" should be(IntervalRange(C(-1), O(1), 0.BP)) + range"[-1, 1]" should be(IntervalRange(C(-1), C(1), 0.BP)) + range"(-1, 1]" should be(IntervalRange(O(-1), C(1), 0.BP)) + range"(-1, 1)" should be(IntervalRange(O(-1), O(1), 0.BP)) + } + + "range specifiers should be whitespace tolerant" in { + range"[-1,1)" should be(IntervalRange(C(-1), O(1), 0.BP)) + range" [-1,1) " should be(IntervalRange(C(-1), O(1), 0.BP)) + range" [ -1 , 1 ) " should be(IntervalRange(C(-1), O(1), 0.BP)) + range" [ -1 , 1 ) " should be(IntervalRange(C(-1), O(1), 0.BP)) + } + + "range macros should work with interpolated variables" in { + val a = 10 + val b = -3 + + range"[$b, $a)" should be(IntervalRange(C(b), O(a), 0.BP)) + range"[${a + b}, $a)" should be(IntervalRange(C(a + b), O(a), 0.BP)) + range"[${-3 - 7}, ${-3 + a})" should be(IntervalRange(C(-10), O(-3 + a), 0.BP)) + + def number(n: Int): Int = n + + range"[${number(1)}, ${number(3)})" should be(IntervalRange(C(1), O(3), 0.BP)) + } + + "UInt should get the correct width from a range" in { + UInt(range"[0, 8)").getWidth should be(3) + UInt(range"[0, 8]").getWidth should be(4) + UInt(range"[0, 0]").getWidth should be(1) + } + + "SInt should get the correct width from a range" in { + SInt(range"[0, 8)").getWidth should be(4) + SInt(range"[0, 8]").getWidth should be(5) + SInt(range"[-4, 4)").getWidth should be(3) + SInt(range"[0, 0]").getWidth should be(1) + } + + "UInt should check that the range is valid" in { + an[ChiselException] should be thrownBy { + UInt(range"[1, 0]") + } + an[ChiselException] should be thrownBy { + UInt(range"[-1, 1]") + } + an[ChiselException] should be thrownBy { + UInt(range"(0,0]") + } + an[ChiselException] should be thrownBy { + UInt(range"[0,0)") + } + an[ChiselException] should be thrownBy { + UInt(range"(0,0)") + } + an[ChiselException] should be thrownBy { + UInt(range"(0,1)") + } + } + + "SInt should check that the range is valid" in { + an[ChiselException] should be thrownBy { + SInt(range"[1, 0]") + } + an[ChiselException] should be thrownBy { + SInt(range"(0,0]") + } + an[ChiselException] should be thrownBy { + SInt(range"[0,0)") + } + an[ChiselException] should be thrownBy { + SInt(range"(0,0)") + } + an[ChiselException] should be thrownBy { + SInt(range"(0,1)") + } + } + } + + "shift operations should work on ranges" - { + "<<, shiftLeft affects the bounds but not the binary point" in { + checkRange(range"[0,7].1", C(0), C(7), 1.BP) + checkRange(range"[0,7].1" << 1, C(0), C(14), 1.BP) + + checkRange(range"[2,7].2", C(2), C(7), 2.BP) + checkRange(range"[2,7].2" << 1, C(4), C(14), 2.BP) + } + + ">>, shiftRight affects the bounds but not the binary point" in { + checkRange(range"[0,7].0", C(0), C(7), 0.BP) + checkRange(range"[0,7].0" >> 1, C(0), C(3), 0.BP) + + checkRange(range"[0,7].1", C(0), C(7), 1.BP) + checkRange(range"[0,7].1" >> 1, C(0), C(3.5), 1.BP) + + checkRange(range"[2,7].2", C(2), C(7), 2.BP) + checkRange(range"[2,7].2" >> 1, C(1), C(3.5), 2.BP) + + checkRange(range"[2,7].2", C(2), C(7), 2.BP) + checkRange(range"[2,7].2" >> 2, C(0.5), C(1.75), 2.BP) + + // the 7(b111) >> 3 => 0.875(b0.111) but since + // binary point is two, lopping must occur so 0.875 becomes 0.75 + checkRange(range"[-8,7].2", C(-8), C(7), 2.BP) + checkRange(range"[-8,7].2" >> 3, C(-1), C(0.75), 2.BP) + + + checkRange(range"(0,7).0", O(0), O(7), 0.BP) + checkRange(range"(0,7).0" >> 1, O(0), O(3), 0.BP) + + checkRange(range"(0,7).1", O(0), O(7), 1.BP) + checkRange(range"(0,7).1" >> 1, O(0), O(3.5), 1.BP) + + checkRange(range"(2,7).2", O(2), O(7), 2.BP) + checkRange(range"(2,7).2" >> 1, O(1), O(3.5), 2.BP) + + checkRange(range"(2,7).2", O(2), O(7), 2.BP) + checkRange(range"(2,7).2" >> 2, O(0.5), O(1.75), 2.BP) + + // the 7(b111) >> 3 => 0.875(b0.111) but since + // binary point is two, lopping must occur so 0.875 becomes 0.75 + checkRange(range"(-8,7).2", O(-8), O(7), 2.BP) + checkRange(range"(-8,7).2" >> 3, O(-1), O(0.75), 2.BP) + } + + "set precision can change the bounds due to precision loss, direction of change is always to lower value" in { + intercept[ChiselException] { + checkRange(range"[-7.875,7.875].3".setPrecision(UnknownBinaryPoint), C(-7.875), C(7.875), 5.BP) + } + + checkRange(range"[-7.875,7.875].3", C(-7.875), C(7.875), 3.BP) + checkRange(range"[1.25,2].2".setPrecision(1.BP), C(1.0), C(2), 1.BP) + checkRange(range"[-7.875,7.875].3".setPrecision(5.BP), C(-7.875), C(7.875), 5.BP) + checkRange(range"[-7.875,7.875].3".setPrecision(1.BP), C(-8.0), C(7.5), 1.BP) + } + } + } + +} diff --git a/src/test/scala/chiselTests/IntervalSpec.scala b/src/test/scala/chiselTests/IntervalSpec.scala new file mode 100644 index 00000000..863771a3 --- /dev/null +++ b/src/test/scala/chiselTests/IntervalSpec.scala @@ -0,0 +1,919 @@ +// See LICENSE for license details. + +package chiselTests + +import scala.language.reflectiveCalls +import _root_.firrtl.ir.{Closed, Open} +import chisel3._ +import chisel3.internal.firrtl.{IntervalRange, KnownBinaryPoint} +import chisel3.internal.sourceinfo.{SourceInfo, UnlocatableSourceInfo} +import chisel3.stage.{ChiselGeneratorAnnotation, ChiselStage} +import chisel3.testers.BasicTester +import cookbook.CookbookTester +import firrtl.options.TargetDirAnnotation +import firrtl.passes.CheckTypes.InvalidConnect +import firrtl.passes.CheckWidths.{DisjointSqueeze, InvalidRange} +import firrtl.passes.{PassExceptions, WrapWithRemainder} +import firrtl.stage.{CompilerAnnotation, FirrtlCircuitAnnotation} +import firrtl.{FIRRTLException, HighFirrtlCompiler, LowFirrtlCompiler, MiddleFirrtlCompiler, MinimumVerilogCompiler, NoneCompiler, SystemVerilogCompiler, VerilogCompiler} +import org.scalatest.{FreeSpec, Matchers} + +//scalastyle:off magic.number +//noinspection TypeAnnotation + +object IntervalTestHelper { + + /** Compiles a Chisel Module to Verilog + * NOTE: This uses the "test_run_dir" as the default directory for generated code. + * @param compilerName the generator for the module + * @param gen the generator for the module + * @return the Verilog code as a string. + */ + //scalastyle:off cyclomatic.complexity + def makeFirrtl[T <: RawModule](compilerName: String)(gen: () => T): String = { + val c = compilerName match { + case "none" => new NoneCompiler() + case "high" => new HighFirrtlCompiler() + case "lo" => new LowFirrtlCompiler() + case "low" => new LowFirrtlCompiler() + case "middle" => new MiddleFirrtlCompiler() + case "verilog" => new VerilogCompiler() + case "mverilog" => new MinimumVerilogCompiler() + case "sverilog" => new SystemVerilogCompiler() + case _ => + throw new Exception( + s"Unknown compiler name '$compilerName'! (Did you misspell it?)" + ) + } + val compiler = CompilerAnnotation(c) + val annotations = Seq(new ChiselGeneratorAnnotation(gen), TargetDirAnnotation("test_run_dir/IntervalSpec"), compiler) + val processed = (new ChiselStage).run(annotations) + processed.collectFirst { case FirrtlCircuitAnnotation(source) => source } match { + case Some(circuit) => circuit.serialize + case _ => + throw new Exception( + s"makeFirrtl($compilerName) failed to generate firrtl circuit" + ) + } + } +} + +import chiselTests.IntervalTestHelper.makeFirrtl +import chisel3.experimental._ +import chisel3.experimental.Interval + +class IntervalTest1 extends Module { + val io = IO(new Bundle { + val in1 = Input(Interval(range"[0,4]")) + val in2 = Input(Interval(range"[0,4].3")) + val out = Output(Interval(range"[0,8].3")) + }) + + io.out := io.in1 + io.in2 +} + +class IntervalTester extends CookbookTester(10) { + + val dut = Module(new IntervalTest1) + + dut.io.in1 := BigInt(4).I + dut.io.in2 := 4.I + assert(dut.io.out === 8.I) + + val i = Interval(range"[0,10)") + stop() +} + +class IntervalTest2 extends Module { + val io = IO(new Bundle { + val p = Input(Bool()) + val in1 = Input(Interval(range"[0,4]")) + val in2 = Input(Interval(range"[0,6]")) + val out = Output(Interval()) + }) + + io.out := Mux(io.p, io.in1, io.in2) +} + +class IntervalTester2 extends CookbookTester(10) { + + val dut = Module(new IntervalTest2) + + dut.io.p := 1.U + dut.io.in1 := 4.I + dut.io.in2 := 5.I + assert(dut.io.out === 4.I) + + stop() +} + +class IntervalAddTester extends BasicTester { + + val in1 = Wire(Interval(range"[0,4]")) + val in2 = Wire(Interval(range"[0,4]")) + + in1 := 2.I + in2 := 2.I + + 5.U + + val result = in1 +& in2 + + assert(result === 4.I) + + stop() + +} + +class IntervalSetBinaryPointTester extends BasicTester { + implicit val sourceinfo: SourceInfo = UnlocatableSourceInfo + val in1 = Wire(Interval(range"[0,4].4")) + val in2 = in1.setPrecision(2) + + assert(in2.binaryPoint == KnownBinaryPoint(2)) + + in1 := 2.I + + val shiftedLeft = in1.increasePrecision(2) + + assert( + shiftedLeft.binaryPoint == KnownBinaryPoint(6), + s"Error: increasePrecision result ${shiftedLeft.range} expected bt = 2" + ) + + val shiftedRight = in1.decreasePrecision(2) + + assert( + shiftedRight.binaryPoint == KnownBinaryPoint(2), + s"Error: increasePrecision result ${shiftedRight.range} expected bt = 2" + ) + + stop() +} + +class MoreIntervalShiftTester extends BasicTester { + implicit val sourceinfo: SourceInfo = UnlocatableSourceInfo + + val in1 = Wire(Interval(range"[0,4].4")) + val in2 = in1.setPrecision(2) + + assert(in2.binaryPoint == KnownBinaryPoint(2)) + + val toShiftLeft = Wire(Interval(range"[0,4].4")) + val shiftedLeft = in1.increasePrecision(2) + + assert( + shiftedLeft.binaryPoint == KnownBinaryPoint(2), + s"Error: decreasePrecision result ${shiftedLeft.range} expected bt = 2" + ) + + val toShiftRight = Wire(Interval(range"[0,4].4")) + val shiftedRight = in1.decreasePrecision(2) + + assert( + shiftedRight.binaryPoint == KnownBinaryPoint(6), + s"Error: decreasePrecision result ${shiftedRight.range} expected bt = 2" + ) + + stop() +} + +/** + * This is a reality check not a test. Makes it easier to figure out + * what is going on in other places + * @param range a range for inputs + * @param targetRange a range for outputs + * @param startNum start here + * @param endNum end here + * @param incNum increment by this + */ +class ClipSqueezeWrapDemo(range: IntervalRange, + targetRange: IntervalRange, + startNum: Double, + endNum: Double, + incNum: Double) + extends BasicTester { + + val binaryPointAsInt = range.binaryPoint.asInstanceOf[KnownBinaryPoint].value +// val startValue = Interval.fromDouble(startNum, binaryPoint = binaryPointAsInt) +// val increment = Interval.fromDouble(incNum, binaryPoint = binaryPointAsInt) +// val endValue = Interval.fromDouble(endNum, binaryPoint = binaryPointAsInt) + val startValue = startNum.I(range.binaryPoint) + val increment = incNum.I(range.binaryPoint) + val endValue = endNum.I(range.binaryPoint) + + val counter = RegInit(Interval(range), startValue) + + counter := (counter + increment).squeeze(counter) + when(counter > endValue) { + stop() + } + + val clipped = counter.clip(0.U.asInterval(targetRange)) + val squeezed = counter.squeeze(0.U.asInterval(targetRange)) + val wrapped = counter.wrap(0.U.asInterval(targetRange)) + + when(counter === startValue) { + printf(s"Target range is $range\n") + printf("value clip squeeze wrap\n") + } + + printf( + " %d %d %d %d\n", + counter.asSInt(), + clipped.asSInt(), + squeezed.asSInt(), + wrapped.asSInt() + ) +} + +class SqueezeFunctionalityTester(range: IntervalRange, + startNum: BigDecimal, + endNum: BigDecimal, + increment: BigDecimal) + extends BasicTester { + + val counter = RegInit(0.U(10.W)) + counter := counter + 1.U + when(counter > 10.U) { + stop() + } + + val squeezeInterval = Wire(Interval(range)) + squeezeInterval := 0.I + + val squeezeTemplate = Wire(Interval(range)) + + val ss = WireInit(Interval(range), (-10).S.asInterval(range)) + + val toSqueeze = counter.asInterval(range) - ss + + squeezeTemplate := toSqueeze.squeeze(squeezeInterval) + + printf( + s"SqueezeTest %d %d.squeeze($range) => %d\n", + counter, + toSqueeze.asSInt(), + squeezeTemplate.asSInt() + ) +} + +/** + * Demonstrate a simple counter register with an Interval type + */ +class IntervalRegisterTester extends BasicTester { + + val range = range"[-2,5]" + val counter = RegInit(Interval(range), (-1).I) + counter := (counter + 1.I) + .squeeze(counter) // this works with other types, why not Interval + when(counter > 4.I) { + stop() + } +} + +//noinspection ScalaStyle +class IntervalWrapTester extends BasicTester { + + val t1 = Wire(Interval(range"[-2, 12]")) + t1 := (-2).I + val u1 = 0.U(3.W) + val r1 = RegInit(u1) + r1 := u1 + val t2 = t1.wrap(u1) + val t3 = t1.wrap(r1) + + assert( + t2.range.upper == Closed(7), + s"t1 upper ${t2.range.upper} expected ${Closed(7)}" + ) + assert( + t3.range.upper == Closed(7), + s"t1 upper ${t3.range.upper} expected ${Closed(7)}" + ) + + val in1 = WireInit(Interval(range"[0,9].6"), 0.I) + val in2 = WireInit(Interval(range"[1,6).4"), 2.I) + val in3 = in1.wrap(in2) + + assert( + in3.range.lower == Closed(1), + s"in3 lower ${in3.range.lower} expected ${Closed(1)}" + ) + assert( + in3.range.upper == Open(6), + s"in3 upper ${in3.range.upper} expected ${Open(6)}" + ) + assert( + in3.binaryPoint == KnownBinaryPoint(6), + s"in3 binaryPoint ${in3.binaryPoint} expected ${KnownBinaryPoint(2)}" + ) + + val enclosedRange = range"[-2, 5]" + val base = Wire(Interval(range"[-4, 6]")) + val enclosed = WireInit(Interval(enclosedRange), 0.I) + val enclosing = WireInit(Interval(range"[-6, 8]"), 0.I) + val overlapLeft = WireInit(Interval(range"[-10,-2]"), (-3).I) + val overlapRight = WireInit(Interval(range"[-1,10]"), 0.I) + + val w1 = base.wrap(enclosed) + val w2 = base.wrap(enclosing) + val w3 = base.wrap(overlapLeft) + val w4 = base.wrap(overlapRight) + val w7 = base.wrap(enclosedRange) + + base := 6.I + + assert(w1 === (-2).I) + assert(w2 === 6.I) + assert(w3 === (-3).I) + assert(w4 === 6.I) + assert(w7 === (-2).I) + + stop() +} + +class IntervalClipTester extends BasicTester { + + val enclosedRange = range"[-2, 5]" + val base = Wire(Interval(range"[-4, 6]")) + val enclosed = Wire(Interval(enclosedRange)) + val enclosing = Wire(Interval(range"[-6, 8]")) + val overlapLeft = Wire(Interval(range"[-10,-2]")) + val overlapRight = Wire(Interval(range"[-1,10]")) + val disjointLeft = Wire(Interval(range"[-14,-7]")) + val disjointRight = Wire(Interval(range"[7,11]")) + + enclosed := DontCare + enclosing := DontCare + overlapLeft := DontCare + overlapRight := DontCare + disjointLeft := DontCare + disjointRight := DontCare + + val enclosedResult = base.clip(enclosed) + val enclosingResult = base.clip(enclosing) + val overlapLeftResult = base.clip(overlapLeft) + val overlapRightResult = base.clip(overlapRight) + val disjointLeftResult = base.clip(disjointLeft) + val disjointRightResult = base.clip(disjointRight) + val enclosedViaRangeString = base.clip(enclosedRange) + + base := 6.I + + assert(enclosedResult === 5.I) + assert(enclosingResult === 6.I) + assert(overlapLeftResult === (-2).I) + assert(overlapRightResult === 6.I) + assert(disjointLeftResult === (-7).I) + assert(disjointRightResult === 7.I) + + assert(enclosedViaRangeString === 5.I) + + stop() +} + +class IntervalChainedAddTester extends BasicTester { + + val intervalResult = Wire(Interval()) + val uintResult = Wire(UInt()) + + intervalResult := 1.I + 1.I + 1.I + 1.I + 1.I + 1.I + 1.I + uintResult := 1.U +& 1.U +& 1.U +& 1.U +& 1.U +& 1.U +& 1.U + + assert(intervalResult === 7.I) + assert(uintResult === 7.U) + stop() +} + +class IntervalChainedMulTester extends BasicTester { + + val intervalResult = Wire(Interval()) + val uintResult = Wire(UInt()) + + intervalResult := 2.I * 2.I * 2.I * 2.I * 2.I * 2.I * 2.I + uintResult := 2.U * 2.U * 2.U * 2.U * 2.U * 2.U * 2.U + + assert(intervalResult === 128.I) + assert(uintResult === 128.U) + stop() +} + +class IntervalChainedSubTester extends BasicTester { + val intervalResult1 = Wire(Interval()) + val intervalResult2 = Wire(Interval()) + val uIntResult = Wire(UInt()) + val sIntResult = Wire(SInt()) + val fixedResult = Wire(FixedPoint()) + + intervalResult1 := 17.I - 2.I - 2.I - 2.I - 2.I - 2.I - 2.I // gives same result as -& operand version below + intervalResult2 := 17.I -& 2.I -& 2.I -& 2.I -& 2.I -& 2.I -& 2.I + uIntResult := 17.U -& 2.U -& 2.U -& 2.U -& 2.U -& 2.U -& 2.U + fixedResult := 17.0.F(0.BP) -& 2.0.F(0.BP) -& 2.0.F(0.BP) -& 2.0.F(0.BP) -& 2.0 + .F(0.BP) -& 2.0.F(0.BP) -& 2.0.F(0.BP) + sIntResult := 17.S -& 2.S -& 2.S -& 2.S -& 2.S -& 2.S -& 2.S + + assert(uIntResult === 5.U) + assert(sIntResult === 5.S) + assert(fixedResult.asUInt === 5.U) + assert(intervalResult1 === 5.I) + assert(intervalResult2 === 5.I) + + stop() +} + +//TODO: need tests for dynamic shifts on intervals +class IntervalSpec extends FreeSpec with Matchers with ChiselRunners { + + type TempFirrtlException = Exception + + "Test a simple interval add" in { + assertTesterPasses { new IntervalAddTester } + } + "Intervals can be created" in { + assertTesterPasses { new IntervalTester } + } + "Test a simple interval mux" in { + assertTesterPasses { new IntervalTester2 } + } + "Intervals can have binary points set" in { + assertTesterPasses { new IntervalSetBinaryPointTester } + } + "Interval literals that don't fit in explicit ranges are caught by chisel" - { + "case 1: does not fit in specified width" in { + intercept[ChiselException] { + ChiselGeneratorAnnotation( + () => + new BasicTester { + val x = 5.I(3.W, 0.BP) + } + ).elaborate + } + } + "case 2: doesn't fit in specified range" in { + intercept[ChiselException] { + ChiselGeneratorAnnotation( + () => + new BasicTester { + val x = 5.I(range"[0,4]") + } + ).elaborate + } + } + } + + "Let's take a look at the results of squeeze over small range" in { + assertTesterPasses { + new ClipSqueezeWrapDemo( + range = range"[-10,33].0", + targetRange = range"[-4,17].0", + startNum = -4.0, + endNum = 30.0, + incNum = 1.0 + ) + } + assertTesterPasses { + new ClipSqueezeWrapDemo( + range = range"[-2,5].1", + targetRange = range"[-1,3].1", + startNum = -2.0, + endNum = 5.0, + incNum = 0.5 + ) + } + } + "Intervals can be squeezed into another intervals range" in { + assertTesterPasses { + new SqueezeFunctionalityTester( + range"[-2,5]", + BigDecimal(-10), + BigDecimal(10), + BigDecimal(1.0) + ) + } + } + "Intervals can be wrapped with wrap operator" in { + assertTesterPasses { new IntervalWrapTester } + } + + "Interval compile pathologies: clip, wrap, and squeeze have different behavior" - { + "wrap target range is completely left of source" in { + intercept[TempFirrtlException] { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + val disjointLeft = WireInit(Interval(range"[-7,-5]"), (-6).I) + val w5 = base.wrap(disjointLeft) + stop() + }) + } + } + "wrap target range is completely right of source" in { + intercept[TempFirrtlException] { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.wrap(disjointLeft) + stop() + }) + } + } + "clip target range is completely left of source" in { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + val disjointLeft = WireInit(Interval(range"[-7,-5]"), (-6).I) + val w5 = base.clip(disjointLeft) + chisel3.assert(w5 === (-5).I) + stop() + }) + } + "clip target range is completely right of source" in { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.clip(disjointLeft) + chisel3.assert(w5.asSInt === 7.S) + stop() + }) + } + "squeeze target range is completely right of source" in { + intercept[TempFirrtlException] { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.squeeze(disjointLeft) + chisel3.assert(w5.asSInt === 6.S) + stop() + }) + } + } + "squeeze target range is completely left of source" in { + intercept[TempFirrtlException] { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + val disjointLeft = WireInit(Interval(range"[-7, -5]"), 8.I) + val w5 = base.squeeze(disjointLeft) + stop() + }) + } + } + + def makeCircuit(operation: String, + sourceRange: IntervalRange, + targetRange: IntervalRange): () => RawModule = { () => + new Module { + val io = IO(new Bundle { val out = Output(Interval()) }) + val base = Wire(Interval(sourceRange)) + base := 6.I + + val disjointLeft = WireInit(Interval(targetRange), 8.I) + val w5 = operation match { + case "clip" => base.clip(disjointLeft) + case "wrap" => base.wrap(disjointLeft) + case "squeeze" => base.squeeze(disjointLeft) + } + io.out := w5 + } + } + + "disjoint ranges should error when used with clip, wrap and squeeze" - { + + def mustGetException(disjointLeft: Boolean, + operation: String): Boolean = { + val (rangeA, rangeB) = if (disjointLeft) { + (range"[-4, 6]", range"[7,10]") + } else { + (range"[7,10]", range"[-4, 6]") + } + try { + makeFirrtl("low")(makeCircuit(operation, rangeA, rangeB)) + false + } catch { + case _: InvalidConnect | _: PassExceptions | _: InvalidRange | _: WrapWithRemainder | _: DisjointSqueeze => + true + case _: Throwable => + false + } + } + + "Range A disjoint left, operation clip should generate useful error" in { + mustGetException(disjointLeft = true, "clip") should be(false) + } + "Range A largely out of bounds left, operation wrap should generate useful error" in { + mustGetException(disjointLeft = true, "wrap") should be(true) + } + "Range A disjoint left, operation squeeze should generate useful error" in { + mustGetException(disjointLeft = true, "squeeze") should be(true) + } + "Range A disjoint right, operation clip should generate useful error" in { + mustGetException(disjointLeft = false, "clip") should be(true) + } + "Range A disjoint right, operation wrap should generate useful error" in { + mustGetException(disjointLeft = false, "wrap") should be(true) + } + "Range A disjoint right, operation squeeze should generate useful error" in { + mustGetException(disjointLeft = false, "squeeze") should be(true) + } + } + + "Errors are sometimes inconsistent or incorrectly labelled as Firrtl Internal Error" - { + "squeeze disjoint is not internal error when defined in BasicTester" in { + intercept[DisjointSqueeze] { + makeFirrtl("low")( + () => + new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + val base2 = Wire(Interval(range"[-4, 6]")) + base := 6.I + base2 := 5.I + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.squeeze(disjointLeft) + stop() + } + ) + } + } + "wrap disjoint is not internal error when defined in BasicTester" in { + intercept[DisjointSqueeze] { + makeFirrtl("low")( + () => + new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + val base2 = Wire(Interval(range"[-4, 6]")) + base := 6.I + base2 := 5.I + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.squeeze(disjointLeft) + stop() + } + ) + } + } + "squeeze disjoint from Module gives exception" in { + intercept[DisjointSqueeze] { + makeFirrtl("lo")( + () => + new Module { + val io = IO(new Bundle { + val out = Output(Interval()) + }) + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.squeeze(disjointLeft) + io.out := w5 + } + ) + } + } + "clip disjoint from Module gives no error" in { + makeFirrtl("lo")( + () => + new Module { + val io = IO(new Bundle { + val out = Output(Interval()) + }) + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.clip(disjointLeft) + io.out := w5 + } + ) + } + "wrap disjoint from Module wrap with remainder" in { + intercept[WrapWithRemainder] { + makeFirrtl("lo")( + () => + new Module { + val io = IO(new Bundle { + val out = Output(Interval()) + }) + val base = Wire(Interval(range"[-4, 6]")) + base := 6.I + + val disjointLeft = WireInit(Interval(range"[7,10]"), 8.I) + val w5 = base.wrap(disjointLeft) + io.out := w5 + } + ) + } + } + } + + "assign literal out of range of interval" in { + intercept[firrtl.passes.CheckTypes.InvalidConnect] { + assertTesterPasses(new BasicTester { + val base = Wire(Interval(range"[-4, 6]")) + base := (-8).I + }) + } + } + } + + "Intervals should catch assignment of literals outside of range" - { + "when literal is too small" in { + intercept[InvalidConnect] { + makeFirrtl("lo")( + () => + new Module { + val io = IO(new Bundle { val out = Output(Interval()) }) + val base = Wire(Interval(range"[-4, 6]")) + base := (-7).I + io.out := base + } + ) + } + } + "when literal is too big" in { + intercept[InvalidConnect] { + makeFirrtl("low")( + () => + new Module { + val io = IO(new Bundle { val out = Output(Interval()) }) + val base = Wire(Interval(range"[-4, 6]")) + base := 9.I + io.out := base + } + ) + } + } + } + + "Intervals can be shifted left" in { + assertTesterPasses(new BasicTester { + val i1 = 3.0.I(range"[0,4]") + val shifted1 = i1 << 2 + val shiftUInt = WireInit(1.U(8.W)) + val shifted2 = i1 << shiftUInt + + chisel3.assert(shifted1 === 12.I, "shifted 1 should be 12, it wasn't") + chisel3.assert(shifted2 === 6.I, "shifted 2 should be 6 it wasn't") + stop() + }) + } + + "Intervals can be shifted right" in { + assertTesterPasses(new BasicTester { + val i1 = 12.0.I(range"[0,15]") + val shifted1 = i1 >> 2 + val shiftUInt = 1.U + val shifted2 = i1 >> shiftUInt + + chisel3.assert(shifted1 === 3.I) + chisel3.assert(shifted2 === 6.I) + stop() + }) + } + + "Intervals can be used to construct registers" in { + assertTesterPasses { new IntervalRegisterTester } + } + "Intervals can be clipped with clip (saturate) operator" in { + assertTesterPasses { new IntervalClipTester } + } + "Intervals adds same answer as UInt" in { + assertTesterPasses { new IntervalChainedAddTester } + } + "Intervals should produce canonically smaller ranges via inference" in { + val loFirrtl = makeFirrtl("low")( + () => + new Module { + val io = IO(new Bundle { + val in = Input(Interval(range"[0,1]")) + val out = Output(Interval()) + }) + + val intervalResult = Wire(Interval()) + + intervalResult := 1.I + 1.I + 1.I + 1.I + 1.I + 1.I + 1.I + io.out := intervalResult + } + ) + loFirrtl.contains("output io_out : SInt<4>") should be(true) + + } + "Intervals multiplication same answer as UInt" in { + assertTesterPasses { new IntervalChainedMulTester } + } + "Intervals subs same answer as UInt" in { + assertTesterPasses { new IntervalChainedSubTester } + } + "Test clip, wrap and a variety of ranges" - { + """range"[0.0,10.0].2" => range"[2,6].2"""" in { + assertTesterPasses(new BasicTester { + + val sourceRange = range"[0.0,10.0].2" + val targetRange = range"[2,6].2" + + val sourceSimulator = ScalaIntervalSimulator(sourceRange) + val targetSimulator = ScalaIntervalSimulator(targetRange) + + for (sourceValue <- sourceSimulator.allValues) { + val clippedValue = Wire(Interval(targetRange)) + clippedValue := sourceSimulator + .makeLit(sourceValue) + .clip(clippedValue) + + val goldClippedValue = + targetSimulator.makeLit(targetSimulator.clip(sourceValue)) + + // Useful for debugging + // printf(s"source value $sourceValue clipped gold value %d compare to clipped value %d\n", + // goldClippedValue.asSInt(), clippedValue.asSInt()) + + chisel3.assert(goldClippedValue === clippedValue) + + val wrappedValue = Wire(Interval(targetRange)) + wrappedValue := sourceSimulator + .makeLit(sourceValue) + .wrap(wrappedValue) + + val goldWrappedValue = + targetSimulator.makeLit(targetSimulator.wrap(sourceValue)) + + // Useful for debugging + // printf(s"source value $sourceValue wrapped gold value %d compare to wrapped value %d\n", + // goldWrappedValue.asSInt(), wrappedValue.asSInt()) + + chisel3.assert(goldWrappedValue === wrappedValue) + } + + stop() + }) + } + } + + "Test squeeze over a variety of ranges" - { + """range"[2,6].2""" in { + assertTesterPasses(new BasicTester { + + val sourceRange = range"[0.0,10.0].2" + val targetRange = range"[2,6].3" + + val sourceSimulator = ScalaIntervalSimulator(sourceRange) + val targetSimulator = ScalaIntervalSimulator(targetRange) + + for (sourceValue <- sourceSimulator.allValues) { + val squeezedValue = Wire(Interval(targetRange)) + squeezedValue := sourceSimulator + .makeLit(sourceValue) + .clip(squeezedValue) + + val goldSqueezedValue = + targetSimulator.makeLit(targetSimulator.clip(sourceValue)) + + // Useful for debugging + // printf(s"source value $sourceValue squeezed gold value %d compare to squeezed value %d\n", + // goldSqueezedValue.asSInt(), squeezedValue.asSInt()) + + chisel3.assert(goldSqueezedValue === squeezedValue) + } + + stop() + }) + } + } + + "test asInterval" - { + "use with UInt" in { + assertTesterPasses(new BasicTester { + val u1 = Wire(UInt(5.W)) + u1 := 7.U + val i1 = u1.asInterval(range"[0,15]") + val i2 = u1.asInterval(range"[0,15].2") + printf("i1 %d\n", i1.asUInt) + chisel3.assert(i1 === 7.I, "i1") + stop() + }) + } + "use with SInt" in { + assertTesterPasses(new BasicTester { + val s1 = Wire(SInt(5.W)) + s1 := 7.S + val s2 = Wire(SInt(5.W)) + s2 := 7.S + val i1 = s1.asInterval(range"[-16,15]") + val i2 = s1.asInterval(range"[-16,15].1") + printf("i1 %d\n", i1.asSInt) + printf("i2 %d\n", i2.asSInt) + chisel3.assert(i1 === 7.I, "i1 is wrong") + chisel3.assert(i2 === (3.5).I(binaryPoint = 1.BP), "i2 is wrong") + stop() + }) + } + "more SInt tests" in { + assertTesterPasses(new BasicTester { + chisel3.assert(7.S.asInterval(range"[-16,15].1") === 3.5.I(binaryPoint = 1.BP), "adding binary point") + stop() + }) + } + } +} diff --git a/src/test/scala/chiselTests/RangeSpec.scala b/src/test/scala/chiselTests/RangeSpec.scala index e2313f34..e85a477d 100644 --- a/src/test/scala/chiselTests/RangeSpec.scala +++ b/src/test/scala/chiselTests/RangeSpec.scala @@ -4,101 +4,9 @@ package chiselTests import chisel3._ import chisel3.experimental.ChiselRange - -import chisel3.internal.firrtl.{Open, Closed} -import org.scalatest.{Matchers, FreeSpec} +import chisel3.internal.firrtl._ +import firrtl.ir.{Closed, Open} +import org.scalatest.{FreeSpec, Matchers} class RangeSpec extends FreeSpec with Matchers { - "Ranges can be specified for UInt, SInt, and FixedPoint" - { - "invalid range specifiers should fail at compile time" in { - assertDoesNotCompile(""" range"" """) - assertDoesNotCompile(""" range"[]" """) - assertDoesNotCompile(""" range"0" """) - assertDoesNotCompile(""" range"[0]" """) - assertDoesNotCompile(""" range"[0, 1" """) - assertDoesNotCompile(""" range"0, 1]" """) - assertDoesNotCompile(""" range"[0, 1, 2]" """) - assertDoesNotCompile(""" range"[a]" """) - assertDoesNotCompile(""" range"[a, b]" """) - assertCompiles(""" range"[0, 1]" """) // syntax sanity check - } - - "range macros should allow open and closed bounds" in { - range"[-1, 1)" should be( (Closed(-1), Open(1)) ) - range"[-1, 1]" should be( (Closed(-1), Closed(1)) ) - range"(-1, 1]" should be( (Open(-1), Closed(1)) ) - range"(-1, 1)" should be( (Open(-1), Open(1)) ) - } - - "range specifiers should be whitespace tolerant" in { - range"[-1,1)" should be( (Closed(-1), Open(1)) ) - range" [-1,1) " should be( (Closed(-1), Open(1)) ) - range" [ -1 , 1 ) " should be( (Closed(-1), Open(1)) ) - range" [ -1 , 1 ) " should be( (Closed(-1), Open(1)) ) - } - - "range macros should work with interpolated variables" in { - val a = 10 - val b = -3 - - range"[$b, $a)" should be( (Closed(b), Open(a)) ) - range"[${a + b}, $a)" should be( (Closed(a + b), Open(a)) ) - range"[${-3 - 7}, ${-3 + a})" should be( (Closed(-10), Open(-3 + a)) ) - - def number(n: Int): Int = n - range"[${number(1)}, ${number(3)})" should be( (Closed(1), Open(3)) ) - } - - "UInt should get the correct width from a range" in { - UInt(range"[0, 8)").getWidth should be (3) - UInt(range"[0, 8]").getWidth should be (4) - UInt(range"[0, 0]").getWidth should be (1) - } - - "SInt should get the correct width from a range" in { - SInt(range"[0, 8)").getWidth should be (4) - SInt(range"[0, 8]").getWidth should be (5) - SInt(range"[-4, 4)").getWidth should be (3) - SInt(range"[0, 0]").getWidth should be (1) - } - - "UInt should check that the range is valid" in { - an [IllegalArgumentException] should be thrownBy { - UInt(range"[1, 0]") - } - an [IllegalArgumentException] should be thrownBy { - UInt(range"[-1, 1]") - } - an [IllegalArgumentException] should be thrownBy { - UInt(range"(0,0]") - } - an [IllegalArgumentException] should be thrownBy { - UInt(range"[0,0)") - } - an [IllegalArgumentException] should be thrownBy { - UInt(range"(0,0)") - } - an [IllegalArgumentException] should be thrownBy { - UInt(range"(0,1)") - } - } - - "SInt should check that the range is valid" in { - an [IllegalArgumentException] should be thrownBy { - SInt(range"[1, 0]") - } - an [IllegalArgumentException] should be thrownBy { - SInt(range"(0,0]") - } - an [IllegalArgumentException] should be thrownBy { - SInt(range"[0,0)") - } - an [IllegalArgumentException] should be thrownBy { - SInt(range"(0,0)") - } - an [IllegalArgumentException] should be thrownBy { - SInt(range"(0,1)") - } - } - } } diff --git a/src/test/scala/chiselTests/ScalaIntervalSimulatorTest.scala b/src/test/scala/chiselTests/ScalaIntervalSimulatorTest.scala new file mode 100644 index 00000000..0bf8741e --- /dev/null +++ b/src/test/scala/chiselTests/ScalaIntervalSimulatorTest.scala @@ -0,0 +1,96 @@ +// See README.md for license details. + +package chiselTests + +import chisel3._ +import chisel3.experimental._ +import org.scalatest.{FreeSpec, Matchers} + +class ScalaIntervalSimulatorSpec extends FreeSpec with Matchers { + "clip tests" - { + "Should work for closed ranges" in { + val sim = ScalaIntervalSimulator(range"[2,4]") + sim.clip(BigDecimal(1.0)) should be (2.0) + sim.clip(BigDecimal(2.0)) should be (2.0) + sim.clip(BigDecimal(3.0)) should be (3.0) + sim.clip(BigDecimal(4.0)) should be (4.0) + sim.clip(BigDecimal(5.0)) should be (4.0) + } + "Should work for closed ranges with binary point" in { + val sim = ScalaIntervalSimulator(range"[2,6].2") + sim.clip(BigDecimal(1.75)) should be (2.0) + sim.clip(BigDecimal(2.0)) should be (2.0) + sim.clip(BigDecimal(2.25)) should be (2.25) + sim.clip(BigDecimal(2.5)) should be (2.5) + sim.clip(BigDecimal(5.75)) should be (5.75) + sim.clip(BigDecimal(6.0)) should be (6.0) + sim.clip(BigDecimal(6.25)) should be (6.0) + sim.clip(BigDecimal(6.5)) should be (6.0) + sim.clip(BigDecimal(8.5)) should be (6.0) + } + "Should work for open ranges" in { + val sim = ScalaIntervalSimulator(range"(2,4)") + sim.clip(BigDecimal(1.0)) should be (3.0) + sim.clip(BigDecimal(2.0)) should be (3.0) + sim.clip(BigDecimal(3.0)) should be (3.0) + sim.clip(BigDecimal(4.0)) should be (3.0) + sim.clip(BigDecimal(5.0)) should be (3.0) + } + "Should work for open ranges with binary point" in { + val sim = ScalaIntervalSimulator(range"(2,6).2") + sim.clip(BigDecimal(1.75)) should be (2.25) + sim.clip(BigDecimal(2.0)) should be (2.25) + sim.clip(BigDecimal(2.25)) should be (2.25) + sim.clip(BigDecimal(2.5)) should be (2.5) + sim.clip(BigDecimal(5.75)) should be (5.75) + sim.clip(BigDecimal(6.0)) should be (5.75) + sim.clip(BigDecimal(6.25)) should be (5.75) + sim.clip(BigDecimal(6.5)) should be (5.75) + sim.clip(BigDecimal(8.5)) should be (5.75) + } + } + "wrap tests" - { + "Should work for closed ranges" in { + val sim = ScalaIntervalSimulator(range"[2,6]") + sim.wrap(BigDecimal(1.0)) should be (6.0) + sim.wrap(BigDecimal(2.0)) should be (2.0) + sim.wrap(BigDecimal(3.0)) should be (3.0) + sim.wrap(BigDecimal(4.0)) should be (4.0) + sim.wrap(BigDecimal(5.0)) should be (5.0) + sim.wrap(BigDecimal(6.0)) should be (6.0) + sim.wrap(BigDecimal(7.0)) should be (2.0) + } + "Should work for closed ranges with binary point" in { + val sim = ScalaIntervalSimulator(range"[2,6].2") + sim.wrap(BigDecimal(1.75)) should be (6.0) + sim.wrap(BigDecimal(2.0)) should be (2.0) + sim.wrap(BigDecimal(2.25)) should be (2.25) + sim.wrap(BigDecimal(2.5)) should be (2.5) + sim.wrap(BigDecimal(5.75)) should be (5.75) + sim.wrap(BigDecimal(6.0)) should be (6.0) + sim.wrap(BigDecimal(6.25)) should be (2.0) + sim.wrap(BigDecimal(6.5)) should be (2.25) + } + "Should work for open ranges" in { + val sim = ScalaIntervalSimulator(range"(2,6)") + sim.wrap(BigDecimal(1.0)) should be (4.0) + sim.wrap(BigDecimal(2.0)) should be (5.0) + sim.wrap(BigDecimal(3.0)) should be (3.0) + sim.wrap(BigDecimal(4.0)) should be (4.0) + sim.wrap(BigDecimal(5.0)) should be (5.0) + sim.wrap(BigDecimal(6.0)) should be (3.0) + sim.wrap(BigDecimal(7.0)) should be (4.0) + } + "Should work for open ranges with binary point" in { + val sim = ScalaIntervalSimulator(range"(2,6).2") + sim.wrap(BigDecimal(1.75)) should be (5.5) + sim.wrap(BigDecimal(2.0)) should be (5.75) + sim.wrap(BigDecimal(2.25)) should be (2.25) + sim.wrap(BigDecimal(2.5)) should be (2.5) + sim.wrap(BigDecimal(5.75)) should be (5.75) + sim.wrap(BigDecimal(6.0)) should be (2.25) + sim.wrap(BigDecimal(6.25)) should be (2.5) + sim.wrap(BigDecimal(7.0)) should be (3.25) + } + } +} diff --git a/src/test/scala/chiselTests/Util.scala b/src/test/scala/chiselTests/Util.scala index f71cd7f3..8c9bc4ea 100644 --- a/src/test/scala/chiselTests/Util.scala +++ b/src/test/scala/chiselTests/Util.scala @@ -5,6 +5,9 @@ package chiselTests import chisel3._ +import chisel3.experimental.Interval +import chisel3.internal.firrtl.{IntervalRange, KnownBinaryPoint, Width} +import _root_.firrtl.{ir => firrtlir} class PassthroughModuleIO extends Bundle { val in = Input(UInt(32.W)) @@ -20,4 +23,53 @@ class PassthroughModule extends Module with AbstractPassthroughModule class PassthroughMultiIOModule extends MultiIOModule with AbstractPassthroughModule class PassthroughRawModule extends RawModule with AbstractPassthroughModule +case class ScalaIntervalSimulator(intervalRange: IntervalRange) { + val binaryPoint: Int = intervalRange.binaryPoint.asInstanceOf[KnownBinaryPoint].value + val epsilon: Double = 1.0 / math.pow(2.0, binaryPoint.toDouble) + + val (lower, upper) = (intervalRange.lowerBound, intervalRange.upperBound) match { + + case (firrtlir.Closed(lower1), firrtlir.Closed(upper1)) => (lower1, upper1) + case (firrtlir.Closed(lower1), firrtlir.Open(upper1)) => (lower1, upper1 - epsilon) + case (firrtlir.Open(lower1), firrtlir.Closed(upper1)) => (lower1 + epsilon, upper1) + case (firrtlir.Open(lower1), firrtlir.Open(upper1)) => (lower1 + epsilon, upper1 - epsilon) + case _ => + throw new Exception(s"lower and upper bounds must be defined, range here is $intervalRange") + } + + def clip(value: BigDecimal): BigDecimal = { + + if (value < lower) { + lower + } + else if (value > upper) { + upper + } + else { + value + } + } + + def wrap(value: BigDecimal): BigDecimal = { + + if (value < lower) { + upper + (value - lower) + epsilon + } + else if (value > upper) { + ((value - upper) - epsilon) + lower + } + else { + value + } + } + + def allValues: Iterator[BigDecimal] = { + (lower to upper by epsilon).toIterator + } + + def makeLit(value: BigDecimal): Interval = { + Interval.fromDouble(value.toDouble, width = Width(), binaryPoint = binaryPoint.BP) + } +} + |
