aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala14
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala7
2 files changed, 19 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 5610c7e7..bc1fc9af 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -29,7 +29,17 @@ object ConstantPropagation {
/** Pads e to the width of t */
def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match {
- case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t)
+ case (we, wt) if we < wt =>
+ DoPrim(
+ Pad,
+ Seq(e),
+ Seq(wt),
+ e.tpe match {
+ case UIntType(_) => UIntType(IntWidth(wt))
+ case SIntType(_) => SIntType(IntWidth(wt))
+ case _ => e.tpe
+ }
+ )
case (we, wt) if we == wt => e
}
@@ -252,7 +262,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend
}
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
- case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
+ case SIntLiteral(v, _) if v == BigInt(0) => asUInt(pad(rhs, e.tpe), e.tpe)
case _ => e
}
def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe))
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index bc7f92e6..ababb95b 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -1531,22 +1531,29 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
val input =
s"""|circuit Foo:
| module Foo:
+ | input in1: SInt<3>
| output out1: UInt<2>
| output out2: UInt<2>
| output out3: UInt<2>
+ | output out4: UInt<4>
| out1 <= xor(SInt<2>(-1), SInt<2>(1))
| out2 <= or(SInt<2>(-1), SInt<2>(1))
| out3 <= and(SInt<2>(-1), SInt<2>(-2))
+ | out4 <= xor(in1, SInt<4>(0))
|""".stripMargin
val check =
s"""|circuit Foo:
| module Foo:
+ | input in1: SInt<3>
| output out1: UInt<2>
| output out2: UInt<2>
| output out3: UInt<2>
+ | output out4: UInt<4>
| out1 <= UInt<2>(2)
| out2 <= UInt<2>(3)
| out3 <= UInt<2>(2)
+ | node _GEN_0 = pad(in1, 4)
+ | out4 <= asUInt(_GEN_0)
|""".stripMargin
execute(input, check, Seq.empty)
}