summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2021-07-01 16:34:48 -0700
committerJack Koenig2021-07-01 18:03:42 -0700
commit5fe539c707c88eedbb112f5c6bcea1dfe1d52169 (patch)
tree8d9bf0d80eec9e003907056fea8b12b6642252dc
parent04caf395c737450c26f59d373d76b567a2b80f0f (diff)
Add ChiselEnum.safe factory method and avoid warning
Previously, ChiselEnum would warn any time a UInt is converted to an Enum. There was no way to suppress this warning. Now there is a factory method (`.safe`) that does not warn and returns (Enum, Bool) where the Bool is the result of calling .isValid on an Enum object. The regular UInt cast is also now smarter and will not warn if all bitvectors of the width of the Enum are legal states.
-rw-r--r--core/src/main/scala/chisel3/StrongEnum.scala34
-rw-r--r--src/test/scala/chiselTests/ChiselSpec.scala17
-rw-r--r--src/test/scala/chiselTests/StrongEnum.scala97
3 files changed, 143 insertions, 5 deletions
diff --git a/core/src/main/scala/chisel3/StrongEnum.scala b/core/src/main/scala/chisel3/StrongEnum.scala
index 7d328eb7..1d0e04d3 100644
--- a/core/src/main/scala/chisel3/StrongEnum.scala
+++ b/core/src/main/scala/chisel3/StrongEnum.scala
@@ -131,7 +131,7 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea
if (litOption.isDefined) {
true.B
} else {
- factory.all.map(this === _).reduce(_ || _)
+ if (factory.isTotal) true.B else factory.all.map(this === _).reduce(_ || _)
}
}
@@ -233,6 +233,12 @@ abstract class EnumFactory {
private[chisel3] val enumTypeName = getClass.getName.init
+ // Do all bitvectors of this Enum's width represent legal states?
+ private[chisel3] def isTotal: Boolean = {
+ (this.getWidth < 31) && // guard against Integer overflow
+ (enumRecords.size == (1 << this.getWidth))
+ }
+
private[chisel3] def globalAnnotation: EnumDefChiselAnnotation =
EnumDefChiselAnnotation(enumTypeName, (enumNames, enumValues).zipped.toMap)
@@ -277,7 +283,7 @@ abstract class EnumFactory {
def apply(): Type = new Type
- def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = {
+ private def castImpl(n: UInt, warn: Boolean)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = {
if (n.litOption.isDefined) {
enumInstances.find(_.litValue == n.litValue) match {
case Some(result) => result
@@ -288,8 +294,9 @@ abstract class EnumFactory {
} else if (n.getWidth > this.getWidth) {
throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)")
} else {
- Builder.warning(s"Casting non-literal UInt to $enumTypeName. You can check that its value is legal by calling isValid")
-
+ if (warn && !this.isTotal) {
+ Builder.warning(s"Casting non-literal UInt to $enumTypeName. You can use $enumTypeName.safe to cast without this warning.")
+ }
val glue = Wire(new UnsafeEnum(width))
glue := n
val result = Wire(new Type)
@@ -297,6 +304,25 @@ abstract class EnumFactory {
result
}
}
+
+ /** Cast an [[UInt]] to the type of this Enum
+ *
+ * @note will give a Chisel elaboration time warning if the argument could hit invalid states
+ * @param n the UInt to cast
+ * @return the equivalent Enum to the value of the cast UInt
+ */
+ def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = castImpl(n, warn = true)
+
+ /** Safely cast an [[UInt]] to the type of this Enum
+ *
+ * @param n the UInt to cast
+ * @return the equivalent Enum to the value of the cast UInt and a Bool indicating if the
+ * Enum is valid
+ */
+ def safe(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): (Type, Bool) = {
+ val t = castImpl(n, warn = false)
+ (t, t.isValid)
+ }
}
diff --git a/src/test/scala/chiselTests/ChiselSpec.scala b/src/test/scala/chiselTests/ChiselSpec.scala
index a4192c5e..e513189e 100644
--- a/src/test/scala/chiselTests/ChiselSpec.scala
+++ b/src/test/scala/chiselTests/ChiselSpec.scala
@@ -9,6 +9,7 @@ import chisel3.testers._
import firrtl.annotations.Annotation
import firrtl.util.BackendCompilationUtilities
import firrtl.{AnnotationSeq, EmittedVerilogCircuitAnnotation}
+import _root_.logger.Logger
import org.scalacheck._
import org.scalatest._
import org.scalatest.flatspec.AnyFlatSpec
@@ -17,7 +18,7 @@ import org.scalatest.propspec.AnyPropSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks
-import java.io.ByteArrayOutputStream
+import java.io.{ByteArrayOutputStream, PrintStream}
import java.security.Permission
import scala.reflect.ClassTag
@@ -172,6 +173,20 @@ trait Utils {
(stdout.toString, stderr.toString, ret)
}
+ /** Run some Scala thunk and return all logged messages as Strings
+ * @param thunk some Scala code
+ * @return a tuple containing LOGGED, and what the thunk returns
+ */
+ def grabLog[T](thunk: => T): (String, T) = {
+ val baos = new ByteArrayOutputStream()
+ val stream = new PrintStream(baos, true, "utf-8")
+ val ret = Logger.makeScope(Nil) {
+ Logger.setOutput(stream)
+ thunk
+ }
+ (baos.toString, ret)
+ }
+
/** Encodes a System.exit exit code
* @param status the exit code
*/
diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala
index bf0eb2fe..e59a5398 100644
--- a/src/test/scala/chiselTests/StrongEnum.scala
+++ b/src/test/scala/chiselTests/StrongEnum.scala
@@ -74,6 +74,18 @@ class CastFromNonLit extends Module {
io.valid := io.out.isValid
}
+class SafeCastFromNonLit extends Module {
+ val io = IO(new Bundle {
+ val in = Input(UInt(EnumExample.getWidth.W))
+ val out = Output(EnumExample())
+ val valid = Output(Bool())
+ })
+
+ val (enum, valid) = EnumExample.safe(io.in)
+ io.out := enum
+ io.valid := valid
+}
+
class CastFromNonLitWidth(w: Option[Int] = None) extends Module {
val width = if (w.isDefined) w.get.W else UnknownWidth()
@@ -191,6 +203,28 @@ class CastFromNonLitTester extends BasicTester {
stop()
}
+class SafeCastFromNonLitTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new SafeCastFromNonLit)
+ mod.io.in := lit
+ assert(mod.io.out === enum)
+ assert(mod.io.valid === true.B)
+ }
+
+ val invalid_values = (1 until (1 << EnumExample.getWidth)).
+ filter(!EnumExample.litValues.map(_.litValue).contains(_)).
+ map(_.U)
+
+ for (invalid_val <- invalid_values) {
+ val mod = Module(new SafeCastFromNonLit)
+ mod.io.in := invalid_val
+
+ assert(mod.io.valid === false.B)
+ }
+
+ stop()
+}
+
class CastToInvalidEnumTester extends BasicTester {
val invalid_value: UInt = EnumExample.litValues.last + 1.U
Module(new CastFromLit(invalid_value))
@@ -320,6 +354,10 @@ class StrongEnumSpec extends ChiselFlatSpec with Utils {
assertTesterPasses(new CastFromNonLitTester)
}
+ it should "safely cast non-literal UInts to enums correctly and detect illegal casts" in {
+ assertTesterPasses(new SafeCastFromNonLitTester)
+ }
+
it should "prevent illegal literal casts to enums" in {
a [ChiselException] should be thrownBy extractCause[ChiselException] {
ChiselStage.elaborate(new CastToInvalidEnumTester)
@@ -377,6 +415,65 @@ class StrongEnumSpec extends ChiselFlatSpec with Utils {
"StrongEnum FSM" should "work" in {
assertTesterPasses(new StrongEnumFSMTester)
}
+
+ "Casting a UInt to an Enum" should "warn if the UInt can express illegal states" in {
+ object MyEnum extends ChiselEnum {
+ val e0, e1, e2 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(MyEnum()))
+ out := MyEnum(in)
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should include ("warn")
+ log should include ("Casting non-literal UInt")
+ }
+
+ it should "NOT warn if the Enum is total" in {
+ object TotalEnum extends ChiselEnum {
+ val e0, e1, e2, e3 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(TotalEnum()))
+ out := TotalEnum(in)
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should not include ("warn")
+ }
+
+ "Casting a UInt to an Enum with .safe" should "NOT warn" in {
+ object MyEnum extends ChiselEnum {
+ val e0, e1, e2 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(MyEnum()))
+ out := MyEnum.safe(in)._1
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should not include ("warn")
+ }
+
+ it should "NOT generate any validity logic if the Enum is total" in {
+ object TotalEnum extends ChiselEnum {
+ val e0, e1, e2, e3 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(TotalEnum()))
+ val (res, valid) = TotalEnum.safe(in)
+ assert(valid.litToBoolean, "It should be true.B")
+ out := res
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should not include ("warn")
+ }
}
class StrongEnumAnnotator extends Module {