diff options
| author | Jiuyang Liu | 2021-12-16 09:47:05 +0800 |
|---|---|---|
| committer | GitHub | 2021-12-16 01:47:05 +0000 |
| commit | 214115a4cdbf0714d3d1716035f5eb0dd98cba45 (patch) | |
| tree | 3faf1eef78c35af066a7aba687caab80a2e4d7e6 /src | |
| parent | 4ff431bb5c7978c9915bcd6080a4f27ef12ae607 (diff) | |
BitSet API (#2211)
BitSet is a new experimental parent type for BitPat.
It enables more complex operations on BitPats.
Co-authored-by: Ocean Shen <shenao6626@gmail.com>
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/chisel3/util/BitPat.scala | 210 | ||||
| -rw-r--r-- | src/main/scala/chisel3/util/experimental/decode/decoder.scala | 33 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/util/BitPatSpec.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/util/BitSetSpec.scala | 119 |
4 files changed, 348 insertions, 16 deletions
diff --git a/src/main/scala/chisel3/util/BitPat.scala b/src/main/scala/chisel3/util/BitPat.scala index 4f8ae504..808245de 100644 --- a/src/main/scala/chisel3/util/BitPat.scala +++ b/src/main/scala/chisel3/util/BitPat.scala @@ -117,6 +117,119 @@ object BitPat { } } +package experimental { + object BitSet { + + /** Construct a [[BitSet]] from a sequence of [[BitPat]]. + * All [[BitPat]] must have the same width. + */ + def apply(bitpats: BitPat*): BitSet = { + val bs = new BitSet { def terms = bitpats.flatMap(_.terms).toSet } + // check width + bs.getWidth + bs + } + + /** Empty [[BitSet]]. */ + val empty: BitSet = new BitSet { + def terms = Set() + } + + /** Construct a [[BitSet]] from String. + * each line should be a valid [[BitPat]] string with the same width. + */ + def fromString(str: String): BitSet = { + val bs = new BitSet { def terms = str.split('\n').map(str => BitPat(str)).toSet } + // check width + bs.getWidth + bs + } + } + + /** A Set of [[BitPat]] represents a set of bit vector with mask. */ + sealed trait BitSet { outer => + /** all [[BitPat]] elements in [[terms]] make up this [[BitSet]]. + * all [[terms]] should be have the same width. + */ + def terms: Set[BitPat] + + /** Get specified width of said BitSet */ + def getWidth: Int = { + require(terms.map(_.width).size <= 1, s"All BitPats must be the same size! Got $this") + // set width = 0 if terms is empty. + terms.headOption.map(_.width).getOrElse(0) + } + + import BitPat.bitPatOrder + override def toString: String = terms.toSeq.sorted.mkString("\n") + + /** whether this [[BitSet]] is empty (i.e. no value matches) */ + def isEmpty: Boolean = terms.forall(_.isEmpty) + + /** Check whether this [[BitSet]] overlap with that [[BitSet]], i.e. !(intersect.isEmpty) + * + * @param that [[BitSet]] to be checked. + * @return true if this and that [[BitSet]] have overlap. + */ + def overlap(that: BitSet): Boolean = + !terms.flatMap(a => that.terms.map(b => (a, b))).forall { case (a, b) => !a.overlap(b) } + + /** Check whether this [[BitSet]] covers that (i.e. forall b matches that, b also matches this) + * + * @param that [[BitSet]] to be covered + * @return true if this [[BitSet]] can cover that [[BitSet]] + */ + def cover(that: BitSet): Boolean = + that.subtract(this).isEmpty + + /** Intersect `this` and `that` [[BitSet]]. + * + * @param that [[BitSet]] to be intersected. + * @return a [[BitSet]] containing all elements of `this` that also belong to `that`. + */ + def intersect(that: BitSet): BitSet = + terms + .flatMap(a => that.terms.map(b => a.intersect(b))) + .filterNot(_.isEmpty) + .fold(BitSet.empty)(_.union(_)) + + /** Subtract that from this [[BitSet]]. + * + * @param that subtrahend [[BitSet]]. + * @return a [[BitSet]] containing elements of `this` which are not the elements of `that`. + */ + def subtract(that: BitSet): BitSet = + terms.map { a => + that.terms.map(b => a.subtract(b)).fold(a)(_.intersect(_)) + }.filterNot(_.isEmpty).fold(BitSet.empty)(_.union(_)) + + /** Union this and that [[BitSet]] + * + * @param that [[BitSet]] to union. + * @return a [[BitSet]] containing all elements of `this` and `that`. + */ + def union(that: BitSet): BitSet = new BitSet { + def terms = outer.terms ++ that.terms + } + + /** Test whether two [[BitSet]] matches the same set of value + * + * @note + * This method can be very expensive compared to ordinary == operator between two Objects + * + * @return true if two [[BitSet]] is same. + */ + override def equals(obj: Any): Boolean = { + obj match { + case that: BitSet => this.getWidth == that.getWidth && this.cover(that) && that.cover(this) + case _ => false + } + } + } + +} + + /** Bit patterns are literals with masks, used to represent values with don't * care bits. Equality comparisons will ignore don't care bits. * @@ -126,19 +239,19 @@ object BitPat { * "b10001".U === BitPat("b101??") // evaluates to false.B * }}} */ -sealed class BitPat(val value: BigInt, val mask: BigInt, width: Int) extends SourceInfoDoc { - def getWidth: Int = width +sealed class BitPat(val value: BigInt, val mask: BigInt, val width: Int) extends util.experimental.BitSet with SourceInfoDoc { + import chisel3.util.experimental.BitSet + def terms = Set(this) + + /** + * Get specified width of said BitPat + */ + override def getWidth: Int = width def apply(x: Int): BitPat = macro SourceInfoTransform.xArg def apply(x: Int, y: Int): BitPat = macro SourceInfoTransform.xyArg def === (that: UInt): Bool = macro SourceInfoTransform.thatArg def =/= (that: UInt): Bool = macro SourceInfoTransform.thatArg def ## (that: BitPat): BitPat = macro SourceInfoTransform.thatArg - override def equals(obj: Any): Boolean = { - obj match { - case y: BitPat => value == y.value && mask == y.mask && getWidth == y.getWidth - case _ => false - } - } /** @group SourceInfoTransformMacro */ def do_apply(x: Int)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): BitPat = { @@ -167,14 +280,83 @@ sealed class BitPat(val value: BigInt, val mask: BigInt, width: Int) extends Sou new BitPat((value << that.getWidth) + that.value, (mask << that.getWidth) + that.mask, this.width + that.getWidth) } - /** Generate raw string of a BitPat. */ - def rawString: String = Seq.tabulate(width) { i => + /** Check whether this [[BitPat]] overlap with that [[BitPat]], i.e. !(intersect.isEmpty) + * + * @param that [[BitPat]] to be checked. + * @return true if this and that [[BitPat]] have overlap. + */ + def overlap(that: BitPat): Boolean = ((mask & that.mask) & (value ^ that.value)) == 0 + + /** Check whether this [[BitSet]] covers that (i.e. forall b matches that, b also matches this) + * + * @param that [[BitPat]] to be covered + * @return true if this [[BitSet]] can cover that [[BitSet]] + */ + def cover(that: BitPat): Boolean = (mask & (~that.mask | (value ^ that.value))) == 0 + + /** Intersect `this` and `that` [[BitPat]]. + * + * @param that [[BitPat]] to be intersected. + * @return a [[BitSet]] containing all elements of `this` that also belong to `that`. + */ + def intersect(that: BitPat): BitSet = { + if (!overlap(that)) { + BitSet.empty + } else { + new BitPat(this.value | that.value, this.mask | that.mask, this.width.max(that.width)) + } + } + + /** Subtract a [[BitPat]] from this. + * + * @param that subtrahend [[BitPat]]. + * @return a [[BitSet]] containing elements of `this` which are not the elements of `that`. + */ + def subtract(that: BitPat): BitSet = { + require(width == that.width) + def enumerateBits(mask: BigInt): Seq[BigInt] = { + if (mask == 0) { + Nil + } else { + // bits comes after the first '1' in a number are inverted in its two's complement. + // therefore bit is always the first '1' in x (counting from least significant bit). + val bit = mask & (-mask) + bit +: enumerateBits(mask & ~bit) + } + } + + val intersection = intersect(that) + val omask = this.mask + if (intersection.isEmpty) { + this + } else { + new BitSet { + val terms = + intersection.terms.flatMap { remove => + enumerateBits(~omask & remove.mask).map { bit => + // Only care about higher than current bit in remove + val nmask = (omask | ~(bit - 1)) & remove.mask + val nvalue = (remove.value ^ bit) & nmask + val nwidth = remove.width + new BitPat(nvalue, nmask, nwidth) + } + } + } + } + } + + override def isEmpty: Boolean = false + + /** Generate raw string of a [[BitPat]]. */ + def rawString: String = Seq + .tabulate(width) { i => (value.testBit(width - i - 1), mask.testBit(width - i - 1)) match { - case (true, true) => "1" - case (false, true) => "0" - case (_, false) => "?" + case (true, true) => "1" + case (false, true) => "0" + case (_, false) => "?" + } } - }.mkString + .mkString override def toString = s"BitPat($rawString)" } diff --git a/src/main/scala/chisel3/util/experimental/decode/decoder.scala b/src/main/scala/chisel3/util/experimental/decode/decoder.scala index ee2ece48..e0bf83b2 100644 --- a/src/main/scala/chisel3/util/experimental/decode/decoder.scala +++ b/src/main/scala/chisel3/util/experimental/decode/decoder.scala @@ -5,7 +5,7 @@ package chisel3.util.experimental.decode import chisel3._ import chisel3.experimental.{ChiselAnnotation, annotate} import chisel3.util.{BitPat, pla} -import chisel3.util.experimental.getAnnotations +import chisel3.util.experimental.{BitSet, getAnnotations} import firrtl.annotations.Annotation import logger.LazyLogging @@ -80,4 +80,35 @@ object decoder extends LazyLogging { qmcFallBack(input, truthTable) } } + + + /** Generate a decoder circuit that matches the input to each bitSet. + * + * The resulting circuit functions like the following but is optimized with a logic minifier. + * {{{ + * when(input === bitSets(0)) { output := b000001 } + * .elsewhen (input === bitSets(1)) { output := b000010 } + * .... + * .otherwise { if (errorBit) output := b100000 else output := DontCare } + * }}} + * + * @param input input to the decoder circuit, width should be equal to bitSets.width + * @param bitSets set of ports to be matched, all width should be the equal + * @param errorBit whether generate an additional decode error bit at MSB of output. + * @return decoded wire + */ + def bitset(input: chisel3.UInt, bitSets: Seq[BitSet], errorBit: Boolean = false): chisel3.UInt = + chisel3.util.experimental.decode.decoder( + input, + chisel3.util.experimental.decode.TruthTable.fromString( + { + bitSets.zipWithIndex.flatMap { + case (bs, i) => + bs.terms.map(bp => + s"${bp.rawString}->${if (errorBit) "0"}${"0" * (bitSets.size - i - 1)}1${"0" * i}" + ) + } ++ Seq(s"${if (errorBit) "1"}${"?" * bitSets.size}") + }.mkString("\n") + ) + ) } diff --git a/src/test/scala/chiselTests/util/BitPatSpec.scala b/src/test/scala/chiselTests/util/BitPatSpec.scala index e14b4496..549e8bca 100644 --- a/src/test/scala/chiselTests/util/BitPatSpec.scala +++ b/src/test/scala/chiselTests/util/BitPatSpec.scala @@ -24,7 +24,7 @@ class BitPatSpec extends AnyFlatSpec with Matchers { intercept[IllegalArgumentException]{BitPat("b")} } - it should "contact BitPat via ##" in { + it should "concat BitPat via ##" in { (BitPat.Y(4) ## BitPat.dontCare(3) ## BitPat.N(2)).toString should be (s"BitPat(1111???00)") } diff --git a/src/test/scala/chiselTests/util/BitSetSpec.scala b/src/test/scala/chiselTests/util/BitSetSpec.scala new file mode 100644 index 00000000..8120cc97 --- /dev/null +++ b/src/test/scala/chiselTests/util/BitSetSpec.scala @@ -0,0 +1,119 @@ +package chiselTests.util + +import chisel3.util.experimental.BitSet +import chisel3.util.BitPat +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class BitSetSpec extends AnyFlatSpec with Matchers { + behavior of classOf[BitSet].toString + + it should "reject unequal width when constructing a BitSet" in { + intercept[IllegalArgumentException] { + BitSet.fromString( + """b0010 + |b00010 + |""".stripMargin) + } + } + + it should "return empty subtraction result correctly" in { + val aBitPat = BitPat("b10?") + val bBitPat = BitPat("b1??") + + aBitPat.subtract(bBitPat).isEmpty should be (true) + } + + it should "return nonempty subtraction result correctly" in { + val aBitPat = BitPat("b10?") + val bBitPat = BitPat("b1??") + val cBitPat = BitPat("b11?") + val dBitPat = BitPat("b100") + + val diffBitPat = bBitPat.subtract(aBitPat) + bBitPat.cover(diffBitPat) should be (true) + diffBitPat.equals(cBitPat) should be (true) + + val largerdiffBitPat = bBitPat.subtract(dBitPat) + aBitPat.cover(dBitPat) should be (true) + largerdiffBitPat.cover(diffBitPat) should be (true) + } + + it should "be able to handle complex subtract between BitSet" in { + val aBitSet = BitSet.fromString( + """b?01?0 + |b11111 + |b00000 + |""".stripMargin) + val bBitSet = BitSet.fromString( + """b?1111 + |b?0000 + |""".stripMargin + ) + val expected = BitPat("b?01?0") + + expected.equals(aBitSet.subtract(bBitSet)) should be (true) + } + + it should "be generated from BitPat union" in { + val aBitSet = BitSet.fromString( + """b001?0 + |b000??""".stripMargin) + val aBitPat = BitPat("b000??") + val bBitPat = BitPat("b001?0") + val cBitPat = BitPat("b00000") + aBitPat.cover(cBitPat) should be (true) + aBitSet.cover(bBitPat) should be (true) + + aBitSet.equals(aBitPat.union(bBitPat)) should be (true) + } + + it should "be generated from BitPat subtraction" in { + val aBitSet = BitSet.fromString( + """b001?0 + |b000??""".stripMargin) + val aBitPat = BitPat("b00???") + val bBitPat = BitPat("b001?1") + + aBitSet.equals(aBitPat.subtract(bBitPat)) should be (true) + } + + it should "union two BitSet together" in { + val aBitSet = BitSet.fromString( + """b001?0 + |b001?1 + |""".stripMargin) + val bBitSet = BitSet.fromString( + """b000?? + |b01??? + |""".stripMargin + ) + val cBitPat = BitPat("b0????") + cBitPat.equals(aBitSet.union(bBitSet)) should be (true) + } + + it should "be decoded" in { + import chisel3._ + import chisel3.util.experimental.decode.decoder + // [0 - 256] part into: [0 - 31], [32 - 47, 64 - 127], [192 - 255] + // "0011????" "10??????" is empty to error + chisel3.stage.ChiselStage.emitSystemVerilog(new Module { + val in = IO(Input(UInt(8.W))) + val out = IO(Output(UInt(4.W))) + out := decoder.bitset(in, Seq( + BitSet.fromString( + "b000?????" + ), + BitSet.fromString( + """b0010???? + |b01?????? + |""".stripMargin + ), + BitSet.fromString( + "b11??????" + ) + ), true) + }) + } + +} |
