package firrtl.passes import firrtl.ir._ import firrtl.Mappers._ import firrtl.options.PreservesAll import firrtl.{Transform, WSubAccess, WSubField, WSubIndex} object PullMuxes extends Pass with PreservesAll[Transform] { override def prerequisites = firrtl.stage.Forms.Deduped def run(c: Circuit): Circuit = { def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { case ex: WSubField => ex.expr match { case exx: Mux => Mux(exx.cond, WSubField(exx.tval, ex.name, ex.tpe, ex.flow), WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe) case exx: ValidIf => ValidIf(exx.cond, WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) case _ => ex // case exx => exx causes failed tests } case ex: WSubIndex => ex.expr match { case exx: Mux => Mux(exx.cond, WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe) case exx: ValidIf => ValidIf(exx.cond, WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) case _ => ex // case exx => exx causes failed tests } case ex: WSubAccess => ex.expr match { case exx: Mux => Mux(exx.cond, WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe) case exx: ValidIf => ValidIf(exx.cond, WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) case _ => ex // case exx => exx causes failed tests } case ex => ex } def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e val modulesx = c.modules.map { case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) case (m:ExtModule) => m } Circuit(c.info, modulesx, c.main) } }