diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/Chisel/Core.scala | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/src/main/scala/Chisel/Core.scala b/src/main/scala/Chisel/Core.scala index 1ab84681..6a3758f0 100644 --- a/src/main/scala/Chisel/Core.scala +++ b/src/main/scala/Chisel/Core.scala @@ -853,18 +853,24 @@ object Bool { object Mux { def apply[T <: Data](cond: Bool, con: T, alt: T): T = (con, alt) match { + // Handle Mux(cond, UInt, Bool) carefully so that the concrete type is UInt + case (c: Bool, a: Bool) => doMux(cond, c, a).asInstanceOf[T] + case (c: UInt, a: Bool) => doMux(cond, c, a << 0).asInstanceOf[T] + case (c: Bool, a: UInt) => doMux(cond, c << 0, a).asInstanceOf[T] case (c: Bits, a: Bits) => doMux(cond, c, a).asInstanceOf[T] + // FIRRTL doesn't support Mux for aggregates, so use a when instead case _ => doWhen(cond, con, alt) } - // These implementations are type-unsafe and rely on FIRRTL for type checking private def doMux[T <: Bits](cond: Bool, con: T, alt: T): T = { + require(con.getClass == alt.getClass, s"can't Mux between ${con.getClass} and ${alt.getClass}") val d = alt.cloneTypeWidth(con.width max alt.width) pushOp(DefPrim(d, MultiplexOp, cond.ref, con.ref, alt.ref)) } // This returns an lvalue, which it most definitely should not private def doWhen[T <: Data](cond: Bool, con: T, alt: T): T = { - val res = Wire(alt, init = alt) + require(con.getClass == alt.getClass, s"can't Mux between ${con.getClass} and ${alt.getClass}") + val res = Wire(t = alt.cloneTypeWidth(con.width max alt.width), init = alt) when (cond) { res := con } res } |
