aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/Passes.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/Passes.scala')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala55
1 files changed, 47 insertions, 8 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 7490c479..1e8ceae2 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -1128,24 +1128,63 @@ object ExpandWhens extends Pass {
}
}
+// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
+// TODO replace UInt with zero-width wire instead
+object Legalize extends Pass {
+ def name = "Legalize"
+ def legalizeShiftRight (e: DoPrim): Expression = e.op match {
+ case SHIFT_RIGHT_OP => {
+ val amount = e.consts(0).toInt
+ val width = long_BANG(tpe(e.args(0)))
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ case t: UIntType => UIntValue(0, IntWidth(1))
+ case t: SIntType =>
+ DoPrim(BITS_SELECT_OP, e.args, Seq(msb, msb), SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type ${t} for Primop Shift Right")
+ }
+ } else {
+ e
+ }
+ }
+ case _ => e
+ }
+ def run (c: Circuit): Circuit = {
+ def legalizeE (e: Expression): Expression = {
+ e map (legalizeE) match {
+ case e: DoPrim => legalizeShiftRight(e)
+ case e => e
+ }
+ }
+ def legalizeS (s: Stmt): Stmt = s map (legalizeS) map (legalizeE)
+ def legalizeM (m: Module): Module = m map (legalizeS)
+ Circuit(c.info, c.modules.map(legalizeM), c.main)
+ }
+}
+
object ConstProp extends Pass {
def name = "Constant Propogation"
var mname = ""
+
def const_prop_e (e:Expression) : Expression = {
e map (const_prop_e) match {
case (e:DoPrim) => {
e.op match {
case SHIFT_RIGHT_OP => {
- (e.args(0)) match {
- case (x:UIntValue) => {
- val b = x.value >> e.consts(0).toInt
- UIntValue(b,tpe(e).as[UIntType].get.width)
+ val amount = e.consts(0).toInt
+ e.args(0) match {
+ case x: UIntValue => {
+ val v = x.value >> amount
+ val w = (x.width - IntWidth(amount)) max IntWidth(1)
+ UIntValue(v, w)
}
- case (x:SIntValue) => {
- val b = x.value >> e.consts(0).toInt
- SIntValue(b,tpe(e).as[SIntType].get.width)
+ case x: SIntValue => { // take sign bit if shift amount is larger than arg width
+ val v = x.value >> amount
+ val w = (x.width - IntWidth(amount)) max IntWidth(1)
+ SIntValue(v, w)
}
- case (x) => e
+ case _ => e
}
}
case BITS_SELECT_OP => {