diff options
| author | mergify[bot] | 2022-06-01 20:32:31 +0000 |
|---|---|---|
| committer | GitHub | 2022-06-01 20:32:31 +0000 |
| commit | 97fde23f666a560d4eba9333e4230f901d7f5361 (patch) | |
| tree | b8434cba3666491dc59aa323dce399e77cb7a576 | |
| parent | 0c811b490f47f20f2e81c58706924e56611b6ba2 (diff) | |
Add formatted Printable interpolator `cf` (#2528) (#2553)
This is a formatted version of the p"..." interpolator analogous to
Scala's f"..." interpolator. The primary difference is that it supports
formatting interpolated variables by following the variable with
"%<specifier>". For example:
printf(cf"myWire = $myWire%x\n")
This will format the hardware value "myWire" as a hexidecimal value in
the emitted Verilog. Note that literal "%" must be escaped as "%%".
Scala types and format specifiers are supported and are handled in the
same manner as in standard Scala f"..." interpolators.
(cherry picked from commit 037f7b2ff3a46184d1b82e1b590a7572bfa6a76b)
Co-authored-by: Girish Pai <girish.pai@sifive.com>
| -rw-r--r-- | core/src/main/scala/chisel3/Printable.scala | 77 | ||||
| -rw-r--r-- | core/src/main/scala/chisel3/package.scala | 153 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/PrintableSpec.scala | 207 |
3 files changed, 350 insertions, 87 deletions
diff --git a/core/src/main/scala/chisel3/Printable.scala b/core/src/main/scala/chisel3/Printable.scala index a616f2b0..78655517 100644 --- a/core/src/main/scala/chisel3/Printable.scala +++ b/core/src/main/scala/chisel3/Printable.scala @@ -63,57 +63,76 @@ object Printable { */ def pack(fmt: String, data: Data*): Printable = { val args = data.toIterator - // Error handling def carrotAt(index: Int) = (" " * index) + "^" def errorMsg(index: Int) = s"""| fmt = "$fmt" | ${carrotAt(index)} | data = ${data.mkString(", ")}""".stripMargin - def getArg(i: Int): Data = { + + def checkArg(i: Int): Unit = { if (!args.hasNext) { val msg = "has no matching argument!\n" + errorMsg(i) // Exception wraps msg in s"Format Specifier '$msg'" throw new MissingFormatArgumentException(msg) } - args.next() + val _ = args.next() } + var iter = 0 + var curr_start = 0 + val buf = mutable.ListBuffer.empty[String] + while (iter < fmt.size) { + // Encountered % which is either + // 1. Describing a format specifier. + // 2. Literal Percent + // 3. Dangling percent - most likely due to a typo - intended literal percent or forgot the specifier. + // Try to give meaningful error reports + if (fmt(iter) == '%') { + if (iter != fmt.size - 1 && (fmt(iter + 1) != '%' && !fmt(iter + 1).isWhitespace)) { + checkArg(iter) + buf += fmt.substring(curr_start, iter) + curr_start = iter + iter += 1 + } - val pables = mutable.ListBuffer.empty[Printable] - var str = "" - var percent = false - for ((c, i) <- fmt.zipWithIndex) { - if (percent) { - val arg = c match { - case FirrtlFormat(x) => FirrtlFormat(x.toString, getArg(i)) - case 'n' => Name(getArg(i)) - case 'N' => FullName(getArg(i)) - case '%' => Percent - case x => - val msg = s"Illegal format specifier '$x'!\n" + errorMsg(i) - throw new UnknownFormatConversionException(msg) + // Last character is %. + else if (iter == fmt.size - 1) { + val msg = s"Trailing %\n" + errorMsg(fmt.size - 1) + throw new UnknownFormatConversionException(msg) + } + + // A lone % + else if (fmt(iter + 1).isWhitespace) { + val msg = s"Unescaped % - add % if literal or add proper specifier if not\n" + errorMsg(iter + 1) + throw new UnknownFormatConversionException(msg) + } + + // A literal percent - hence increment by 2. + else { + iter += 2 } - pables += PString(str.dropRight(1)) // remove format % - pables += arg - str = "" - percent = false - } else { - str += c - percent = c == '%' } - } - if (percent) { - val msg = s"Trailing %\n" + errorMsg(fmt.size - 1) - throw new UnknownFormatConversionException(msg) + + // Normal progression + else { + iter += 1 + } } require( !args.hasNext, s"Too many arguments! More format specifier(s) expected!\n" + errorMsg(fmt.size) ) + buf += fmt.substring(curr_start, iter) + + // The string received as an input to pack is already + // treated i.e. escape sequences are processed. + // Since StringContext API assumes the parts are un-treated + // treatEscapes is called within the implemented custom interpolators. + // The literal \ needs to be escaped before sending to the custom cf interpolator. - pables += PString(str) - Printables(pables) + val bufEscapeBackSlash = buf.map(_.replace("\\", "\\\\")) + StringContext(bufEscapeBackSlash.toSeq: _*).cf(data: _*) } } diff --git a/core/src/main/scala/chisel3/package.scala b/core/src/main/scala/chisel3/package.scala index bd088e21..5521c51e 100644 --- a/core/src/main/scala/chisel3/package.scala +++ b/core/src/main/scala/chisel3/package.scala @@ -1,6 +1,8 @@ // SPDX-License-Identifier: Apache-2.0 import chisel3.internal.firrtl.BinaryPoint +import java.util.{MissingFormatArgumentException, UnknownFormatConversionException} +import scala.collection.mutable /** This package contains the main chisel3 API. */ @@ -210,29 +212,142 @@ package object chisel3 { implicit class PrintableHelper(val sc: StringContext) extends AnyVal { /** Custom string interpolator for generating Printables: p"..." - * Will call .toString on any non-Printable arguments (mimicking s"...") + * mimicks s"..." for non-Printable data) */ def p(args: Any*): Printable = { - sc.checkLengths(args) // Enforce sc.parts.size == pargs.size + 1 - val pargs: Seq[Option[Printable]] = args.map { - case p: Printable => Some(p) - case d: Data => Some(d.toPrintable) - case any => - for { - v <- Option(any) // Handle null inputs - str = v.toString - if !str.isEmpty // Handle empty Strings - } yield PString(str) + // P interpolator does not treat % differently - hence need to add % before sending to cf. + val t = sc.parts.map(_.replaceAll("%", "%%")) + StringContext(t: _*).cf(args: _*) + } + + /** Custom string interpolator for generating formatted Printables : cf"..." + * + * Enhanced version of scala's `f` interpolator. + * Each expression (argument) referenced within the string is + * converted to a particular Printable depending + * on the format specifier and type. + * + * ==== For Chisel types referenced within the String ==== + * + * - <code>%n</code> - Returns [[Name]] Printable. + * - <code>%N</code> - Returns [[FullName]] Printable. + * - <code>%b,%d,%x,%c</code> - Only applicable for types of [[Bits]] or dreived from it. - returns ([[Binary]],[[Decimal]], + * [[Hexadecimal]],[[Character]]) Printable respectively. + * - Default - If no specifier given call [[Data.toPrintable]] on the Chisel Type. + * + * ==== For [[Printable]] type: ==== + * No explicit format specifier supported - just return the Printable. + * + * ==== For regular scala types ==== + * Call String.format with the argument and specifier. + * Default is %s if no specifier is given. + * Wrap the result in [[PString]] Printable. + * + * ==== For the parts of the StringContext ==== + * Remove format specifiers and if literal percents (need to be escaped with %) + * are present convert them into [[Percent]] Printable. + * Rest of the string will be wrapped in [[PString]] Printable. + * + * @example + * {{{ + * + * val w1 = 20.U // Chisel UInt type (which extends Bits) + * val f1 = 30.2 // Scala float type. + * val pable = cf"w1 = $w1%x f1 = $f1%2.2f. This is 100%% clear" + * + * // pable is as follows + * // Printables(List(PString(w1 = ), Hexadecimal(UInt<5>(20)), PString( f1 = ), PString(30.20), PString(. This is 100), Percent, PString( clear))) + * }}} + * + * @throws UnknownFormatConversionException + * if literal percent not escaped with % or if the format specifier is not supported + * for the specific type + * + * @throws StringContext.InvalidEscapeException + * if a `parts` string contains a backslash (`\`) character + * that does not start a valid escape sequence. + * + * @throws IllegalArgumentException + * if the number of `parts` in the enclosing `StringContext` does not exceed + * the number of arguments `arg` by exactly 1. + */ + def cf(args: Any*): Printable = { + + // Handle literal % + // Takes the part string - + // - this is assumed to not have any format specifiers - already handled / removed before calling this function. + // Only thing present is literal % if any which should ideally be with %%. + // If not - then flag an error. + // Return seq of Printables (either PString or Percent or both - nothing else + def percentSplitter(s: String): Seq[Printable] = { + if (s.isEmpty) Seq(PString("")) + else { + val pieces = s.split("%%").toList.flatMap { p => + if (p.contains('%')) throw new UnknownFormatConversionException("Un-escaped % found") + // Wrap in PString and intersperse the escaped percentages + Seq(Percent, PString(p)) + } + if (pieces.isEmpty) Seq(Percent) + else pieces.tail // Don't forget to drop the extra percent we put at the beginning + } } + + def extractFormatSpecifier(part: String): (Option[String], String) = { + // Check if part starts with a format specifier (with % - disambiguate with literal % checking the next character if needed to be %) + // In the case of %f specifier there is a chance that we need more information - so capture till the 1st letter (a-zA-Z). + // Example cf"This is $val%2.2f here" - parts - Seq("This is ","%2.2f here") - the format specifier here is %2.2f. + val endFmtIdx = + if (part.length > 1 && part(0) == '%' && part(1) != '%') part.indexWhere(_.isLetter) + else -1 + val (fmt, rest) = part.splitAt(endFmtIdx + 1) + + val fmtOpt = if (fmt.nonEmpty) Some(fmt) else None + (fmtOpt, rest) + + } + + sc.checkLengths(args) // Enforce sc.parts.size == pargs.size + 1 val parts = sc.parts.map(StringContext.treatEscapes) - // Zip sc.parts and pargs together ito flat Seq - // eg. Seq(sc.parts(0), pargs(0), sc.parts(1), pargs(1), ...) - val seq = for { // append None because sc.parts.size == pargs.size + 1 - (literal, arg) <- parts.zip(pargs :+ None) - optPable <- Seq(Some(PString(literal)), arg) - pable <- optPable // Remove Option[_] - } yield pable - Printables(seq) + // The 1st part is assumed never to contain a format specifier. + // If the 1st part of a string is an argument - then the 1st part will be an empty String. + // So we need to parse parts following the 1st one to get the format specifiers if any + val partsAfterFirst = parts.tail + + // Align parts to their potential specifiers + val pables = partsAfterFirst.zip(args).flatMap { + case (part, arg) => { + val (fmt, modP) = extractFormatSpecifier(part) + val fmtArg: Printable = arg match { + case d: Data => { + fmt match { + case Some("%n") => Name(d) + case Some("%N") => FullName(d) + case Some(fForm) if d.isInstanceOf[Bits] => FirrtlFormat(fForm.substring(1, 2), d) + case Some(x) => { + val msg = s"Illegal format specifier '$x' for Chisel Data type!\n" + throw new UnknownFormatConversionException(msg) + } + case None => d.toPrintable + } + } + case p: Printable => { + fmt match { + case Some(x) => { + val msg = s"Illegal format specifier '$x' for Chisel Printable type!\n" + throw new UnknownFormatConversionException(msg) + } + case None => p + } + } + + // Generic case - use String.format (for example %d,%2.2f etc on regular Scala types) + case t => PString(fmt.getOrElse("%s").format(t)) + + } + Seq(fmtArg) ++ percentSplitter(modP) + } + } + Printables(percentSplitter(parts.head) ++ pables) } } diff --git a/src/test/scala/chiselTests/PrintableSpec.scala b/src/test/scala/chiselTests/PrintableSpec.scala index 7d584cea..8039918d 100644 --- a/src/test/scala/chiselTests/PrintableSpec.scala +++ b/src/test/scala/chiselTests/PrintableSpec.scala @@ -9,7 +9,8 @@ import chisel3.testers.BasicTester import firrtl.annotations.{ReferenceTarget, SingleTargetAnnotation} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers - +import chisel3.util._ +import org.scalactic.source.Position import java.io.File /** Dummy [[printf]] annotation. @@ -32,7 +33,7 @@ object PrintfAnnotation { } /* Printable Tests */ -class PrintableSpec extends AnyFlatSpec with Matchers { +class PrintableSpec extends AnyFlatSpec with Matchers with Utils { // This regex is brittle, it specifically finds the clock and enable signals followed by commas private val PrintfRegex = """\s*printf\(\w+, [^,]+,(.*)\).*""".r private val StringRegex = """([^"]*)"(.*?)"(.*)""".r @@ -47,7 +48,6 @@ class PrintableSpec extends AnyFlatSpec with Matchers { case _ => fail(s"Regex to process Printf should work on $str!") } } - firrtl.split("\n").collect { case PrintfRegex(matched) => val (str, args) = processBody(matched) @@ -55,26 +55,34 @@ class PrintableSpec extends AnyFlatSpec with Matchers { } } + // Generates firrtl, gets Printfs + // Calls fail() if failed match; else calls the partial function which could have its own check + private def generateAndCheck(gen: => RawModule)(check: PartialFunction[Seq[Printf], Unit])(implicit pos: Position) = { + val firrtl = ChiselStage.emitChirrtl(gen) + val printfs = getPrintfs(firrtl) + if (!check.isDefinedAt(printfs)) { + fail() + } else { + check(printfs) + } + } + behavior.of("Printable & Custom Interpolator") it should "pass exact strings through" in { class MyModule extends BasicTester { printf(p"An exact string") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("An exact string", Seq())) => - case e => fail() } } it should "handle Printable and String concatination" in { class MyModule extends BasicTester { printf(p"First " + PString("Second ") + "Third") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("First Second Third", Seq())) => - case e => fail() } } it should "call toString on non-Printable objects" in { @@ -82,10 +90,8 @@ class PrintableSpec extends AnyFlatSpec with Matchers { val myInt = 1234 printf(p"myInt = $myInt") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("myInt = 1234", Seq())) => - case e => fail() } } it should "generate proper printf for simple Decimal printing" in { @@ -93,41 +99,33 @@ class PrintableSpec extends AnyFlatSpec with Matchers { val myWire = WireDefault(1234.U) printf(p"myWire = ${Decimal(myWire)}") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("myWire = %d", Seq("myWire"))) => - case e => fail() } } it should "handle printing literals" in { class MyModule extends BasicTester { printf(Decimal(10.U(32.W))) } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("%d", Seq(lit))) => assert(lit contains "UInt<32>") - case e => fail() } } it should "correctly escape percent" in { class MyModule extends BasicTester { printf(p"%") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("%%", Seq())) => - case e => fail() } } it should "correctly emit tab" in { class MyModule extends BasicTester { printf(p"\t") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("\\t", Seq())) => - case e => fail() } } it should "support names of circuit elements including submodule IO" in { @@ -149,10 +147,8 @@ class PrintableSpec extends AnyFlatSpec with Matchers { printf(p"${FullName(myWire.foo)}") printf(p"${FullName(myInst.io.fizz)}") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("foo", Seq()), Printf("myWire.foo", Seq()), Printf("myInst.io.fizz", Seq())) => - case e => fail() } } it should "handle printing ports of submodules" in { @@ -165,10 +161,8 @@ class PrintableSpec extends AnyFlatSpec with Matchers { val myInst = Module(new MySubModule) printf(p"${myInst.io.fizz}") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("%d", Seq("myInst.io.fizz"))) => - case e => fail() } } it should "print UInts and SInts as Decimal by default" in { @@ -177,10 +171,8 @@ class PrintableSpec extends AnyFlatSpec with Matchers { val mySInt = WireDefault(-1.S) printf(p"$myUInt & $mySInt") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("%d & %d", Seq("myUInt", "mySInt"))) => - case e => fail() } } it should "print Vecs like Scala Seqs by default" in { @@ -189,10 +181,8 @@ class PrintableSpec extends AnyFlatSpec with Matchers { myVec.foreach(_ := 0.U) printf(p"$myVec") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("Vec(%d, %d, %d, %d)", Seq("myVec[0]", "myVec[1]", "myVec[2]", "myVec[3]"))) => - case e => fail() } } it should "print Bundles like Scala Maps by default" in { @@ -205,10 +195,8 @@ class PrintableSpec extends AnyFlatSpec with Matchers { myBun.bar := 0.U printf(p"$myBun") } - val firrtl = ChiselStage.emitChirrtl(new MyModule) - getPrintfs(firrtl) match { + generateAndCheck(new MyModule) { case Seq(Printf("AnonymousBundle(foo -> %d, bar -> %d)", Seq("myBun.foo", "myBun.bar"))) => - case e => fail() } } it should "get emitted with a name and annotated" in { @@ -261,4 +249,145 @@ class PrintableSpec extends AnyFlatSpec with Matchers { """printf(clock, UInt<1>("h1"), "adieu AnonymousBundle(foo -> %d, bar -> %d)", myBun.foo, myBun.bar) : farewell""" ) } + + // Unit tests for cf + it should "print regular scala variables with cf format specifier" in { + + class MyModule extends BasicTester { + val f1 = 20.4517 + val i1 = 10 + val str1 = "String!" + printf( + cf"F1 = $f1 D1 = $i1 F1 formatted = $f1%2.2f str1 = $str1%s i1_str = $i1%s i1_hex=$i1%x" + ) + + } + + generateAndCheck(new MyModule) { + case Seq(Printf("F1 = 20.4517 D1 = 10 F1 formatted = 20.45 str1 = String! i1_str = 10 i1_hex=a", Seq())) => + } + } + + it should "print chisel bits with cf format specifier" in { + + class MyBundle extends Bundle { + val foo = UInt(32.W) + val bar = UInt(32.W) + override def toPrintable: Printable = { + cf"Bundle : " + + cf"Foo : $foo%x Bar : $bar%x" + } + } + class MyModule extends BasicTester { + val b1 = 10.U + val w1 = Wire(new MyBundle) + w1.foo := 5.U + w1.bar := 10.U + printf(cf"w1 = $w1") + } + generateAndCheck(new MyModule) { + case Seq(Printf("w1 = Bundle : Foo : %x Bar : %x", Seq("w1.foo", "w1.bar"))) => + } + } + + it should "support names of circuit elements using format specifier including submodule IO with cf format specifier" in { + // Submodule IO is a subtle issue because the Chisel element has a different + // parent module + class MySubModule extends Module { + val io = IO(new Bundle { + val fizz = UInt(32.W) + }) + } + class MyBundle extends Bundle { + val foo = UInt(32.W) + } + class MyModule extends BasicTester { + override def desiredName: String = "MyModule" + val myWire = Wire(new MyBundle) + val myInst = Module(new MySubModule) + printf(cf"${myWire.foo}%n") + printf(cf"${myWire.foo}%N") + printf(cf"${myInst.io.fizz}%N") + } + generateAndCheck(new MyModule) { + case Seq(Printf("foo", Seq()), Printf("myWire.foo", Seq()), Printf("myInst.io.fizz", Seq())) => + } + } + + it should "correctly print strings after modifier" in { + class MyModule extends BasicTester { + val b1 = 10.U + printf(cf"This is here $b1%x!!!! And should print everything else") + } + generateAndCheck(new MyModule) { + case Seq(Printf("This is here %x!!!! And should print everything else", Seq("UInt<4>(\"ha\")"))) => + } + } + + it should "correctly print strings with a lot of literal %% and different format specifiers for Wires" in { + class MyModule extends BasicTester { + val b1 = 10.U + val b2 = 20.U + printf(cf"%% $b1%x%%$b2%b = ${b1 % b2}%d %%%% Tail String") + } + + generateAndCheck(new MyModule) { + case Seq(Printf("%% %x%%%b = %d %%%% Tail String", Seq(lita, litb, _))) => + assert(lita.contains("UInt<4>") && litb.contains("UInt<5>")) + } + } + + it should "not allow unescaped % in the message" in { + class MyModule extends BasicTester { + printf(cf"This should error out for sure because of % - it should be %%") + } + a[java.util.UnknownFormatConversionException] should be thrownBy { + extractCause[java.util.UnknownFormatConversionException] { + ChiselStage.elaborate { new MyModule } + } + } + } + + it should "allow Printables to be expanded and used" in { + class MyModule extends BasicTester { + val w1 = 20.U + val f1 = 30.2 + val i1 = 14 + val pable = cf"w1 = $w1%b f1 = $f1%2.2f" + printf(cf"Trying to expand printable $pable and mix with i1 = $i1%d") + } + generateAndCheck(new MyModule) { + case Seq(Printf("Trying to expand printable w1 = %b f1 = 30.20 and mix with i1 = 14", Seq(lit))) => + assert(lit.contains("UInt<5>")) + } + } + + it should "fail with a single % in the message" in { + class MyModule extends BasicTester { + printf(cf"%") + } + a[java.util.UnknownFormatConversionException] should be thrownBy { + extractCause[java.util.UnknownFormatConversionException] { + ChiselStage.elaborate { new MyModule } + } + } + } + + it should "fail when passing directly to StirngContext.cf a string with literal \\ correctly escaped " in { + a[StringContext.InvalidEscapeException] should be thrownBy { + extractCause[StringContext.InvalidEscapeException] { + val s_seq = Seq("Test with literal \\ correctly escaped") + StringContext(s_seq: _*).cf(Seq(): _*) + } + } + } + + it should "pass correctly escaped \\ when using Printable.pack" in { + class MyModule extends BasicTester { + printf(Printable.pack("\\ \\]")) + } + generateAndCheck(new MyModule) { + case Seq(Printf("\\\\ \\\\]", Seq())) => + } + } } |
