From b8fc7dd6a7e7d65639d43474fd938c6c65cb1d32 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Fri, 31 Jul 2015 14:38:45 -0700 Subject: Implement getWidth more completely and less buggily --- src/main/scala/Chisel/Core.scala | 66 +++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/src/main/scala/Chisel/Core.scala b/src/main/scala/Chisel/Core.scala index 46b1d4eb..6f104573 100644 --- a/src/main/scala/Chisel/Core.scala +++ b/src/main/scala/Chisel/Core.scala @@ -292,6 +292,7 @@ abstract class Data(dirArg: Direction) extends Id { private[Chisel] def lref: Alias = Alias(this) private[Chisel] def ref: Arg = if (isLit) litArg() else lref def cloneType: this.type + def cloneTypeWidth(width: Int): this.type def name = getRefForId(this).name def debugName = mod.debugName + "." + getRefForId(this).debugName def litArg(): LitArg = null @@ -300,9 +301,17 @@ abstract class Data(dirArg: Direction) extends Id { def floLitValue: Float = intBitsToFloat(litValue().toInt) def dblLitValue: Double = longBitsToDouble(litValue().toLong) def getWidth: Int = flatten.map(_.getWidth).reduce(_ + _) - def maxWidth(other: Data, amt: BigInt): Int = -1 - def sumWidth(amt: BigInt): Int = -1 - def sumWidth(other: Data, amt: BigInt): Int = -1 + def knownWidth: Boolean = flatten.forall(_.knownWidth) + + def maxWidth(other: Data, amt: Int): Int = + if (knownWidth && other.knownWidth) ((getWidth max other.getWidth) + amt) else -1 + def sumWidth(amt: Int): Int = if (knownWidth) (getWidth + amt).toInt else -1 + def sumWidth(other: Data, amt: Int): Int = + if (knownWidth && other.knownWidth) (getWidth + other.getWidth + amt).toInt else -1 + def sumPow2Width(other: Data): Int = + if (knownWidth && other.knownWidth) (getWidth + (1 << other.getWidth)).toInt else -1 + def rshWidth(amt: Int): Int = if (knownWidth) (0 max (getWidth - amt)) else -1 + def flatten: IndexedSeq[Bits] def fromBits(n: Bits): this.type = { var i = 0 @@ -526,7 +535,11 @@ class BitPat(val value: BigInt, val mask: BigInt, width: Int) { } abstract class Element(dirArg: Direction, val width: Int) extends Data(dirArg) { - override def getWidth: Int = width + override def knownWidth: Boolean = width >= 0 + override def getWidth: Int = { + require(knownWidth) + width + } } object Clock { @@ -535,6 +548,7 @@ object Clock { sealed class Clock(dirArg: Direction) extends Element(dirArg, 1) { def cloneType: this.type = Clock(dirArg).asInstanceOf[this.type] + def cloneTypeWidth(width: Int): this.type = cloneType def flatten: IndexedSeq[Bits] = throwException("Clock.flatten") def toType: Kind = ClockType(isFlipVar) } @@ -544,7 +558,6 @@ sealed abstract class Bits(dirArg: Direction, width: Int, lit: Option[LitArg]) e override def isLit(): Boolean = lit.isDefined override def litValue(): BigInt = lit.get.num def fromInt(x: BigInt): this.type = makeLit(x, -1) - def cloneTypeWidth(width: Int): this.type def cloneType: this.type = cloneTypeWidth(width) override def flatten: IndexedSeq[Bits] = IndexedSeq(this) @@ -576,14 +589,6 @@ sealed abstract class Bits(dirArg: Direction, width: Int, lit: Option[LitArg]) e final def apply(x: UInt, y: UInt): UInt = apply(x.litValue(), y.litValue()) - def maxWidth(other: Bits, amt: Int): Int = - if (getWidth >= 0 && other.getWidth >= 0) ((getWidth max other.getWidth) + amt) else -1 - override def sumWidth(amt: BigInt): Int = if (getWidth >= 0) (getWidth + amt).toInt else -1 - def sumWidth(other: Bits, amt: BigInt): Int = - if (getWidth >= 0 && other.getWidth >= 0) (getWidth + other.getWidth + amt).toInt else -1 - def sumPow2Width(other: Bits): Int = - if (getWidth >= 0 && other.getWidth >= 0) (getWidth + (1 << other.getWidth)).toInt else -1 - def :=(other: Bits) = pushCommand(Connect(this.lref, other.ref)) @@ -698,10 +703,10 @@ sealed class UInt(dir: Direction, width: Int, lit: Option[ULit] = None) extends def <= (other: UInt): Bool = compop(LessEqOp, other) def >= (other: UInt): Bool = compop(GreaterEqOp, other) - def << (other: BigInt): UInt = binop(ShiftLeftOp, other, sumWidth(other)) + def << (other: BigInt): UInt = binop(ShiftLeftOp, other, sumWidth(other.toInt)) def << (other: Int): UInt = this << BigInt(other) def << (other: UInt): UInt = binop(DynamicShiftLeftOp, other, sumPow2Width(other)) - def >> (other: BigInt): UInt = binop(ShiftRightOp, other, sumWidth(-other)) + def >> (other: BigInt): UInt = binop(ShiftRightOp, other, rshWidth(other.toInt)) def >> (other: Int): UInt = this >> BigInt(other) def >> (other: UInt): UInt = binop(DynamicShiftRightOp, other, sumWidth(0)) @@ -714,13 +719,13 @@ sealed class UInt(dir: Direction, width: Int, lit: Option[ULit] = None) extends def != (that: BitPat): Bool = that != this def zext(): SInt = { - val x = SInt(width = getWidth + 1) + val x = SInt(width = sumWidth(1)) pushCommand(DefPrim(x, x.toType, ConvertOp, Seq(ref), NoLits)) x } def asSInt(): SInt = { - val x = SInt(width = getWidth) + val x = SInt(width = sumWidth(0)) pushCommand(DefPrim(x, x.toType, AsSIntOp, Seq(ref), NoLits)) x } @@ -782,15 +787,15 @@ sealed class SInt(dir: Direction, width: Int, lit: Option[SLit] = None) extends def >= (other: SInt): Bool = compop(GreaterEqOp, other) def abs(): UInt = Mux(this < SInt(0), (-this).toUInt, this.toUInt) - def << (other: BigInt): SInt = binop(ShiftLeftOp, other, sumWidth(other)) + def << (other: BigInt): SInt = binop(ShiftLeftOp, other, sumWidth(other.toInt)) def << (other: Int): SInt = this << BigInt(other) def << (other: UInt): SInt = binop(DynamicShiftLeftOp, other, sumPow2Width(other)) - def >> (other: BigInt): SInt = binop(ShiftRightOp, other, sumWidth(-other)) + def >> (other: BigInt): SInt = binop(ShiftRightOp, other, rshWidth(other.toInt)) def >> (other: Int): SInt = this >> BigInt(other) def >> (other: UInt): SInt = binop(DynamicShiftRightOp, other, sumWidth(0)) def asUInt(): UInt = { - val x = UInt(width = getWidth) + val x = UInt(width = sumWidth(0)) pushCommand(DefPrim(x, x.toType, AsUIntOp, Seq(ref), NoLits)) x } @@ -838,12 +843,25 @@ object Bool { } object Mux { + private def multiplex[T <: Data](cond: Bool, con: T, alt: T): T = { + val d = alt.cloneTypeWidth(con.maxWidth(alt, 0)) + if (con.getClass != alt.getClass) // TODO: Figure out a better way to do this + throwException(s"Can't Mux between types ${con.getClass.getName} and ${alt.getClass.getName}") + pushCommand(DefPrim(d, d.toType, MultiplexOp, Seq(cond.ref, con.ref, alt.ref), NoLits)) + d + } def apply[T <: Data](cond: Bool, con: T, alt: T): T = { - val w = Wire(alt, init = alt) - when (cond) { - w := con + if (con.getClass == alt.getClass && con.isInstanceOf[Bits]) { + // TODO: figure out a better way to dispatch this + multiplex(cond, con, alt) + } else { + // TODO: this shouldn't return an lvalue! + val w = Wire(alt, init = alt) + when (cond) { + w := con + } + w } - w } } -- cgit v1.2.3