summaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/Chisel/Core.scala66
1 files changed, 42 insertions, 24 deletions
diff --git a/src/main/scala/Chisel/Core.scala b/src/main/scala/Chisel/Core.scala
index 46b1d4eb..6f104573 100644
--- a/src/main/scala/Chisel/Core.scala
+++ b/src/main/scala/Chisel/Core.scala
@@ -292,6 +292,7 @@ abstract class Data(dirArg: Direction) extends Id {
private[Chisel] def lref: Alias = Alias(this)
private[Chisel] def ref: Arg = if (isLit) litArg() else lref
def cloneType: this.type
+ def cloneTypeWidth(width: Int): this.type
def name = getRefForId(this).name
def debugName = mod.debugName + "." + getRefForId(this).debugName
def litArg(): LitArg = null
@@ -300,9 +301,17 @@ abstract class Data(dirArg: Direction) extends Id {
def floLitValue: Float = intBitsToFloat(litValue().toInt)
def dblLitValue: Double = longBitsToDouble(litValue().toLong)
def getWidth: Int = flatten.map(_.getWidth).reduce(_ + _)
- def maxWidth(other: Data, amt: BigInt): Int = -1
- def sumWidth(amt: BigInt): Int = -1
- def sumWidth(other: Data, amt: BigInt): Int = -1
+ def knownWidth: Boolean = flatten.forall(_.knownWidth)
+
+ def maxWidth(other: Data, amt: Int): Int =
+ if (knownWidth && other.knownWidth) ((getWidth max other.getWidth) + amt) else -1
+ def sumWidth(amt: Int): Int = if (knownWidth) (getWidth + amt).toInt else -1
+ def sumWidth(other: Data, amt: Int): Int =
+ if (knownWidth && other.knownWidth) (getWidth + other.getWidth + amt).toInt else -1
+ def sumPow2Width(other: Data): Int =
+ if (knownWidth && other.knownWidth) (getWidth + (1 << other.getWidth)).toInt else -1
+ def rshWidth(amt: Int): Int = if (knownWidth) (0 max (getWidth - amt)) else -1
+
def flatten: IndexedSeq[Bits]
def fromBits(n: Bits): this.type = {
var i = 0
@@ -526,7 +535,11 @@ class BitPat(val value: BigInt, val mask: BigInt, width: Int) {
}
abstract class Element(dirArg: Direction, val width: Int) extends Data(dirArg) {
- override def getWidth: Int = width
+ override def knownWidth: Boolean = width >= 0
+ override def getWidth: Int = {
+ require(knownWidth)
+ width
+ }
}
object Clock {
@@ -535,6 +548,7 @@ object Clock {
sealed class Clock(dirArg: Direction) extends Element(dirArg, 1) {
def cloneType: this.type = Clock(dirArg).asInstanceOf[this.type]
+ def cloneTypeWidth(width: Int): this.type = cloneType
def flatten: IndexedSeq[Bits] = throwException("Clock.flatten")
def toType: Kind = ClockType(isFlipVar)
}
@@ -544,7 +558,6 @@ sealed abstract class Bits(dirArg: Direction, width: Int, lit: Option[LitArg]) e
override def isLit(): Boolean = lit.isDefined
override def litValue(): BigInt = lit.get.num
def fromInt(x: BigInt): this.type = makeLit(x, -1)
- def cloneTypeWidth(width: Int): this.type
def cloneType: this.type = cloneTypeWidth(width)
override def flatten: IndexedSeq[Bits] = IndexedSeq(this)
@@ -576,14 +589,6 @@ sealed abstract class Bits(dirArg: Direction, width: Int, lit: Option[LitArg]) e
final def apply(x: UInt, y: UInt): UInt =
apply(x.litValue(), y.litValue())
- def maxWidth(other: Bits, amt: Int): Int =
- if (getWidth >= 0 && other.getWidth >= 0) ((getWidth max other.getWidth) + amt) else -1
- override def sumWidth(amt: BigInt): Int = if (getWidth >= 0) (getWidth + amt).toInt else -1
- def sumWidth(other: Bits, amt: BigInt): Int =
- if (getWidth >= 0 && other.getWidth >= 0) (getWidth + other.getWidth + amt).toInt else -1
- def sumPow2Width(other: Bits): Int =
- if (getWidth >= 0 && other.getWidth >= 0) (getWidth + (1 << other.getWidth)).toInt else -1
-
def :=(other: Bits) =
pushCommand(Connect(this.lref, other.ref))
@@ -698,10 +703,10 @@ sealed class UInt(dir: Direction, width: Int, lit: Option[ULit] = None) extends
def <= (other: UInt): Bool = compop(LessEqOp, other)
def >= (other: UInt): Bool = compop(GreaterEqOp, other)
- def << (other: BigInt): UInt = binop(ShiftLeftOp, other, sumWidth(other))
+ def << (other: BigInt): UInt = binop(ShiftLeftOp, other, sumWidth(other.toInt))
def << (other: Int): UInt = this << BigInt(other)
def << (other: UInt): UInt = binop(DynamicShiftLeftOp, other, sumPow2Width(other))
- def >> (other: BigInt): UInt = binop(ShiftRightOp, other, sumWidth(-other))
+ def >> (other: BigInt): UInt = binop(ShiftRightOp, other, rshWidth(other.toInt))
def >> (other: Int): UInt = this >> BigInt(other)
def >> (other: UInt): UInt = binop(DynamicShiftRightOp, other, sumWidth(0))
@@ -714,13 +719,13 @@ sealed class UInt(dir: Direction, width: Int, lit: Option[ULit] = None) extends
def != (that: BitPat): Bool = that != this
def zext(): SInt = {
- val x = SInt(width = getWidth + 1)
+ val x = SInt(width = sumWidth(1))
pushCommand(DefPrim(x, x.toType, ConvertOp, Seq(ref), NoLits))
x
}
def asSInt(): SInt = {
- val x = SInt(width = getWidth)
+ val x = SInt(width = sumWidth(0))
pushCommand(DefPrim(x, x.toType, AsSIntOp, Seq(ref), NoLits))
x
}
@@ -782,15 +787,15 @@ sealed class SInt(dir: Direction, width: Int, lit: Option[SLit] = None) extends
def >= (other: SInt): Bool = compop(GreaterEqOp, other)
def abs(): UInt = Mux(this < SInt(0), (-this).toUInt, this.toUInt)
- def << (other: BigInt): SInt = binop(ShiftLeftOp, other, sumWidth(other))
+ def << (other: BigInt): SInt = binop(ShiftLeftOp, other, sumWidth(other.toInt))
def << (other: Int): SInt = this << BigInt(other)
def << (other: UInt): SInt = binop(DynamicShiftLeftOp, other, sumPow2Width(other))
- def >> (other: BigInt): SInt = binop(ShiftRightOp, other, sumWidth(-other))
+ def >> (other: BigInt): SInt = binop(ShiftRightOp, other, rshWidth(other.toInt))
def >> (other: Int): SInt = this >> BigInt(other)
def >> (other: UInt): SInt = binop(DynamicShiftRightOp, other, sumWidth(0))
def asUInt(): UInt = {
- val x = UInt(width = getWidth)
+ val x = UInt(width = sumWidth(0))
pushCommand(DefPrim(x, x.toType, AsUIntOp, Seq(ref), NoLits))
x
}
@@ -838,12 +843,25 @@ object Bool {
}
object Mux {
+ private def multiplex[T <: Data](cond: Bool, con: T, alt: T): T = {
+ val d = alt.cloneTypeWidth(con.maxWidth(alt, 0))
+ if (con.getClass != alt.getClass) // TODO: Figure out a better way to do this
+ throwException(s"Can't Mux between types ${con.getClass.getName} and ${alt.getClass.getName}")
+ pushCommand(DefPrim(d, d.toType, MultiplexOp, Seq(cond.ref, con.ref, alt.ref), NoLits))
+ d
+ }
def apply[T <: Data](cond: Bool, con: T, alt: T): T = {
- val w = Wire(alt, init = alt)
- when (cond) {
- w := con
+ if (con.getClass == alt.getClass && con.isInstanceOf[Bits]) {
+ // TODO: figure out a better way to dispatch this
+ multiplex(cond, con, alt)
+ } else {
+ // TODO: this shouldn't return an lvalue!
+ val w = Wire(alt, init = alt)
+ when (cond) {
+ w := con
+ }
+ w
}
- w
}
}