summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/chisel3/internal/firrtl/Converter.scala3
-rw-r--r--src/main/scala/chisel3/internal/firrtl/Emitter.scala3
-rw-r--r--src/main/scala/chisel3/package.scala3
-rw-r--r--src/main/scala/chisel3/util/Conditional.scala10
-rw-r--r--src/test/scala/chiselTests/StrongEnum.scala430
-rw-r--r--src/test/scala/cookbook/FSM.scala27
6 files changed, 458 insertions, 18 deletions
diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala
index 97504aba..181bdfe8 100644
--- a/src/main/scala/chisel3/internal/firrtl/Converter.scala
+++ b/src/main/scala/chisel3/internal/firrtl/Converter.scala
@@ -2,7 +2,7 @@
package chisel3.internal.firrtl
import chisel3._
-import chisel3.core.SpecifiedDirection
+import chisel3.core.{SpecifiedDirection, EnumType}
import chisel3.experimental._
import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo}
import firrtl.{ir => fir}
@@ -211,6 +211,7 @@ private[chisel3] object Converter {
def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match {
case _: Clock => fir.ClockType
+ case d: EnumType => fir.UIntType(convert(d.width))
case d: UInt => fir.UIntType(convert(d.width))
case d: SInt => fir.SIntType(convert(d.width))
case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint))
diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala
index 26ccc09d..ac4bf8e7 100644
--- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala
+++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala
@@ -2,7 +2,7 @@
package chisel3.internal.firrtl
import chisel3._
-import chisel3.core.SpecifiedDirection
+import chisel3.core.{SpecifiedDirection, EnumType}
import chisel3.experimental._
import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine}
@@ -28,6 +28,7 @@ private class Emitter(circuit: Circuit) {
private def emitType(d: Data, clearDir: Boolean = false): String = d match {
case d: Clock => "Clock"
+ case d: chisel3.core.EnumType => s"UInt${d.width}"
case d: UInt => s"UInt${d.width}"
case d: SInt => s"SInt${d.width}"
case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}"
diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala
index b7c39bad..e79a1186 100644
--- a/src/main/scala/chisel3/package.scala
+++ b/src/main/scala/chisel3/package.scala
@@ -420,6 +420,9 @@ package object chisel3 { // scalastyle:ignore package.object.name
val Analog = chisel3.core.Analog
val attach = chisel3.core.attach
+ type ChiselEnum = chisel3.core.EnumFactory
+ val EnumAnnotations = chisel3.core.EnumAnnotations
+
val withClockAndReset = chisel3.core.withClockAndReset
val withClock = chisel3.core.withClock
val withReset = chisel3.core.withReset
diff --git a/src/main/scala/chisel3/util/Conditional.scala b/src/main/scala/chisel3/util/Conditional.scala
index bf2d4268..3630f8ad 100644
--- a/src/main/scala/chisel3/util/Conditional.scala
+++ b/src/main/scala/chisel3/util/Conditional.scala
@@ -24,7 +24,7 @@ object unless { // scalastyle:ignore object.name
* user-facing API.
* @note DO NOT USE. This API is subject to change without warning.
*/
-class SwitchContext[T <: Bits](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) {
+class SwitchContext[T <: Element](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) {
def is(v: Iterable[T])(block: => Unit): SwitchContext[T] = {
if (!v.isEmpty) {
val newLits = v.map { w =>
@@ -60,19 +60,19 @@ object is { // scalastyle:ignore object.name
// TODO: Begin deprecation of non-type-parameterized is statements.
/** Executes `block` if the switch condition is equal to any of the values in `v`.
*/
- def apply(v: Iterable[Bits])(block: => Unit) {
+ def apply(v: Iterable[Element])(block: => Unit) {
require(false, "The 'is' keyword may not be used outside of a switch.")
}
/** Executes `block` if the switch condition is equal to `v`.
*/
- def apply(v: Bits)(block: => Unit) {
+ def apply(v: Element)(block: => Unit) {
require(false, "The 'is' keyword may not be used outside of a switch.")
}
/** Executes `block` if the switch condition is equal to any of the values in the argument list.
*/
- def apply(v: Bits, vr: Bits*)(block: => Unit) {
+ def apply(v: Element, vr: Element*)(block: => Unit) {
require(false, "The 'is' keyword may not be used outside of a switch.")
}
}
@@ -91,7 +91,7 @@ object is { // scalastyle:ignore object.name
* }}}
*/
object switch { // scalastyle:ignore object.name
- def apply[T <: Bits](cond: T)(x: => Unit): Unit = macro impl
+ def apply[T <: Element](cond: T)(x: => Unit): Unit = macro impl
def impl(c: Context)(cond: c.Tree)(x: c.Tree): c.Tree = { import c.universe._
val q"..$body" = x
val res = body.foldLeft(q"""new SwitchContext($cond, None, Set.empty)""") {
diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala
new file mode 100644
index 00000000..98286624
--- /dev/null
+++ b/src/test/scala/chiselTests/StrongEnum.scala
@@ -0,0 +1,430 @@
+// See LICENSE for license details.
+
+package chiselTests
+
+import chisel3._
+import chisel3.experimental.ChiselEnum
+import chisel3.internal.firrtl.UnknownWidth
+import chisel3.util._
+import chisel3.testers.BasicTester
+import org.scalatest.{FreeSpec, Matchers}
+
+object EnumExample extends ChiselEnum {
+ val e0, e1, e2 = Value
+
+ val e100 = Value(100.U)
+ val e101 = Value(101.U)
+
+ val litValues = List(0.U, 1.U, 2.U, 100.U, 101.U)
+}
+
+object OtherEnum extends ChiselEnum {
+ val otherEnum = Value
+}
+
+object NonLiteralEnumType extends ChiselEnum {
+ val nonLit = Value(UInt())
+}
+
+object NonIncreasingEnum extends ChiselEnum {
+ val x = Value(2.U)
+ val y = Value(2.U)
+}
+
+class SimpleConnector(inType: Data, outType: Data) extends Module {
+ val io = IO(new Bundle {
+ val in = Input(inType)
+ val out = Output(outType)
+ })
+
+ io.out := io.in
+}
+
+class CastToUInt extends Module {
+ val io = IO(new Bundle {
+ val in = Input(EnumExample())
+ val out = Output(UInt())
+ })
+
+ io.out := io.in.asUInt()
+}
+
+class CastFromLit(in: UInt) extends Module {
+ val io = IO(new Bundle {
+ val out = Output(EnumExample())
+ val valid = Output(Bool())
+ })
+
+ io.out := EnumExample(in)
+ io.valid := io.out.isValid
+}
+
+class CastFromNonLit extends Module {
+ val io = IO(new Bundle {
+ val in = Input(UInt(EnumExample.getWidth.W))
+ val out = Output(EnumExample())
+ val valid = Output(Bool())
+ })
+
+ io.out := EnumExample(io.in)
+ io.valid := io.out.isValid
+}
+
+class CastFromNonLitWidth(w: Option[Int] = None) extends Module {
+ val width = if (w.isDefined) w.get.W else UnknownWidth()
+
+ override val io = IO(new Bundle {
+ val in = Input(UInt(width))
+ val out = Output(EnumExample())
+ })
+
+ io.out := EnumExample(io.in)
+}
+
+class EnumOps(val xType: ChiselEnum, val yType: ChiselEnum) extends Module {
+ val io = IO(new Bundle {
+ val x = Input(xType())
+ val y = Input(yType())
+
+ val lt = Output(Bool())
+ val le = Output(Bool())
+ val gt = Output(Bool())
+ val ge = Output(Bool())
+ val eq = Output(Bool())
+ val ne = Output(Bool())
+ })
+
+ io.lt := io.x < io.y
+ io.le := io.x <= io.y
+ io.gt := io.x > io.y
+ io.ge := io.x >= io.y
+ io.eq := io.x === io.y
+ io.ne := io.x =/= io.y
+}
+
+object StrongEnumFSM {
+ object State extends ChiselEnum {
+ val sNone, sOne1, sTwo1s = Value
+
+ val correct_annotation_map = Map[String, BigInt]("sNone" -> 0, "sOne1" -> 1, "sTwo1s" -> 2)
+ }
+}
+
+class StrongEnumFSM extends Module {
+ import StrongEnumFSM.State
+ import StrongEnumFSM.State._
+
+ // This FSM detects two 1's one after the other
+ val io = IO(new Bundle {
+ val in = Input(Bool())
+ val out = Output(Bool())
+ val state = Output(State())
+ })
+
+ val state = RegInit(sNone)
+
+ io.out := (state === sTwo1s)
+ io.state := state
+
+ switch (state) {
+ is (sNone) {
+ when (io.in) {
+ state := sOne1
+ }
+ }
+ is (sOne1) {
+ when (io.in) {
+ state := sTwo1s
+ } .otherwise {
+ state := sNone
+ }
+ }
+ is (sTwo1s) {
+ when (!io.in) {
+ state := sNone
+ }
+ }
+ }
+}
+
+class CastToUIntTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new CastToUInt)
+ mod.io.in := enum
+ assert(mod.io.out === lit)
+ }
+ stop()
+}
+
+class CastFromLitTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new CastFromLit(lit))
+ assert(mod.io.out === enum)
+ assert(mod.io.valid === true.B)
+ }
+ stop()
+}
+
+class CastFromNonLitTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new CastFromNonLit)
+ mod.io.in := lit
+ assert(mod.io.out === enum)
+ assert(mod.io.valid === true.B)
+ }
+
+ val invalid_values = (1 until (1 << EnumExample.getWidth)).
+ filter(!EnumExample.litValues.map(_.litValue).contains(_)).
+ map(_.U)
+
+ for (invalid_val <- invalid_values) {
+ val mod = Module(new CastFromNonLit)
+ mod.io.in := invalid_val
+
+ assert(mod.io.valid === false.B)
+ }
+
+ stop()
+}
+
+class CastToInvalidEnumTester extends BasicTester {
+ val invalid_value: UInt = EnumExample.litValues.last + 1.U
+ Module(new CastFromLit(invalid_value))
+}
+
+class EnumOpsTester extends BasicTester {
+ for (x <- EnumExample.all;
+ y <- EnumExample.all) {
+ val mod = Module(new EnumOps(EnumExample, EnumExample))
+ mod.io.x := x
+ mod.io.y := y
+
+ assert(mod.io.lt === (x.asUInt() < y.asUInt()))
+ assert(mod.io.le === (x.asUInt() <= y.asUInt()))
+ assert(mod.io.gt === (x.asUInt() > y.asUInt()))
+ assert(mod.io.ge === (x.asUInt() >= y.asUInt()))
+ assert(mod.io.eq === (x.asUInt() === y.asUInt()))
+ assert(mod.io.ne === (x.asUInt() =/= y.asUInt()))
+ }
+ stop()
+}
+
+class InvalidEnumOpsTester extends BasicTester {
+ val mod = Module(new EnumOps(EnumExample, OtherEnum))
+ mod.io.x := EnumExample.e0
+ mod.io.y := OtherEnum.otherEnum
+}
+
+class IsLitTester extends BasicTester {
+ for (e <- EnumExample.all) {
+ val wire = WireInit(e)
+
+ assert(e.isLit())
+ assert(!wire.isLit())
+ }
+ stop()
+}
+
+class NextTester extends BasicTester {
+ for ((e,n) <- EnumExample.all.zip(EnumExample.litValues.tail :+ EnumExample.litValues.head)) {
+ assert(e.next.litValue == n.litValue)
+ val w = WireInit(e)
+ assert(w.next === EnumExample(n))
+ }
+ stop()
+}
+
+class WidthTester extends BasicTester {
+ assert(EnumExample.getWidth == EnumExample.litValues.last.getWidth)
+ assert(EnumExample.all.forall(_.getWidth == EnumExample.litValues.last.getWidth))
+ assert(EnumExample.all.forall{e =>
+ val w = WireInit(e)
+ w.getWidth == EnumExample.litValues.last.getWidth
+ })
+ stop()
+}
+
+class StrongEnumFSMTester extends BasicTester {
+ import StrongEnumFSM.State
+ import StrongEnumFSM.State._
+
+ val dut = Module(new StrongEnumFSM)
+
+ // Inputs and expected results
+ val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B)
+ val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B)
+ val expected_state = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s)
+
+ val cntr = Counter(inputs.length)
+ val cycle = cntr.value
+
+ dut.io.in := inputs(cycle)
+ assert(dut.io.out === expected(cycle))
+ assert(dut.io.state === expected_state(cycle))
+
+ when(cntr.inc()) {
+ stop()
+ }
+}
+
+class StrongEnumSpec extends ChiselFlatSpec {
+ import chisel3.internal.ChiselException
+
+ behavior of "Strong enum tester"
+
+ it should "fail to instantiate non-literal enums with the Value function" in {
+ an [ExceptionInInitializerError] should be thrownBy {
+ elaborate(new SimpleConnector(NonLiteralEnumType(), NonLiteralEnumType()))
+ }
+ }
+
+ it should "fail to instantiate non-increasing enums with the Value function" in {
+ an [ExceptionInInitializerError] should be thrownBy {
+ elaborate(new SimpleConnector(NonIncreasingEnum(), NonIncreasingEnum()))
+ }
+ }
+
+ it should "connect enums of the same type" in {
+ elaborate(new SimpleConnector(EnumExample(), EnumExample()))
+ elaborate(new SimpleConnector(EnumExample(), EnumExample.Type()))
+ }
+
+ it should "fail to connect a strong enum to a UInt" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new SimpleConnector(EnumExample(), UInt()))
+ }
+ }
+
+ it should "fail to connect enums of different types" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new SimpleConnector(EnumExample(), OtherEnum()))
+ }
+
+ a [ChiselException] should be thrownBy {
+ elaborate(new SimpleConnector(EnumExample.Type(), OtherEnum.Type()))
+ }
+ }
+
+ it should "cast enums to UInts correctly" in {
+ assertTesterPasses(new CastToUIntTester)
+ }
+
+ it should "cast literal UInts to enums correctly" in {
+ assertTesterPasses(new CastFromLitTester)
+ }
+
+ it should "cast non-literal UInts to enums correctly and detect illegal casts" in {
+ assertTesterPasses(new CastFromNonLitTester)
+ }
+
+ it should "prevent illegal literal casts to enums" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new CastToInvalidEnumTester)
+ }
+ }
+
+ it should "only allow non-literal casts to enums if the width is smaller than or equal to the enum width" in {
+ for (w <- 0 to EnumExample.getWidth)
+ elaborate(new CastFromNonLitWidth(Some(w)))
+
+ a [ChiselException] should be thrownBy {
+ elaborate(new CastFromNonLitWidth)
+ }
+
+ for (w <- (EnumExample.getWidth+1) to (EnumExample.getWidth+100)) {
+ a [ChiselException] should be thrownBy {
+ elaborate(new CastFromNonLitWidth(Some(w)))
+ }
+ }
+ }
+
+ it should "execute enum comparison operations correctly" in {
+ assertTesterPasses(new EnumOpsTester)
+ }
+
+ it should "fail to compare enums of different types" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new InvalidEnumOpsTester)
+ }
+ }
+
+ it should "correctly check whether or not enums are literal" in {
+ assertTesterPasses(new IsLitTester)
+ }
+
+ it should "return the correct next values for enums" in {
+ assertTesterPasses(new NextTester)
+ }
+
+ it should "return the correct widths for enums" in {
+ assertTesterPasses(new WidthTester)
+ }
+
+ it should "maintain Scala-level type-safety" in {
+ def foo(e: EnumExample.Type) = {}
+
+ "foo(EnumExample.e1); foo(EnumExample.e1.next)" should compile
+ "foo(OtherEnum.otherEnum)" shouldNot compile
+ }
+
+ "StrongEnum FSM" should "work" in {
+ assertTesterPasses(new StrongEnumFSMTester)
+ }
+}
+
+class StrongEnumAnnotationSpec extends FreeSpec with Matchers {
+ import chisel3.experimental.EnumAnnotations._
+ import firrtl.annotations.ComponentName
+
+ "Test that strong enums annotate themselves appropriately" in {
+
+ def test() = {
+ Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match {
+ case ChiselExecutionSuccess(Some(circuit), emitted, _) =>
+ val annos = circuit.annotations.map(_.toFirrtl)
+
+ val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a }
+ val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a }
+
+ // Print the annotations out onto the screen
+ println("Enum definitions:")
+ enumDefAnnos.foreach {
+ case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition")
+ }
+ println("Enum components:")
+ enumCompAnnos.foreach{
+ case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName")
+ }
+
+ // Check that the global annotation is correct
+ enumDefAnnos.exists {
+ case EnumDefAnnotation(name, map) =>
+ name.endsWith("State") &&
+ map.size == StrongEnumFSM.State.correct_annotation_map.size &&
+ map.forall {
+ case (k, v) =>
+ val correctValue = StrongEnumFSM.State.correct_annotation_map(k)
+ correctValue == v
+ }
+ case _ => false
+ } should be(true)
+
+ // Check that the component annotations are correct
+ enumCompAnnos.count {
+ case EnumComponentAnnotation(target, enumName) =>
+ val ComponentName(targetName, _) = target
+ (targetName == "state" && enumName.endsWith("State")) ||
+ (targetName == "io.state" && enumName.endsWith("State"))
+ case _ => false
+ } should be(2)
+
+ case _ =>
+ assert(false)
+ }
+ }
+
+ // We run this test twice, to test for an older bug where only the first circuit would be annotated
+ test()
+ test()
+ }
+}
diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala
index 22cf8059..170d110f 100644
--- a/src/test/scala/cookbook/FSM.scala
+++ b/src/test/scala/cookbook/FSM.scala
@@ -4,39 +4,44 @@ package cookbook
import chisel3._
import chisel3.util._
+import chisel3.experimental.ChiselEnum
/* ### How do I create a finite state machine?
*
- * Use Chisel Enum to construct the states and switch & is to construct the FSM
+ * Use Chisel StrongEnum to construct the states and switch & is to construct the FSM
* control logic
*/
+
class DetectTwoOnes extends Module {
val io = IO(new Bundle {
val in = Input(Bool())
val out = Output(Bool())
})
- val sNone :: sOne1 :: sTwo1s :: Nil = Enum(3)
- val state = RegInit(sNone)
+ object State extends ChiselEnum {
+ val sNone, sOne1, sTwo1s = Value
+ }
+
+ val state = RegInit(State.sNone)
- io.out := (state === sTwo1s)
+ io.out := (state === State.sTwo1s)
switch (state) {
- is (sNone) {
+ is (State.sNone) {
when (io.in) {
- state := sOne1
+ state := State.sOne1
}
}
- is (sOne1) {
+ is (State.sOne1) {
when (io.in) {
- state := sTwo1s
+ state := State.sTwo1s
} .otherwise {
- state := sNone
+ state := State.sNone
}
}
- is (sTwo1s) {
+ is (State.sTwo1s) {
when (!io.in) {
- state := sNone
+ state := State.sNone
}
}
}