aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKevin Laeufer2021-11-10 06:33:39 -0800
committerGitHub2021-11-10 09:33:39 -0500
commit18b9a987552492928aa1199d8cf498fe561c2f03 (patch)
treebffeadc7cad0aa3e552f945a6aa4f6e405e9a02d /src
parent7ef3e1ba9d1a748bd39f8d4f279e8d4e34bb4cc7 (diff)
smt: fix handling of div primitive in formal backend (#2409)
We never tested the case where the width of the numerator was less than the denominator. This should fix any issue with this combination.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala19
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala2
2 files changed, 14 insertions, 7 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index c7524e21..6f454a22 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -58,17 +58,22 @@ private object FirrtlExpressionSemantics {
val width = args.map(getWidth).sum
BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Div, Seq(num, den), _) =>
- val (width, op) = if (isSigned(num)) {
- (getWidth(num) + 1, Op.SignedDiv)
- } else { (getWidth(num), Op.UnsignedDiv) }
- BVOp(op, toSMT(num, width), forceWidth(toSMT(den), isSigned(den), width))
+ val signed = isSigned(num)
+ val resWidth = if (signed) { getWidth(num) + 1 }
+ else { getWidth(num) }
+ val op = if (signed) { Op.SignedDiv }
+ else { Op.UnsignedDiv }
+ // we do the calculation on the widened values and then narrow the result if needed
+ val width = args.map(getWidth).max + (if (signed) 1 else 0)
+ val res = BVOp(op, toSMT(num, width), toSMT(den, width))
+ forceWidth(res, signed, resWidth, allowNarrow = true)
case (PrimOps.Rem, Seq(num, den), _) =>
- val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem
+ val signed = isSigned(num)
+ val op = if (signed) Op.SignedRem else Op.UnsignedRem
val width = args.map(getWidth).max
val resWidth = args.map(getWidth).min
val res = BVOp(op, toSMT(num, width), toSMT(den, width))
- if (res.width > resWidth) { BVSlice(res, resWidth - 1, 0) }
- else { res }
+ forceWidth(res, signed, resWidth, allowNarrow = true)
case (PrimOps.Lt, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
diff --git a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
index c9c1c943..7476b20c 100644
--- a/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
+++ b/src/test/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemanticsSpec.scala
@@ -113,6 +113,8 @@ class FirrtlExpressionSemanticsSpec extends AnyFlatSpec {
assert(primop(false, "div", 8, List(8, 4), modelUndef = false) == "udiv(i0, zext(i1, 4))")
assert(primop(true, "div", 8, List(7, 7), modelUndef = false) == "sdiv(sext(i0, 1), sext(i1, 1))")
assert(primop(true, "div", 8, List(7, 4), modelUndef = false) == "sdiv(sext(i0, 1), sext(i1, 4))")
+ // result width is always the width of the numerator, even if the denominator is larger
+ assert(primop(false, "div", 1, List(1, 2), modelUndef = false) == "udiv(zext(i0, 1), i1)[0]")
}
it should "correctly translate the `rem` primitive operation" in {