summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJiuyang Liu2021-12-16 09:47:05 +0800
committerGitHub2021-12-16 01:47:05 +0000
commit214115a4cdbf0714d3d1716035f5eb0dd98cba45 (patch)
tree3faf1eef78c35af066a7aba687caab80a2e4d7e6
parent4ff431bb5c7978c9915bcd6080a4f27ef12ae607 (diff)
BitSet API (#2211)
BitSet is a new experimental parent type for BitPat. It enables more complex operations on BitPats. Co-authored-by: Ocean Shen <shenao6626@gmail.com>
-rw-r--r--src/main/scala/chisel3/util/BitPat.scala210
-rw-r--r--src/main/scala/chisel3/util/experimental/decode/decoder.scala33
-rw-r--r--src/test/scala/chiselTests/util/BitPatSpec.scala2
-rw-r--r--src/test/scala/chiselTests/util/BitSetSpec.scala119
4 files changed, 348 insertions, 16 deletions
diff --git a/src/main/scala/chisel3/util/BitPat.scala b/src/main/scala/chisel3/util/BitPat.scala
index 4f8ae504..808245de 100644
--- a/src/main/scala/chisel3/util/BitPat.scala
+++ b/src/main/scala/chisel3/util/BitPat.scala
@@ -117,6 +117,119 @@ object BitPat {
}
}
+package experimental {
+ object BitSet {
+
+ /** Construct a [[BitSet]] from a sequence of [[BitPat]].
+ * All [[BitPat]] must have the same width.
+ */
+ def apply(bitpats: BitPat*): BitSet = {
+ val bs = new BitSet { def terms = bitpats.flatMap(_.terms).toSet }
+ // check width
+ bs.getWidth
+ bs
+ }
+
+ /** Empty [[BitSet]]. */
+ val empty: BitSet = new BitSet {
+ def terms = Set()
+ }
+
+ /** Construct a [[BitSet]] from String.
+ * each line should be a valid [[BitPat]] string with the same width.
+ */
+ def fromString(str: String): BitSet = {
+ val bs = new BitSet { def terms = str.split('\n').map(str => BitPat(str)).toSet }
+ // check width
+ bs.getWidth
+ bs
+ }
+ }
+
+ /** A Set of [[BitPat]] represents a set of bit vector with mask. */
+ sealed trait BitSet { outer =>
+ /** all [[BitPat]] elements in [[terms]] make up this [[BitSet]].
+ * all [[terms]] should be have the same width.
+ */
+ def terms: Set[BitPat]
+
+ /** Get specified width of said BitSet */
+ def getWidth: Int = {
+ require(terms.map(_.width).size <= 1, s"All BitPats must be the same size! Got $this")
+ // set width = 0 if terms is empty.
+ terms.headOption.map(_.width).getOrElse(0)
+ }
+
+ import BitPat.bitPatOrder
+ override def toString: String = terms.toSeq.sorted.mkString("\n")
+
+ /** whether this [[BitSet]] is empty (i.e. no value matches) */
+ def isEmpty: Boolean = terms.forall(_.isEmpty)
+
+ /** Check whether this [[BitSet]] overlap with that [[BitSet]], i.e. !(intersect.isEmpty)
+ *
+ * @param that [[BitSet]] to be checked.
+ * @return true if this and that [[BitSet]] have overlap.
+ */
+ def overlap(that: BitSet): Boolean =
+ !terms.flatMap(a => that.terms.map(b => (a, b))).forall { case (a, b) => !a.overlap(b) }
+
+ /** Check whether this [[BitSet]] covers that (i.e. forall b matches that, b also matches this)
+ *
+ * @param that [[BitSet]] to be covered
+ * @return true if this [[BitSet]] can cover that [[BitSet]]
+ */
+ def cover(that: BitSet): Boolean =
+ that.subtract(this).isEmpty
+
+ /** Intersect `this` and `that` [[BitSet]].
+ *
+ * @param that [[BitSet]] to be intersected.
+ * @return a [[BitSet]] containing all elements of `this` that also belong to `that`.
+ */
+ def intersect(that: BitSet): BitSet =
+ terms
+ .flatMap(a => that.terms.map(b => a.intersect(b)))
+ .filterNot(_.isEmpty)
+ .fold(BitSet.empty)(_.union(_))
+
+ /** Subtract that from this [[BitSet]].
+ *
+ * @param that subtrahend [[BitSet]].
+ * @return a [[BitSet]] containing elements of `this` which are not the elements of `that`.
+ */
+ def subtract(that: BitSet): BitSet =
+ terms.map { a =>
+ that.terms.map(b => a.subtract(b)).fold(a)(_.intersect(_))
+ }.filterNot(_.isEmpty).fold(BitSet.empty)(_.union(_))
+
+ /** Union this and that [[BitSet]]
+ *
+ * @param that [[BitSet]] to union.
+ * @return a [[BitSet]] containing all elements of `this` and `that`.
+ */
+ def union(that: BitSet): BitSet = new BitSet {
+ def terms = outer.terms ++ that.terms
+ }
+
+ /** Test whether two [[BitSet]] matches the same set of value
+ *
+ * @note
+ * This method can be very expensive compared to ordinary == operator between two Objects
+ *
+ * @return true if two [[BitSet]] is same.
+ */
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case that: BitSet => this.getWidth == that.getWidth && this.cover(that) && that.cover(this)
+ case _ => false
+ }
+ }
+ }
+
+}
+
+
/** Bit patterns are literals with masks, used to represent values with don't
* care bits. Equality comparisons will ignore don't care bits.
*
@@ -126,19 +239,19 @@ object BitPat {
* "b10001".U === BitPat("b101??") // evaluates to false.B
* }}}
*/
-sealed class BitPat(val value: BigInt, val mask: BigInt, width: Int) extends SourceInfoDoc {
- def getWidth: Int = width
+sealed class BitPat(val value: BigInt, val mask: BigInt, val width: Int) extends util.experimental.BitSet with SourceInfoDoc {
+ import chisel3.util.experimental.BitSet
+ def terms = Set(this)
+
+ /**
+ * Get specified width of said BitPat
+ */
+ override def getWidth: Int = width
def apply(x: Int): BitPat = macro SourceInfoTransform.xArg
def apply(x: Int, y: Int): BitPat = macro SourceInfoTransform.xyArg
def === (that: UInt): Bool = macro SourceInfoTransform.thatArg
def =/= (that: UInt): Bool = macro SourceInfoTransform.thatArg
def ## (that: BitPat): BitPat = macro SourceInfoTransform.thatArg
- override def equals(obj: Any): Boolean = {
- obj match {
- case y: BitPat => value == y.value && mask == y.mask && getWidth == y.getWidth
- case _ => false
- }
- }
/** @group SourceInfoTransformMacro */
def do_apply(x: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): BitPat = {
@@ -167,14 +280,83 @@ sealed class BitPat(val value: BigInt, val mask: BigInt, width: Int) extends Sou
new BitPat((value << that.getWidth) + that.value, (mask << that.getWidth) + that.mask, this.width + that.getWidth)
}
- /** Generate raw string of a BitPat. */
- def rawString: String = Seq.tabulate(width) { i =>
+ /** Check whether this [[BitPat]] overlap with that [[BitPat]], i.e. !(intersect.isEmpty)
+ *
+ * @param that [[BitPat]] to be checked.
+ * @return true if this and that [[BitPat]] have overlap.
+ */
+ def overlap(that: BitPat): Boolean = ((mask & that.mask) & (value ^ that.value)) == 0
+
+ /** Check whether this [[BitSet]] covers that (i.e. forall b matches that, b also matches this)
+ *
+ * @param that [[BitPat]] to be covered
+ * @return true if this [[BitSet]] can cover that [[BitSet]]
+ */
+ def cover(that: BitPat): Boolean = (mask & (~that.mask | (value ^ that.value))) == 0
+
+ /** Intersect `this` and `that` [[BitPat]].
+ *
+ * @param that [[BitPat]] to be intersected.
+ * @return a [[BitSet]] containing all elements of `this` that also belong to `that`.
+ */
+ def intersect(that: BitPat): BitSet = {
+ if (!overlap(that)) {
+ BitSet.empty
+ } else {
+ new BitPat(this.value | that.value, this.mask | that.mask, this.width.max(that.width))
+ }
+ }
+
+ /** Subtract a [[BitPat]] from this.
+ *
+ * @param that subtrahend [[BitPat]].
+ * @return a [[BitSet]] containing elements of `this` which are not the elements of `that`.
+ */
+ def subtract(that: BitPat): BitSet = {
+ require(width == that.width)
+ def enumerateBits(mask: BigInt): Seq[BigInt] = {
+ if (mask == 0) {
+ Nil
+ } else {
+ // bits comes after the first '1' in a number are inverted in its two's complement.
+ // therefore bit is always the first '1' in x (counting from least significant bit).
+ val bit = mask & (-mask)
+ bit +: enumerateBits(mask & ~bit)
+ }
+ }
+
+ val intersection = intersect(that)
+ val omask = this.mask
+ if (intersection.isEmpty) {
+ this
+ } else {
+ new BitSet {
+ val terms =
+ intersection.terms.flatMap { remove =>
+ enumerateBits(~omask & remove.mask).map { bit =>
+ // Only care about higher than current bit in remove
+ val nmask = (omask | ~(bit - 1)) & remove.mask
+ val nvalue = (remove.value ^ bit) & nmask
+ val nwidth = remove.width
+ new BitPat(nvalue, nmask, nwidth)
+ }
+ }
+ }
+ }
+ }
+
+ override def isEmpty: Boolean = false
+
+ /** Generate raw string of a [[BitPat]]. */
+ def rawString: String = Seq
+ .tabulate(width) { i =>
(value.testBit(width - i - 1), mask.testBit(width - i - 1)) match {
- case (true, true) => "1"
- case (false, true) => "0"
- case (_, false) => "?"
+ case (true, true) => "1"
+ case (false, true) => "0"
+ case (_, false) => "?"
+ }
}
- }.mkString
+ .mkString
override def toString = s"BitPat($rawString)"
}
diff --git a/src/main/scala/chisel3/util/experimental/decode/decoder.scala b/src/main/scala/chisel3/util/experimental/decode/decoder.scala
index ee2ece48..e0bf83b2 100644
--- a/src/main/scala/chisel3/util/experimental/decode/decoder.scala
+++ b/src/main/scala/chisel3/util/experimental/decode/decoder.scala
@@ -5,7 +5,7 @@ package chisel3.util.experimental.decode
import chisel3._
import chisel3.experimental.{ChiselAnnotation, annotate}
import chisel3.util.{BitPat, pla}
-import chisel3.util.experimental.getAnnotations
+import chisel3.util.experimental.{BitSet, getAnnotations}
import firrtl.annotations.Annotation
import logger.LazyLogging
@@ -80,4 +80,35 @@ object decoder extends LazyLogging {
qmcFallBack(input, truthTable)
}
}
+
+
+ /** Generate a decoder circuit that matches the input to each bitSet.
+ *
+ * The resulting circuit functions like the following but is optimized with a logic minifier.
+ * {{{
+ * when(input === bitSets(0)) { output := b000001 }
+ * .elsewhen (input === bitSets(1)) { output := b000010 }
+ * ....
+ * .otherwise { if (errorBit) output := b100000 else output := DontCare }
+ * }}}
+ *
+ * @param input input to the decoder circuit, width should be equal to bitSets.width
+ * @param bitSets set of ports to be matched, all width should be the equal
+ * @param errorBit whether generate an additional decode error bit at MSB of output.
+ * @return decoded wire
+ */
+ def bitset(input: chisel3.UInt, bitSets: Seq[BitSet], errorBit: Boolean = false): chisel3.UInt =
+ chisel3.util.experimental.decode.decoder(
+ input,
+ chisel3.util.experimental.decode.TruthTable.fromString(
+ {
+ bitSets.zipWithIndex.flatMap {
+ case (bs, i) =>
+ bs.terms.map(bp =>
+ s"${bp.rawString}->${if (errorBit) "0"}${"0" * (bitSets.size - i - 1)}1${"0" * i}"
+ )
+ } ++ Seq(s"${if (errorBit) "1"}${"?" * bitSets.size}")
+ }.mkString("\n")
+ )
+ )
}
diff --git a/src/test/scala/chiselTests/util/BitPatSpec.scala b/src/test/scala/chiselTests/util/BitPatSpec.scala
index e14b4496..549e8bca 100644
--- a/src/test/scala/chiselTests/util/BitPatSpec.scala
+++ b/src/test/scala/chiselTests/util/BitPatSpec.scala
@@ -24,7 +24,7 @@ class BitPatSpec extends AnyFlatSpec with Matchers {
intercept[IllegalArgumentException]{BitPat("b")}
}
- it should "contact BitPat via ##" in {
+ it should "concat BitPat via ##" in {
(BitPat.Y(4) ## BitPat.dontCare(3) ## BitPat.N(2)).toString should be (s"BitPat(1111???00)")
}
diff --git a/src/test/scala/chiselTests/util/BitSetSpec.scala b/src/test/scala/chiselTests/util/BitSetSpec.scala
new file mode 100644
index 00000000..8120cc97
--- /dev/null
+++ b/src/test/scala/chiselTests/util/BitSetSpec.scala
@@ -0,0 +1,119 @@
+package chiselTests.util
+
+import chisel3.util.experimental.BitSet
+import chisel3.util.BitPat
+import org.scalatest.flatspec.AnyFlatSpec
+import org.scalatest.matchers.should.Matchers
+
+class BitSetSpec extends AnyFlatSpec with Matchers {
+ behavior of classOf[BitSet].toString
+
+ it should "reject unequal width when constructing a BitSet" in {
+ intercept[IllegalArgumentException] {
+ BitSet.fromString(
+ """b0010
+ |b00010
+ |""".stripMargin)
+ }
+ }
+
+ it should "return empty subtraction result correctly" in {
+ val aBitPat = BitPat("b10?")
+ val bBitPat = BitPat("b1??")
+
+ aBitPat.subtract(bBitPat).isEmpty should be (true)
+ }
+
+ it should "return nonempty subtraction result correctly" in {
+ val aBitPat = BitPat("b10?")
+ val bBitPat = BitPat("b1??")
+ val cBitPat = BitPat("b11?")
+ val dBitPat = BitPat("b100")
+
+ val diffBitPat = bBitPat.subtract(aBitPat)
+ bBitPat.cover(diffBitPat) should be (true)
+ diffBitPat.equals(cBitPat) should be (true)
+
+ val largerdiffBitPat = bBitPat.subtract(dBitPat)
+ aBitPat.cover(dBitPat) should be (true)
+ largerdiffBitPat.cover(diffBitPat) should be (true)
+ }
+
+ it should "be able to handle complex subtract between BitSet" in {
+ val aBitSet = BitSet.fromString(
+ """b?01?0
+ |b11111
+ |b00000
+ |""".stripMargin)
+ val bBitSet = BitSet.fromString(
+ """b?1111
+ |b?0000
+ |""".stripMargin
+ )
+ val expected = BitPat("b?01?0")
+
+ expected.equals(aBitSet.subtract(bBitSet)) should be (true)
+ }
+
+ it should "be generated from BitPat union" in {
+ val aBitSet = BitSet.fromString(
+ """b001?0
+ |b000??""".stripMargin)
+ val aBitPat = BitPat("b000??")
+ val bBitPat = BitPat("b001?0")
+ val cBitPat = BitPat("b00000")
+ aBitPat.cover(cBitPat) should be (true)
+ aBitSet.cover(bBitPat) should be (true)
+
+ aBitSet.equals(aBitPat.union(bBitPat)) should be (true)
+ }
+
+ it should "be generated from BitPat subtraction" in {
+ val aBitSet = BitSet.fromString(
+ """b001?0
+ |b000??""".stripMargin)
+ val aBitPat = BitPat("b00???")
+ val bBitPat = BitPat("b001?1")
+
+ aBitSet.equals(aBitPat.subtract(bBitPat)) should be (true)
+ }
+
+ it should "union two BitSet together" in {
+ val aBitSet = BitSet.fromString(
+ """b001?0
+ |b001?1
+ |""".stripMargin)
+ val bBitSet = BitSet.fromString(
+ """b000??
+ |b01???
+ |""".stripMargin
+ )
+ val cBitPat = BitPat("b0????")
+ cBitPat.equals(aBitSet.union(bBitSet)) should be (true)
+ }
+
+ it should "be decoded" in {
+ import chisel3._
+ import chisel3.util.experimental.decode.decoder
+ // [0 - 256] part into: [0 - 31], [32 - 47, 64 - 127], [192 - 255]
+ // "0011????" "10??????" is empty to error
+ chisel3.stage.ChiselStage.emitSystemVerilog(new Module {
+ val in = IO(Input(UInt(8.W)))
+ val out = IO(Output(UInt(4.W)))
+ out := decoder.bitset(in, Seq(
+ BitSet.fromString(
+ "b000?????"
+ ),
+ BitSet.fromString(
+ """b0010????
+ |b01??????
+ |""".stripMargin
+ ),
+ BitSet.fromString(
+ "b11??????"
+ )
+ ), true)
+ })
+ }
+
+}