From e9b2946c962f91a04611e32b1a9d03f78e7edf2b Mon Sep 17 00:00:00 2001 From: Fabian Schuiki Date: Fri, 16 Apr 2021 16:18:32 +0200 Subject: Fix signedness of xor const prop with zero (#2179) Constant propagation of the Xor op folds `xor(a, SInt(0))` to `asUInt(a)`. For comparison, Or folds to `asUInt(pad(a, W))`. This can be a problem in the following case: circuit Foo : module Foo : input a: UInt<3> output b: UInt<4> b <= asUInt(xor(asSInt(a), SInt<4>(0))) This would emit the assignment as `b = a` instead of the sign-extended `b = {{1{a[2]}},a}`. This requires adjusting the `pad(e, t)` function use in const prop, which currently just inserts a `Pad` prim op with the requested output type. However, the function advertises that it pads *to the width* of the type `t`. Some of the folds rely on this and request the padding of a SInt to the width of a UInt. But the current implementation then then actually returns a `Pad` op with type UInt, instead of the SInt that was requested.--- src/main/scala/firrtl/transforms/ConstantPropagation.scala | 14 ++++++++++++-- src/test/scala/firrtlTests/ConstantPropagationTests.scala | 7 +++++++ 2 files changed, 19 insertions(+), 2 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 5610c7e7..bc1fc9af 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -29,7 +29,17 @@ object ConstantPropagation { /** Pads e to the width of t */ def pad(e: Expression, t: Type) = (bitWidth(e.tpe), bitWidth(t)) match { - case (we, wt) if we < wt => DoPrim(Pad, Seq(e), Seq(wt), t) + case (we, wt) if we < wt => + DoPrim( + Pad, + Seq(e), + Seq(wt), + e.tpe match { + case UIntType(_) => UIntType(IntWidth(wt)) + case SIntType(_) => SIntType(IntWidth(wt)) + case _ => e.tpe + } + ) case (we, wt) if we == wt => e } @@ -252,7 +262,7 @@ class ConstantPropagation extends Transform with RegisteredTransform with Depend } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs - case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) + case SIntLiteral(v, _) if v == BigInt(0) => asUInt(pad(rhs, e.tpe), e.tpe) case _ => e } def matchingArgsValue(e: DoPrim, arg: Expression) = UIntLiteral(0, getWidth(arg.tpe)) diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index bc7f92e6..ababb95b 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1531,22 +1531,29 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { val input = s"""|circuit Foo: | module Foo: + | input in1: SInt<3> | output out1: UInt<2> | output out2: UInt<2> | output out3: UInt<2> + | output out4: UInt<4> | out1 <= xor(SInt<2>(-1), SInt<2>(1)) | out2 <= or(SInt<2>(-1), SInt<2>(1)) | out3 <= and(SInt<2>(-1), SInt<2>(-2)) + | out4 <= xor(in1, SInt<4>(0)) |""".stripMargin val check = s"""|circuit Foo: | module Foo: + | input in1: SInt<3> | output out1: UInt<2> | output out2: UInt<2> | output out3: UInt<2> + | output out4: UInt<4> | out1 <= UInt<2>(2) | out2 <= UInt<2>(3) | out3 <= UInt<2>(2) + | node _GEN_0 = pad(in1, 4) + | out4 <= asUInt(_GEN_0) |""".stripMargin execute(input, check, Seq.empty) } -- cgit v1.2.3