aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Chen2020-07-23 11:14:35 -0700
committerGitHub2020-07-23 18:14:35 +0000
commit1927dc6574b9eee315c8f24441df390f2ce793c7 (patch)
tree8f942d8d4c7c46c59db476a4329d119d0ac77c79 /src
parentea558ad79ed0e65df73b5a01ceea690e5b0479ca (diff)
mask bits when propagating bitwise ops (#1745)
* ConstProp: test bitwise op of signed literals * ConstProp: use bit mask for FoldOr/FoldXor * handle and also * add UIntLiteral.masked helper Co-authored-by: Jack Koenig <koenig@sifive.com>
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/ir/IR.scala9
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala15
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala24
3 files changed, 45 insertions, 3 deletions
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index 275cbe51..023f53fd 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -297,6 +297,15 @@ case class UIntLiteral(value: BigInt, width: Width) extends Literal {
object UIntLiteral {
def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1))
def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value))
+
+ /** Utility to construct UIntLiterals masked by the width
+ *
+ * This supports truncating negative values as well as values that are too wide for the width
+ */
+ def masked(value: BigInt, width: IntWidth): UIntLiteral = {
+ val mask = (BigInt(1) << width.width.toInt) - 1
+ UIntLiteral(value & mask, width)
+ }
}
case class SIntLiteral(value: BigInt, width: Width) extends Literal {
def tpe = SIntType(width)
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 000adc15..8ad3489f 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -207,7 +207,10 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
object FoldAND extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value & c2.value, c1.width max c2.width)
+ def fold(c1: Literal, c2: Literal) = {
+ val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ UIntLiteral.masked(c1.value & c2.value, width)
+ }
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
case SIntLiteral(v, w) if v == BigInt(0) => UIntLiteral(0, w)
@@ -218,7 +221,10 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
object FoldOR extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value | c2.value, c1.width max c2.width)
+ def fold(c1: Literal, c2: Literal) = {
+ val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ UIntLiteral.masked((c1.value | c2.value), width)
+ }
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
@@ -229,7 +235,10 @@ class ConstantPropagation extends Transform with DependencyAPIMigration with Res
}
object FoldXOR extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width)
+ def fold(c1: Literal, c2: Literal) = {
+ val width = (c1.width max c2.width).asInstanceOf[IntWidth]
+ UIntLiteral.masked((c1.value ^ c2.value), width)
+ }
def simplify(e: Expression, lhs: Literal, rhs: Expression) = lhs match {
case UIntLiteral(v, _) if v == BigInt(0) => rhs
case SIntLiteral(v, _) if v == BigInt(0) => asUInt(rhs, e.tpe)
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 131f9466..d81f8687 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -1508,6 +1508,30 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
execute(input, check, Seq.empty)
}
+ it should "optimize bitwise operations of signed literals" in {
+ val input =
+ s"""|circuit Foo:
+ | module Foo:
+ | output out1: UInt<2>
+ | output out2: UInt<2>
+ | output out3: UInt<2>
+ | out1 <= xor(SInt<2>(-1), SInt<2>(1))
+ | out2 <= or(SInt<2>(-1), SInt<2>(1))
+ | out3 <= and(SInt<2>(-1), SInt<2>(-2))
+ |""".stripMargin
+ val check =
+ s"""|circuit Foo:
+ | module Foo:
+ | output out1: UInt<2>
+ | output out2: UInt<2>
+ | output out3: UInt<2>
+ | out1 <= UInt<2>(2)
+ | out2 <= UInt<2>(3)
+ | out3 <= UInt<2>(2)
+ |""".stripMargin
+ execute(input, check, Seq.empty)
+ }
+
}