summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/test/scala/chiselTests/ChiselSpec.scala17
-rw-r--r--src/test/scala/chiselTests/StrongEnum.scala97
2 files changed, 113 insertions, 1 deletions
diff --git a/src/test/scala/chiselTests/ChiselSpec.scala b/src/test/scala/chiselTests/ChiselSpec.scala
index a4192c5e..e513189e 100644
--- a/src/test/scala/chiselTests/ChiselSpec.scala
+++ b/src/test/scala/chiselTests/ChiselSpec.scala
@@ -9,6 +9,7 @@ import chisel3.testers._
import firrtl.annotations.Annotation
import firrtl.util.BackendCompilationUtilities
import firrtl.{AnnotationSeq, EmittedVerilogCircuitAnnotation}
+import _root_.logger.Logger
import org.scalacheck._
import org.scalatest._
import org.scalatest.flatspec.AnyFlatSpec
@@ -17,7 +18,7 @@ import org.scalatest.propspec.AnyPropSpec
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.scalacheck.ScalaCheckPropertyChecks
-import java.io.ByteArrayOutputStream
+import java.io.{ByteArrayOutputStream, PrintStream}
import java.security.Permission
import scala.reflect.ClassTag
@@ -172,6 +173,20 @@ trait Utils {
(stdout.toString, stderr.toString, ret)
}
+ /** Run some Scala thunk and return all logged messages as Strings
+ * @param thunk some Scala code
+ * @return a tuple containing LOGGED, and what the thunk returns
+ */
+ def grabLog[T](thunk: => T): (String, T) = {
+ val baos = new ByteArrayOutputStream()
+ val stream = new PrintStream(baos, true, "utf-8")
+ val ret = Logger.makeScope(Nil) {
+ Logger.setOutput(stream)
+ thunk
+ }
+ (baos.toString, ret)
+ }
+
/** Encodes a System.exit exit code
* @param status the exit code
*/
diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala
index bf0eb2fe..e59a5398 100644
--- a/src/test/scala/chiselTests/StrongEnum.scala
+++ b/src/test/scala/chiselTests/StrongEnum.scala
@@ -74,6 +74,18 @@ class CastFromNonLit extends Module {
io.valid := io.out.isValid
}
+class SafeCastFromNonLit extends Module {
+ val io = IO(new Bundle {
+ val in = Input(UInt(EnumExample.getWidth.W))
+ val out = Output(EnumExample())
+ val valid = Output(Bool())
+ })
+
+ val (enum, valid) = EnumExample.safe(io.in)
+ io.out := enum
+ io.valid := valid
+}
+
class CastFromNonLitWidth(w: Option[Int] = None) extends Module {
val width = if (w.isDefined) w.get.W else UnknownWidth()
@@ -191,6 +203,28 @@ class CastFromNonLitTester extends BasicTester {
stop()
}
+class SafeCastFromNonLitTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new SafeCastFromNonLit)
+ mod.io.in := lit
+ assert(mod.io.out === enum)
+ assert(mod.io.valid === true.B)
+ }
+
+ val invalid_values = (1 until (1 << EnumExample.getWidth)).
+ filter(!EnumExample.litValues.map(_.litValue).contains(_)).
+ map(_.U)
+
+ for (invalid_val <- invalid_values) {
+ val mod = Module(new SafeCastFromNonLit)
+ mod.io.in := invalid_val
+
+ assert(mod.io.valid === false.B)
+ }
+
+ stop()
+}
+
class CastToInvalidEnumTester extends BasicTester {
val invalid_value: UInt = EnumExample.litValues.last + 1.U
Module(new CastFromLit(invalid_value))
@@ -320,6 +354,10 @@ class StrongEnumSpec extends ChiselFlatSpec with Utils {
assertTesterPasses(new CastFromNonLitTester)
}
+ it should "safely cast non-literal UInts to enums correctly and detect illegal casts" in {
+ assertTesterPasses(new SafeCastFromNonLitTester)
+ }
+
it should "prevent illegal literal casts to enums" in {
a [ChiselException] should be thrownBy extractCause[ChiselException] {
ChiselStage.elaborate(new CastToInvalidEnumTester)
@@ -377,6 +415,65 @@ class StrongEnumSpec extends ChiselFlatSpec with Utils {
"StrongEnum FSM" should "work" in {
assertTesterPasses(new StrongEnumFSMTester)
}
+
+ "Casting a UInt to an Enum" should "warn if the UInt can express illegal states" in {
+ object MyEnum extends ChiselEnum {
+ val e0, e1, e2 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(MyEnum()))
+ out := MyEnum(in)
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should include ("warn")
+ log should include ("Casting non-literal UInt")
+ }
+
+ it should "NOT warn if the Enum is total" in {
+ object TotalEnum extends ChiselEnum {
+ val e0, e1, e2, e3 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(TotalEnum()))
+ out := TotalEnum(in)
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should not include ("warn")
+ }
+
+ "Casting a UInt to an Enum with .safe" should "NOT warn" in {
+ object MyEnum extends ChiselEnum {
+ val e0, e1, e2 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(MyEnum()))
+ out := MyEnum.safe(in)._1
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should not include ("warn")
+ }
+
+ it should "NOT generate any validity logic if the Enum is total" in {
+ object TotalEnum extends ChiselEnum {
+ val e0, e1, e2, e3 = Value
+ }
+
+ class MyModule extends Module {
+ val in = IO(Input(UInt(2.W)))
+ val out = IO(Output(TotalEnum()))
+ val (res, valid) = TotalEnum.safe(in)
+ assert(valid.litToBoolean, "It should be true.B")
+ out := res
+ }
+ val (log, _) = grabLog(ChiselStage.elaborate(new MyModule))
+ log should not include ("warn")
+ }
}
class StrongEnumAnnotator extends Module {