diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 55 |
1 files changed, 47 insertions, 8 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 7490c479..1e8ceae2 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -1128,24 +1128,63 @@ object ExpandWhens extends Pass { } } +// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts +// TODO replace UInt with zero-width wire instead +object Legalize extends Pass { + def name = "Legalize" + def legalizeShiftRight (e: DoPrim): Expression = e.op match { + case SHIFT_RIGHT_OP => { + val amount = e.consts(0).toInt + val width = long_BANG(tpe(e.args(0))) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + case t: UIntType => UIntValue(0, IntWidth(1)) + case t: SIntType => + DoPrim(BITS_SELECT_OP, e.args, Seq(msb, msb), SIntType(IntWidth(1))) + case t => error(s"Unsupported type ${t} for Primop Shift Right") + } + } else { + e + } + } + case _ => e + } + def run (c: Circuit): Circuit = { + def legalizeE (e: Expression): Expression = { + e map (legalizeE) match { + case e: DoPrim => legalizeShiftRight(e) + case e => e + } + } + def legalizeS (s: Stmt): Stmt = s map (legalizeS) map (legalizeE) + def legalizeM (m: Module): Module = m map (legalizeS) + Circuit(c.info, c.modules.map(legalizeM), c.main) + } +} + object ConstProp extends Pass { def name = "Constant Propogation" var mname = "" + def const_prop_e (e:Expression) : Expression = { e map (const_prop_e) match { case (e:DoPrim) => { e.op match { case SHIFT_RIGHT_OP => { - (e.args(0)) match { - case (x:UIntValue) => { - val b = x.value >> e.consts(0).toInt - UIntValue(b,tpe(e).as[UIntType].get.width) + val amount = e.consts(0).toInt + e.args(0) match { + case x: UIntValue => { + val v = x.value >> amount + val w = (x.width - IntWidth(amount)) max IntWidth(1) + UIntValue(v, w) } - case (x:SIntValue) => { - val b = x.value >> e.consts(0).toInt - SIntValue(b,tpe(e).as[SIntType].get.width) + case x: SIntValue => { // take sign bit if shift amount is larger than arg width + val v = x.value >> amount + val w = (x.width - IntWidth(amount)) max IntWidth(1) + SIntValue(v, w) } - case (x) => e + case _ => e } } case BITS_SELECT_OP => { |
