From 1927dc6574b9eee315c8f24441df390f2ce793c7 Mon Sep 17 00:00:00 2001 From: Albert Chen Date: Thu, 23 Jul 2020 11:14:35 -0700 Subject: mask bits when propagating bitwise ops (#1745) * ConstProp: test bitwise op of signed literals * ConstProp: use bit mask for FoldOr/FoldXor * handle and also * add UIntLiteral.masked helper Co-authored-by: Jack Koenig --- src/main/scala/firrtl/ir/IR.scala | 9 ++++++++ .../firrtl/transforms/ConstantPropagation.scala | 15 +++++++++++--- .../firrtlTests/ConstantPropagationTests.scala | 24 ++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 275cbe51..023f53fd 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -297,6 +297,15 @@ case class UIntLiteral(value: BigInt, width: Width) extends Literal { object UIntLiteral { def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1)) def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value)) + + /** Utility to construct UIntLiterals masked by the width + * + * This supports truncating negative values as well as values that are too wide for the width + */ + def masked(value: BigInt, width: IntWidth): UIntLiteral = { + val mask = (BigInt(1) << width.width.toInt) - 1 + UIntLiteral(value & mask, width) + } } case class SIntLiteral(value: BigInt, width: Width) extends Literal { def tpe = SIntType(width) diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 000adc15..8ad3489f 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -207,7 +207,10 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } object FoldAND extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def fold(c1: Literal, c2: Literal) = { + val width = (c1.width max c2.width).asInstanceOf[IntWidth] + UIntLiteral.masked(c1.value & c2.value, width) + } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w) @@ -218,7 +221,10 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } object FoldOR extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def fold(c1: Literal, c2: Literal) = { + val width = (c1.width max c2.width).asInstanceOf[IntWidth] + UIntLiteral.masked((c1.value | c2.value), width) + } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) @@ -229,7 +235,10 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res } object FoldXOR extends FoldCommutativeOp { - def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def fold(c1: Literal, c2: Literal) = { + val width = (c1.width max c2.width).asInstanceOf[IntWidth] + UIntLiteral.masked((c1.value ^ c2.value), width) + } def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match { case UIntLiteral(v, _) if v == BigInt(0) => rhs case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe) diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index 131f9466..d81f8687 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -1508,6 +1508,30 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec { execute(input, check, Seq.empty) } + it should "optimize bitwise operations of signed literals" in { + val input = + s"""|circuit Foo: + | module Foo: + | output out1: UInt<2> + | output out2: UInt<2> + | output out3: UInt<2> + | out1 <= xor(SInt<2>(-1), SInt<2>(1)) + | out2 <= or(SInt<2>(-1), SInt<2>(1)) + | out3 <= and(SInt<2>(-1), SInt<2>(-2)) + |""".stripMargin + val check = + s"""|circuit Foo: + | module Foo: + | output out1: UInt<2> + | output out2: UInt<2> + | output out3: UInt<2> + | out1 <= UInt<2>(2) + | out2 <= UInt<2>(3) + | out3 <= UInt<2>(2) + |""".stripMargin + execute(input, check, Seq.empty) + } + } -- cgit v1.2.3