summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/Chisel/Core.scala10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/main/scala/Chisel/Core.scala b/src/main/scala/Chisel/Core.scala
index 1ab84681..6a3758f0 100644
--- a/src/main/scala/Chisel/Core.scala
+++ b/src/main/scala/Chisel/Core.scala
@@ -853,18 +853,24 @@ object Bool {
object Mux {
def apply[T <: Data](cond: Bool, con: T, alt: T): T = (con, alt) match {
+ // Handle Mux(cond, UInt, Bool) carefully so that the concrete type is UInt
+ case (c: Bool, a: Bool) => doMux(cond, c, a).asInstanceOf[T]
+ case (c: UInt, a: Bool) => doMux(cond, c, a << 0).asInstanceOf[T]
+ case (c: Bool, a: UInt) => doMux(cond, c << 0, a).asInstanceOf[T]
case (c: Bits, a: Bits) => doMux(cond, c, a).asInstanceOf[T]
+ // FIRRTL doesn't support Mux for aggregates, so use a when instead
case _ => doWhen(cond, con, alt)
}
- // These implementations are type-unsafe and rely on FIRRTL for type checking
private def doMux[T <: Bits](cond: Bool, con: T, alt: T): T = {
+ require(con.getClass == alt.getClass, s"can't Mux between ${con.getClass} and ${alt.getClass}")
val d = alt.cloneTypeWidth(con.width max alt.width)
pushOp(DefPrim(d, MultiplexOp, cond.ref, con.ref, alt.ref))
}
// This returns an lvalue, which it most definitely should not
private def doWhen[T <: Data](cond: Bool, con: T, alt: T): T = {
- val res = Wire(alt, init = alt)
+ require(con.getClass == alt.getClass, s"can't Mux between ${con.getClass} and ${alt.getClass}")
+ val res = Wire(t = alt.cloneTypeWidth(con.width max alt.width), init = alt)
when (cond) { res := con }
res
}