diff options
| author | Jiuyang Liu | 2021-12-15 15:53:52 +0800 |
|---|---|---|
| committer | GitHub | 2021-12-15 07:53:52 +0000 |
| commit | 36506c527ff0f51636beee4160f0ce1f6ad2f90a (patch) | |
| tree | ef6f708959d6f115154d76ddd6216a7ba288a01f | |
| parent | 7e8ec50376f852d5ab35d7609d986c7e4128abb1 (diff) | |
Refactor TruthTable to use Seq (#2217)
This makes the resulting Verilog from decoding a TruthTable deterministic.
8 files changed, 65 insertions, 40 deletions
diff --git a/integration-tests/src/test/scala/chiselTests/util/experimental/DecoderSpec.scala b/integration-tests/src/test/scala/chiselTests/util/experimental/DecoderSpec.scala index c31fdee0..2d50555e 100644 --- a/integration-tests/src/test/scala/chiselTests/util/experimental/DecoderSpec.scala +++ b/integration-tests/src/test/scala/chiselTests/util/experimental/DecoderSpec.scala @@ -10,7 +10,7 @@ import chiseltest._ import chiseltest.formal._ class DecoderSpec extends AnyFlatSpec with ChiselScalatestTester with Formal { - val xor = TruthTable( + val xor = TruthTable.fromString( """10->1 |01->1 | 0""".stripMargin) @@ -42,7 +42,7 @@ class DecoderSpec extends AnyFlatSpec with ChiselScalatestTester with Formal { """10->1 |01->1 | 0""".stripMargin, - QMCMinimizer.minimize(TruthTable( + QMCMinimizer.minimize(TruthTable.fromString( """10->1 |01->1 | 0""".stripMargin)).toString diff --git a/integration-tests/src/test/scala/chiselTests/util/experimental/minimizer/MinimizerSpec.scala b/integration-tests/src/test/scala/chiselTests/util/experimental/minimizer/MinimizerSpec.scala index a932eb77..07afd074 100644 --- a/integration-tests/src/test/scala/chiselTests/util/experimental/minimizer/MinimizerSpec.scala +++ b/integration-tests/src/test/scala/chiselTests/util/experimental/minimizer/MinimizerSpec.scala @@ -19,7 +19,7 @@ class DecodeTestModule(minimizer: Minimizer, table: TruthTable) extends Module { chisel3.experimental.verification.assert( // for each instruction, if input matches, output should match, not no matched, fallback to default (table.table.map { case (key, value) => (i === key) && (minimizedO === value) } ++ - Seq(table.table.keys.map(i =/= _).reduce(_ && _) && minimizedO === table.default)).reduce(_ || _) + Seq(table.table.map(_._1).map(i =/= _).reduce(_ && _) && minimizedO === table.default)).reduce(_ || _) ) } diff --git a/src/main/scala/chisel3/util/BitPat.scala b/src/main/scala/chisel3/util/BitPat.scala index 4b94879f..4f8ae504 100644 --- a/src/main/scala/chisel3/util/BitPat.scala +++ b/src/main/scala/chisel3/util/BitPat.scala @@ -8,6 +8,12 @@ import chisel3.internal.sourceinfo.{SourceInfo, SourceInfoTransform} object BitPat { + + private[chisel3] implicit val bitPatOrder = new Ordering[BitPat] { + import scala.math.Ordered.orderingToOrdered + def compare(x: BitPat, y: BitPat): Int = (x.getWidth, x.value, x.mask) compare (y.getWidth, y.value, y.mask) + } + /** Parses a bit pattern string into (bits, mask, width). * * @return bits the literal value, with don't cares being 0 diff --git a/src/main/scala/chisel3/util/experimental/decode/EspressoMinimizer.scala b/src/main/scala/chisel3/util/experimental/decode/EspressoMinimizer.scala index 6adf544c..4dcea99e 100644 --- a/src/main/scala/chisel3/util/experimental/decode/EspressoMinimizer.scala +++ b/src/main/scala/chisel3/util/experimental/decode/EspressoMinimizer.scala @@ -55,7 +55,7 @@ object EspressoMinimizer extends Minimizer with LazyLogging { |""".stripMargin ++ (if (defaultType == '1') invertRawTable else rawTable) } - def readTable(espressoTable: String): Map[BitPat, BitPat] = { + def readTable(espressoTable: String) = { def bitPat(espresso: String): BitPat = BitPat("b" + espresso.replace('-', '?')) espressoTable @@ -63,7 +63,6 @@ object EspressoMinimizer extends Minimizer with LazyLogging { .filterNot(_.startsWith(".")) .map(_.split(' ')) .map(row => bitPat(row(0)) -> bitPat(row(1))) - .toMap } val input = writeTable(table) diff --git a/src/main/scala/chisel3/util/experimental/decode/QMCMinimizer.scala b/src/main/scala/chisel3/util/experimental/decode/QMCMinimizer.scala index 8bd8a03e..59120221 100644 --- a/src/main/scala/chisel3/util/experimental/decode/QMCMinimizer.scala +++ b/src/main/scala/chisel3/util/experimental/decode/QMCMinimizer.scala @@ -234,11 +234,11 @@ object QMCMinimizer extends Minimizer { val outputBp = BitPat("b" + "?" * (m - i - 1) + "1" + "?" * i) // Minterms, implicants that makes the output to be 1 - val mint: Seq[Implicant] = table.table.filter { case (_, t) => t.mask.testBit(i) && t.value.testBit(i) }.keys.map(toImplicant).toSeq + val mint: Seq[Implicant] = table.table.filter { case (_, t) => t.mask.testBit(i) && t.value.testBit(i) }.map(_._1).map(toImplicant) // Maxterms, implicants that makes the output to be 0 - val maxt: Seq[Implicant] = table.table.filter { case (_, t) => t.mask.testBit(i) && !t.value.testBit(i) }.keys.map(toImplicant).toSeq + val maxt: Seq[Implicant] = table.table.filter { case (_, t) => t.mask.testBit(i) && !t.value.testBit(i) }.map(_._1).map(toImplicant) // Don't cares, implicants that can produce either 0 or 1 as output - val dc: Seq[Implicant] = table.table.filter { case (_, t) => !t.mask.testBit(i) }.keys.map(toImplicant).toSeq + val dc: Seq[Implicant] = table.table.filter { case (_, t) => !t.mask.testBit(i) }.map(_._1).map(toImplicant) val (implicants, defaultToDc) = table.default match { case x if x.mask.testBit(i) && !x.value.testBit(i) => // default to 0 @@ -291,7 +291,7 @@ object QMCMinimizer extends Minimizer { (essentialPrimeImplicants ++ getCover(nonessentialPrimeImplicants, uncoveredImplicants)).map(a => (a.bp, outputBp)) }) - minimized.tail.foldLeft(table.copy(table = Map(minimized.head))) { case (tb, t) => + minimized.tail.foldLeft(table.copy(table = Seq(minimized.head))) { case (tb, t) => if (tb.table.exists(x => x._1 == t._1)) { tb.copy(table = tb.table.map { case (k, v) => if (k == t._1) { @@ -300,7 +300,7 @@ object QMCMinimizer extends Minimizer { } else (k, v) }) } else { - tb.copy(table = tb.table + t) + tb.copy(table = tb.table :+ t) } } } diff --git a/src/main/scala/chisel3/util/experimental/decode/TruthTable.scala b/src/main/scala/chisel3/util/experimental/decode/TruthTable.scala index 683de16b..322466f9 100644 --- a/src/main/scala/chisel3/util/experimental/decode/TruthTable.scala +++ b/src/main/scala/chisel3/util/experimental/decode/TruthTable.scala @@ -4,8 +4,7 @@ package chisel3.util.experimental.decode import chisel3.util.BitPat -final class TruthTable(val table: Map[BitPat, BitPat], val default: BitPat) { - +sealed class TruthTable private (val table: Seq[(BitPat, BitPat)], val default: BitPat, val sort: Boolean) { def inputWidth = table.head._1.getWidth def outputWidth = table.head._2.getWidth @@ -14,10 +13,10 @@ final class TruthTable(val table: Map[BitPat, BitPat], val default: BitPat) { def writeRow(map: (BitPat, BitPat)): String = s"${map._1.rawString}->${map._2.rawString}" - (table.map(writeRow) ++ Seq(s"${" "*(inputWidth + 2)}${default.rawString}")).toSeq.sorted.mkString("\n") + (table.map(writeRow) ++ Seq(s"${" "*(inputWidth + 2)}${default.rawString}")).mkString("\n") } - def copy(table: Map[BitPat, BitPat] = this.table, default: BitPat = this.default) = new TruthTable(table, default) + def copy(table: Seq[(BitPat, BitPat)] = this.table, default: BitPat = this.default, sort: Boolean = this.sort) = TruthTable(table, default, sort) override def equals(y: Any): Boolean = { y match { @@ -28,25 +27,12 @@ final class TruthTable(val table: Map[BitPat, BitPat], val default: BitPat) { } object TruthTable { - /** Parse TruthTable from its string representation. */ - def apply(tableString: String): TruthTable = { - TruthTable( - tableString - .split("\n") - .filter(_.contains("->")) - .map(_.split("->").map(str => BitPat(s"b$str"))) - .map(bps => bps(0) -> bps(1)) - .toSeq, - BitPat(s"b${tableString.split("\n").filterNot(_.contains("->")).head.replace(" ", "")}") - ) - } - /** Convert a table and default output into a [[TruthTable]]. */ - def apply(table: Iterable[(BitPat, BitPat)], default: BitPat): TruthTable = { + def apply(table: Iterable[(BitPat, BitPat)], default: BitPat, sort: Boolean = true): TruthTable = { require(table.map(_._1.getWidth).toSet.size == 1, "input width not equal.") require(table.map(_._2.getWidth).toSet.size == 1, "output width not equal.") val outputWidth = table.map(_._2.getWidth).head - new TruthTable(table.toSeq.groupBy(_._1.toString).map { case (key, values) => + val mergedTable = table.groupBy(_._1.toString).map { case (key, values) => // merge same input inputs. values.head._1 -> BitPat(s"b${ Seq.tabulate(outputWidth) { i => @@ -59,9 +45,23 @@ object TruthTable { outputSet.headOption.getOrElse('?') }.mkString }") - }, default) + }.toSeq + import BitPat.bitPatOrder + new TruthTable(if(sort) mergedTable.sorted else mergedTable, default, sort) } + /** Parse TruthTable from its string representation. */ + def fromString(tableString: String): TruthTable = { + TruthTable( + tableString + .split("\n") + .filter(_.contains("->")) + .map(_.split("->").map(str => BitPat(s"b$str"))) + .map(bps => bps(0) -> bps(1)) + .toSeq, + BitPat(s"b${tableString.split("\n").filterNot(_.contains("->")).head.replace(" ", "")}") + ) + } /** consume 1 table, split it into up to 3 tables with the same default bits. * @@ -99,18 +99,17 @@ object TruthTable { tables: Seq[(TruthTable, Seq[Int])] ): TruthTable = { def reIndex(bitPat: BitPat, table: TruthTable, indexes: Seq[Int]): Seq[(Char, Int)] = - (table.table.map(a => a._1.toString -> a._2).getOrElse(bitPat.toString, BitPat.dontCare(indexes.size))).rawString.zip(indexes) + table.table.map(a => a._1.toString -> a._2).collectFirst{ case (k, v) if k == bitPat.toString => v}.getOrElse(BitPat.dontCare(indexes.size)).rawString.zip(indexes) def bitPat(indexedChar: Seq[(Char, Int)]) = BitPat(s"b${indexedChar .sortBy(_._2) .map(_._1) .mkString}") TruthTable( tables - .flatMap(_._1.table.keys) + .flatMap(_._1.table.map(_._1)) .map { key => key -> bitPat(tables.flatMap { case (table, indexes) => reIndex(key, table, indexes) }) - } - .toMap, + }, bitPat(tables.flatMap { case (table, indexes) => table.default.rawString.zip(indexes) }) ) } diff --git a/src/main/scala/chisel3/util/experimental/decode/decoder.scala b/src/main/scala/chisel3/util/experimental/decode/decoder.scala index 42e374d1..ee2ece48 100644 --- a/src/main/scala/chisel3/util/experimental/decode/decoder.scala +++ b/src/main/scala/chisel3/util/experimental/decode/decoder.scala @@ -19,7 +19,7 @@ object decoder extends LazyLogging { */ def apply(minimizer: Minimizer, input: UInt, truthTable: TruthTable): UInt = { val minimizedTable = getAnnotations().collect { - case DecodeTableAnnotation(_, in, out) => TruthTable(in) -> TruthTable(out) + case DecodeTableAnnotation(_, in, out) => TruthTable.fromString(in) -> TruthTable.fromString(out) }.toMap.getOrElse(truthTable, minimizer.minimize(truthTable)) if (minimizedTable.table.isEmpty) { val outputs = Wire(UInt(minimizedTable.default.getWidth.W)) diff --git a/src/test/scala/chiselTests/util/experimental/TruthTableSpec.scala b/src/test/scala/chiselTests/util/experimental/TruthTableSpec.scala index 743a3cd8..255effaf 100644 --- a/src/test/scala/chiselTests/util/experimental/TruthTableSpec.scala +++ b/src/test/scala/chiselTests/util/experimental/TruthTableSpec.scala @@ -2,8 +2,9 @@ package chiselTests.util.experimental +import chisel3._ import chisel3.util.BitPat -import chisel3.util.experimental.decode.TruthTable +import chisel3.util.experimental.decode.{TruthTable, decoder} import org.scalatest.flatspec.AnyFlatSpec class TruthTableSpec extends AnyFlatSpec { @@ -34,16 +35,16 @@ class TruthTableSpec extends AnyFlatSpec { assert(table.toString contains " 0") } "TruthTable" should "deserialize" in { - assert(TruthTable(str) == table) + assert(TruthTable.fromString(str) == table) } "TruthTable" should "merge same key" in { assert( - TruthTable( + TruthTable.fromString( """001100->??1 |001100->1?? |??? |""".stripMargin - ) == TruthTable( + ) == TruthTable.fromString( """001100->1?1 |??? |""".stripMargin @@ -52,7 +53,7 @@ class TruthTableSpec extends AnyFlatSpec { } "TruthTable" should "crash when merging 0 and 1" in { intercept[IllegalArgumentException] { - TruthTable( + TruthTable.fromString( """0->0 |0->1 |??? @@ -60,4 +61,24 @@ class TruthTableSpec extends AnyFlatSpec { ) } } + "TruthTable" should "be reproducible" in { + class Foo extends Module { + + val io = IO(new Bundle{ + val in = Input(UInt(4.W)) + val out = Output(UInt(16.W)) + }) + + + val table = TruthTable( + (0 until 16).map{ + i => BitPat(i.U(4.W)) -> BitPat((1<<i).U(16.W)) + }, + BitPat.dontCare(16) + ) + + io.out := decoder.qmc(io.in, table) + } + assert(chisel3.stage.ChiselStage.emitChirrtl(new Foo) == chisel3.stage.ChiselStage.emitChirrtl(new Foo)) + } } |
