summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/Bits.scala66
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/Clock.scala2
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/Data.scala4
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala6
-rw-r--r--chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala248
-rw-r--r--chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala8
-rw-r--r--src/main/scala/chisel3/internal/firrtl/Converter.scala3
-rw-r--r--src/main/scala/chisel3/internal/firrtl/Emitter.scala3
-rw-r--r--src/main/scala/chisel3/package.scala3
-rw-r--r--src/main/scala/chisel3/util/Conditional.scala10
-rw-r--r--src/test/scala/chiselTests/StrongEnum.scala430
-rw-r--r--src/test/scala/cookbook/FSM.scala27
12 files changed, 753 insertions, 57 deletions
diff --git a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala
index e9458446..9356a91c 100644
--- a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala
+++ b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala
@@ -19,7 +19,11 @@ import chisel3.internal.firrtl.PrimOp._
*
* @define coll element
*/
-abstract class Element(private[chisel3] val width: Width) extends Data {
+abstract class Element extends Data {
+ private[chisel3] final def allElements: Seq[Element] = Seq(this)
+ def widthKnown: Boolean = width.known
+ def name: String = getRef.name
+
private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) {
binding = target
val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection)
@@ -30,9 +34,32 @@ abstract class Element(private[chisel3] val width: Width) extends Data {
}
}
- private[chisel3] final def allElements: Seq[Element] = Seq(this)
- def widthKnown: Boolean = width.known
- def name: String = getRef.name
+ private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match {
+ // Translate Bundle lit bindings to Element lit bindings
+ case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
+ case Some(litArg) => Some(ElementLitBinding(litArg))
+ case _ => Some(DontCareBinding())
+ }
+ case topBindingOpt => topBindingOpt
+ }
+
+ private[core] def litArgOption: Option[LitArg] = topBindingOpt match {
+ case Some(ElementLitBinding(litArg)) => Some(litArg)
+ case _ => None
+ }
+
+ override def litOption: Option[BigInt] = litArgOption.map(_.num)
+ private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth)
+
+ // provide bits-specific literal handling functionality here
+ override private[chisel3] def ref: Arg = topBindingOpt match {
+ case Some(ElementLitBinding(litArg)) => litArg
+ case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
+ case Some(litArg) => litArg
+ case _ => throwException(s"internal error: DontCare should be caught before getting ref")
+ }
+ case _ => super.ref
+ }
private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = {
// If the source is a DontCare, generate a DefInvalid for the sink,
@@ -69,7 +96,7 @@ private[chisel3] sealed trait ToBoolable extends Element {
* @define sumWidth @note The width of the returned $coll is `width of this` + `width of that`.
* @define unchangedWidth @note The width of the returned $coll is unchanged, i.e., the `width of this`.
*/
-sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable { //scalastyle:off number.of.methods
+sealed abstract class Bits(private[chisel3] val width: Width) extends Element with ToBoolable { //scalastyle:off number.of.methods
// TODO: perhaps make this concrete?
// Arguments for: self-checking code (can't do arithmetic on bits)
// Arguments against: generates down to a FIRRTL UInt anyways
@@ -79,33 +106,6 @@ sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable
def cloneType: this.type = cloneTypeWidth(width)
- private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match {
- // Translate Bundle lit bindings to Element lit bindings
- case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
- case Some(litArg) => Some(ElementLitBinding(litArg))
- case _ => Some(DontCareBinding())
- }
- case topBindingOpt => topBindingOpt
- }
-
- private[core] def litArgOption: Option[LitArg] = topBindingOpt match {
- case Some(ElementLitBinding(litArg)) => Some(litArg)
- case _ => None
- }
-
- override def litOption: Option[BigInt] = litArgOption.map(_.num)
- private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth)
-
- // provide bits-specific literal handling functionality here
- override private[chisel3] def ref: Arg = topBindingOpt match {
- case Some(ElementLitBinding(litArg)) => litArg
- case Some(BundleLitBinding(litMap)) => litMap.get(this) match {
- case Some(litArg) => litArg
- case _ => throwException(s"internal error: DontCare should be caught before getting ref")
- }
- case _ => super.ref
- }
-
/** Tail operator
*
* @param n the number of bits to remove
@@ -1693,7 +1693,7 @@ object FixedPoint {
*
* @note This API is experimental and subject to change
*/
-final class Analog private (width: Width) extends Element(width) {
+final class Analog private (private[chisel3] val width: Width) extends Element {
require(width.known, "Since Analog is only for use in BlackBoxes, width must be known")
private[core] override def typeEquivalent(that: Data): Boolean =
diff --git a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala
index b728075b..88208d9a 100644
--- a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala
+++ b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala
@@ -12,7 +12,7 @@ object Clock {
}
// TODO: Document this.
-sealed class Clock extends Element(Width(1)) {
+sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element {
def cloneType: this.type = Clock().asInstanceOf[this.type]
private[core] def typeEquivalent(that: Data): Boolean =
diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala
index 869e22fb..f292d3c6 100644
--- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala
+++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala
@@ -533,10 +533,12 @@ object WireInit {
/** RHS (source) for Invalidate API.
* Causes connection logic to emit a DefInvalid when connected to an output port (or wire).
*/
-object DontCare extends Element(width = UnknownWidth()) {
+object DontCare extends Element {
// This object should be initialized before we execute any user code that refers to it,
// otherwise this "Chisel" object will end up on the UserModule's id list.
+ private[chisel3] override val width: Width = UnknownWidth()
+
bind(DontCareBinding(), SpecifiedDirection.Output)
override def cloneType = DontCare
diff --git a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala
index eba24870..c9420ba7 100644
--- a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala
+++ b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala
@@ -79,6 +79,12 @@ object MonoConnect {
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
case (sink_e: Clock, source_e: Clock) =>
elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
+ case (sink_e: EnumType, source_e: UnsafeEnum) =>
+ elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
+ case (sink_e: EnumType, source_e: EnumType) if sink_e.typeEquivalent(source_e) =>
+ elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
+ case (sink_e: UnsafeEnum, source_e: UInt) =>
+ elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod)
// Handle Vec case
case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) =>
diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala
new file mode 100644
index 00000000..a9f51387
--- /dev/null
+++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala
@@ -0,0 +1,248 @@
+// See LICENSE for license details.
+
+package chisel3.core
+
+import scala.language.experimental.macros
+import scala.reflect.macros.blackbox.Context
+import scala.collection.mutable
+
+import chisel3.internal.Builder.pushOp
+import chisel3.internal.firrtl.PrimOp._
+import chisel3.internal.firrtl._
+import chisel3.internal.sourceinfo._
+import chisel3.internal.{Builder, InstanceId, throwException}
+import firrtl.annotations._
+
+
+object EnumAnnotations {
+ case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] {
+ def duplicate(n: Named) = this.copy(target = n)
+ }
+
+ case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation {
+ def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName)
+ }
+
+ case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation
+
+ case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation {
+ override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition)
+ }
+}
+import EnumAnnotations._
+
+
+abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element {
+ override def cloneType: this.type = factory().asInstanceOf[this.type]
+
+ private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = {
+ requireIsHardware(this, "bits operated on")
+ requireIsHardware(other, "bits operated on")
+
+ if(!this.typeEquivalent(other))
+ throwException(s"Enum types are not equivalent: ${this.enumTypeName}, ${other.enumTypeName}")
+
+ pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref))
+ }
+
+ private[core] override def typeEquivalent(that: Data): Boolean = {
+ this.getClass == that.getClass &&
+ this.factory == that.asInstanceOf[EnumType].factory
+ }
+
+ // This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we
+ // inherit it from Data.
+ private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo,
+ compileOptions: CompileOptions): Unit = ???
+
+ final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg
+ final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
+ final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg
+ final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
+ final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg
+ final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg
+
+ def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that)
+ def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that)
+ def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that)
+ def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that)
+ def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that)
+ def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that)
+
+ override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt =
+ pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref))
+
+ protected[chisel3] override def width: Width = factory.width
+
+ def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = {
+ if (litOption.isDefined) {
+ true.B
+ } else {
+ factory.all.map(this === _).reduce(_ || _)
+ }
+ }
+
+ def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = {
+ if (litOption.isDefined) {
+ val index = factory.all.indexOf(this)
+
+ if (index < factory.all.length-1)
+ factory.all(index+1).asInstanceOf[this.type]
+ else
+ factory.all.head.asInstanceOf[this.type]
+ } else {
+ val enums_with_nexts = factory.all zip (factory.all.tail :+ factory.all.head)
+ val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e,n) => (this === e, n) } )
+ next_enum.asInstanceOf[this.type]
+ }
+ }
+
+ private[core] def bindToLiteral(num: BigInt, w: Width): Unit = {
+ val lit = ULit(num, w)
+ lit.bindLitArg(this)
+ }
+
+ override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = {
+ super.bind(target, parentDirection)
+
+ // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception.
+ // To workaround that, we only annotate enums that are not bound to literals.
+ if (selfAnnotating && litOption.isEmpty) {
+ annotateEnum()
+ }
+ }
+
+ private def annotateEnum(): Unit = {
+ annotate(EnumComponentChiselAnnotation(this, enumTypeName))
+
+ if (!Builder.annotations.contains(factory.globalAnnotation)) {
+ annotate(factory.globalAnnotation)
+ }
+ }
+
+ protected def enumTypeName: String = factory.enumTypeName
+
+ def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer
+}
+
+
+abstract class EnumFactory {
+ class Type extends EnumType(this)
+ object Type {
+ def apply(): Type = EnumFactory.this.apply()
+ }
+
+ private var id: BigInt = 0
+ private[core] var width: Width = 0.W
+
+ private case class EnumRecord(inst: Type, name: String)
+ private val enum_records = mutable.ArrayBuffer.empty[EnumRecord]
+
+ private def enumNames = enum_records.map(_.name).toSeq
+ private def enumValues = enum_records.map(_.inst.litValue()).toSeq
+ private def enumInstances = enum_records.map(_.inst).toSeq
+
+ private[core] val enumTypeName = getClass.getName.init
+
+ private[core] def globalAnnotation: EnumDefChiselAnnotation =
+ EnumDefChiselAnnotation(enumTypeName, (enumNames, enumValues).zipped.toMap)
+
+ def getWidth: Int = width.get
+
+ def all: Seq[Type] = enumInstances
+
+ protected def Value: Type = macro EnumMacros.ValImpl
+ protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl
+
+ protected def do_Value(names: Seq[String]): Type = {
+ val result = new Type
+
+ // We have to use UnknownWidth here, because we don't actually know what the final width will be
+ result.bindToLiteral(id, UnknownWidth())
+
+ val result_name = names.find(!enumNames.contains(_)).get
+ enum_records.append(EnumRecord(result, result_name))
+
+ width = (1 max id.bitLength).W
+ id += 1
+
+ result
+ }
+
+ protected def do_Value(names: Seq[String], id: UInt): Type = {
+ // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just
+ // throw an exception
+ if (id.litOption.isEmpty)
+ throwException(s"$enumTypeName defined with a non-literal type")
+ if (id.litValue() < this.id)
+ throwException(s"Enums must be strictly increasing: $enumTypeName")
+
+ this.id = id.litValue()
+ do_Value(names)
+ }
+
+ def apply(): Type = new Type
+
+ def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = {
+ if (n.litOption.isDefined) {
+ val result = enumInstances.find(_.litValue == n.litValue)
+
+ if (result.isEmpty) {
+ throwException(s"${n.litValue}.U is not a valid value for $enumTypeName")
+ } else {
+ result.get
+ }
+ } else if (!n.isWidthKnown) {
+ throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width")
+ } else if (n.getWidth > this.getWidth) {
+ throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)")
+ } else {
+ Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that its value is legal by calling isValid")
+
+ val glue = Wire(new UnsafeEnum(width))
+ glue := n
+ val result = Wire(new Type)
+ result := glue
+ result
+ }
+ }
+}
+
+
+private[core] object EnumMacros {
+ def ValImpl(c: Context) : c.Tree = {
+ import c.universe._
+ val names = getNames(c)
+ q"""this.do_Value(Seq(..$names))"""
+ }
+
+ def ValCustomImpl(c: Context)(id: c.Expr[UInt]) = {
+ import c.universe._
+ val names = getNames(c)
+ q"""this.do_Value(Seq(..$names), $id)"""
+ }
+
+ // Much thanks to Travis Brown for this solution:
+ // stackoverflow.com/questions/18450203/retrieve-the-name-of-the-value-a-scala-macro-invocation-will-be-assigned-to
+ def getNames(c: Context): Seq[String] = {
+ import c.universe._
+
+ val names = c.enclosingClass.collect {
+ case ValDef(_, name, _, rhs)
+ if rhs.pos == c.macroApplication.pos => name.decoded
+ }
+
+ if (names.isEmpty)
+ c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum")
+
+ names
+ }
+}
+
+
+// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts
+// to enums.
+private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(UnsafeEnum, selfAnnotating = false) {
+ override def cloneType: this.type = new UnsafeEnum(width).asInstanceOf[this.type]
+}
+private object UnsafeEnum extends EnumFactory
diff --git a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala
index b6630f7f..ae8b248a 100644
--- a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala
+++ b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala
@@ -68,10 +68,10 @@ abstract class LitArg(val num: BigInt, widthArg: Width) extends Arg {
private[chisel3] def width: Width = if (forcedWidth) widthArg else Width(minWidth)
override def fullName(ctx: Component): String = name
// Ensure the node representing this LitArg has a ref to it and a literal binding.
- def bindLitArg[T <: Bits](bits: T): T = {
- bits.bind(ElementLitBinding(this))
- bits.setRef(this)
- bits
+ def bindLitArg[T <: Element](elem: T): T = {
+ elem.bind(ElementLitBinding(this))
+ elem.setRef(this)
+ elem
}
protected def minWidth: Int
diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala
index 97504aba..181bdfe8 100644
--- a/src/main/scala/chisel3/internal/firrtl/Converter.scala
+++ b/src/main/scala/chisel3/internal/firrtl/Converter.scala
@@ -2,7 +2,7 @@
package chisel3.internal.firrtl
import chisel3._
-import chisel3.core.SpecifiedDirection
+import chisel3.core.{SpecifiedDirection, EnumType}
import chisel3.experimental._
import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo}
import firrtl.{ir => fir}
@@ -211,6 +211,7 @@ private[chisel3] object Converter {
def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match {
case _: Clock => fir.ClockType
+ case d: EnumType => fir.UIntType(convert(d.width))
case d: UInt => fir.UIntType(convert(d.width))
case d: SInt => fir.SIntType(convert(d.width))
case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint))
diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala
index 26ccc09d..ac4bf8e7 100644
--- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala
+++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala
@@ -2,7 +2,7 @@
package chisel3.internal.firrtl
import chisel3._
-import chisel3.core.SpecifiedDirection
+import chisel3.core.{SpecifiedDirection, EnumType}
import chisel3.experimental._
import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine}
@@ -28,6 +28,7 @@ private class Emitter(circuit: Circuit) {
private def emitType(d: Data, clearDir: Boolean = false): String = d match {
case d: Clock => "Clock"
+ case d: chisel3.core.EnumType => s"UInt${d.width}"
case d: UInt => s"UInt${d.width}"
case d: SInt => s"SInt${d.width}"
case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}"
diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala
index b7c39bad..e79a1186 100644
--- a/src/main/scala/chisel3/package.scala
+++ b/src/main/scala/chisel3/package.scala
@@ -420,6 +420,9 @@ package object chisel3 { // scalastyle:ignore package.object.name
val Analog = chisel3.core.Analog
val attach = chisel3.core.attach
+ type ChiselEnum = chisel3.core.EnumFactory
+ val EnumAnnotations = chisel3.core.EnumAnnotations
+
val withClockAndReset = chisel3.core.withClockAndReset
val withClock = chisel3.core.withClock
val withReset = chisel3.core.withReset
diff --git a/src/main/scala/chisel3/util/Conditional.scala b/src/main/scala/chisel3/util/Conditional.scala
index bf2d4268..3630f8ad 100644
--- a/src/main/scala/chisel3/util/Conditional.scala
+++ b/src/main/scala/chisel3/util/Conditional.scala
@@ -24,7 +24,7 @@ object unless { // scalastyle:ignore object.name
* user-facing API.
* @note DO NOT USE. This API is subject to change without warning.
*/
-class SwitchContext[T <: Bits](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) {
+class SwitchContext[T <: Element](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) {
def is(v: Iterable[T])(block: => Unit): SwitchContext[T] = {
if (!v.isEmpty) {
val newLits = v.map { w =>
@@ -60,19 +60,19 @@ object is { // scalastyle:ignore object.name
// TODO: Begin deprecation of non-type-parameterized is statements.
/** Executes `block` if the switch condition is equal to any of the values in `v`.
*/
- def apply(v: Iterable[Bits])(block: => Unit) {
+ def apply(v: Iterable[Element])(block: => Unit) {
require(false, "The 'is' keyword may not be used outside of a switch.")
}
/** Executes `block` if the switch condition is equal to `v`.
*/
- def apply(v: Bits)(block: => Unit) {
+ def apply(v: Element)(block: => Unit) {
require(false, "The 'is' keyword may not be used outside of a switch.")
}
/** Executes `block` if the switch condition is equal to any of the values in the argument list.
*/
- def apply(v: Bits, vr: Bits*)(block: => Unit) {
+ def apply(v: Element, vr: Element*)(block: => Unit) {
require(false, "The 'is' keyword may not be used outside of a switch.")
}
}
@@ -91,7 +91,7 @@ object is { // scalastyle:ignore object.name
* }}}
*/
object switch { // scalastyle:ignore object.name
- def apply[T <: Bits](cond: T)(x: => Unit): Unit = macro impl
+ def apply[T <: Element](cond: T)(x: => Unit): Unit = macro impl
def impl(c: Context)(cond: c.Tree)(x: c.Tree): c.Tree = { import c.universe._
val q"..$body" = x
val res = body.foldLeft(q"""new SwitchContext($cond, None, Set.empty)""") {
diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala
new file mode 100644
index 00000000..98286624
--- /dev/null
+++ b/src/test/scala/chiselTests/StrongEnum.scala
@@ -0,0 +1,430 @@
+// See LICENSE for license details.
+
+package chiselTests
+
+import chisel3._
+import chisel3.experimental.ChiselEnum
+import chisel3.internal.firrtl.UnknownWidth
+import chisel3.util._
+import chisel3.testers.BasicTester
+import org.scalatest.{FreeSpec, Matchers}
+
+object EnumExample extends ChiselEnum {
+ val e0, e1, e2 = Value
+
+ val e100 = Value(100.U)
+ val e101 = Value(101.U)
+
+ val litValues = List(0.U, 1.U, 2.U, 100.U, 101.U)
+}
+
+object OtherEnum extends ChiselEnum {
+ val otherEnum = Value
+}
+
+object NonLiteralEnumType extends ChiselEnum {
+ val nonLit = Value(UInt())
+}
+
+object NonIncreasingEnum extends ChiselEnum {
+ val x = Value(2.U)
+ val y = Value(2.U)
+}
+
+class SimpleConnector(inType: Data, outType: Data) extends Module {
+ val io = IO(new Bundle {
+ val in = Input(inType)
+ val out = Output(outType)
+ })
+
+ io.out := io.in
+}
+
+class CastToUInt extends Module {
+ val io = IO(new Bundle {
+ val in = Input(EnumExample())
+ val out = Output(UInt())
+ })
+
+ io.out := io.in.asUInt()
+}
+
+class CastFromLit(in: UInt) extends Module {
+ val io = IO(new Bundle {
+ val out = Output(EnumExample())
+ val valid = Output(Bool())
+ })
+
+ io.out := EnumExample(in)
+ io.valid := io.out.isValid
+}
+
+class CastFromNonLit extends Module {
+ val io = IO(new Bundle {
+ val in = Input(UInt(EnumExample.getWidth.W))
+ val out = Output(EnumExample())
+ val valid = Output(Bool())
+ })
+
+ io.out := EnumExample(io.in)
+ io.valid := io.out.isValid
+}
+
+class CastFromNonLitWidth(w: Option[Int] = None) extends Module {
+ val width = if (w.isDefined) w.get.W else UnknownWidth()
+
+ override val io = IO(new Bundle {
+ val in = Input(UInt(width))
+ val out = Output(EnumExample())
+ })
+
+ io.out := EnumExample(io.in)
+}
+
+class EnumOps(val xType: ChiselEnum, val yType: ChiselEnum) extends Module {
+ val io = IO(new Bundle {
+ val x = Input(xType())
+ val y = Input(yType())
+
+ val lt = Output(Bool())
+ val le = Output(Bool())
+ val gt = Output(Bool())
+ val ge = Output(Bool())
+ val eq = Output(Bool())
+ val ne = Output(Bool())
+ })
+
+ io.lt := io.x < io.y
+ io.le := io.x <= io.y
+ io.gt := io.x > io.y
+ io.ge := io.x >= io.y
+ io.eq := io.x === io.y
+ io.ne := io.x =/= io.y
+}
+
+object StrongEnumFSM {
+ object State extends ChiselEnum {
+ val sNone, sOne1, sTwo1s = Value
+
+ val correct_annotation_map = Map[String, BigInt]("sNone" -> 0, "sOne1" -> 1, "sTwo1s" -> 2)
+ }
+}
+
+class StrongEnumFSM extends Module {
+ import StrongEnumFSM.State
+ import StrongEnumFSM.State._
+
+ // This FSM detects two 1's one after the other
+ val io = IO(new Bundle {
+ val in = Input(Bool())
+ val out = Output(Bool())
+ val state = Output(State())
+ })
+
+ val state = RegInit(sNone)
+
+ io.out := (state === sTwo1s)
+ io.state := state
+
+ switch (state) {
+ is (sNone) {
+ when (io.in) {
+ state := sOne1
+ }
+ }
+ is (sOne1) {
+ when (io.in) {
+ state := sTwo1s
+ } .otherwise {
+ state := sNone
+ }
+ }
+ is (sTwo1s) {
+ when (!io.in) {
+ state := sNone
+ }
+ }
+ }
+}
+
+class CastToUIntTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new CastToUInt)
+ mod.io.in := enum
+ assert(mod.io.out === lit)
+ }
+ stop()
+}
+
+class CastFromLitTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new CastFromLit(lit))
+ assert(mod.io.out === enum)
+ assert(mod.io.valid === true.B)
+ }
+ stop()
+}
+
+class CastFromNonLitTester extends BasicTester {
+ for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) {
+ val mod = Module(new CastFromNonLit)
+ mod.io.in := lit
+ assert(mod.io.out === enum)
+ assert(mod.io.valid === true.B)
+ }
+
+ val invalid_values = (1 until (1 << EnumExample.getWidth)).
+ filter(!EnumExample.litValues.map(_.litValue).contains(_)).
+ map(_.U)
+
+ for (invalid_val <- invalid_values) {
+ val mod = Module(new CastFromNonLit)
+ mod.io.in := invalid_val
+
+ assert(mod.io.valid === false.B)
+ }
+
+ stop()
+}
+
+class CastToInvalidEnumTester extends BasicTester {
+ val invalid_value: UInt = EnumExample.litValues.last + 1.U
+ Module(new CastFromLit(invalid_value))
+}
+
+class EnumOpsTester extends BasicTester {
+ for (x <- EnumExample.all;
+ y <- EnumExample.all) {
+ val mod = Module(new EnumOps(EnumExample, EnumExample))
+ mod.io.x := x
+ mod.io.y := y
+
+ assert(mod.io.lt === (x.asUInt() < y.asUInt()))
+ assert(mod.io.le === (x.asUInt() <= y.asUInt()))
+ assert(mod.io.gt === (x.asUInt() > y.asUInt()))
+ assert(mod.io.ge === (x.asUInt() >= y.asUInt()))
+ assert(mod.io.eq === (x.asUInt() === y.asUInt()))
+ assert(mod.io.ne === (x.asUInt() =/= y.asUInt()))
+ }
+ stop()
+}
+
+class InvalidEnumOpsTester extends BasicTester {
+ val mod = Module(new EnumOps(EnumExample, OtherEnum))
+ mod.io.x := EnumExample.e0
+ mod.io.y := OtherEnum.otherEnum
+}
+
+class IsLitTester extends BasicTester {
+ for (e <- EnumExample.all) {
+ val wire = WireInit(e)
+
+ assert(e.isLit())
+ assert(!wire.isLit())
+ }
+ stop()
+}
+
+class NextTester extends BasicTester {
+ for ((e,n) <- EnumExample.all.zip(EnumExample.litValues.tail :+ EnumExample.litValues.head)) {
+ assert(e.next.litValue == n.litValue)
+ val w = WireInit(e)
+ assert(w.next === EnumExample(n))
+ }
+ stop()
+}
+
+class WidthTester extends BasicTester {
+ assert(EnumExample.getWidth == EnumExample.litValues.last.getWidth)
+ assert(EnumExample.all.forall(_.getWidth == EnumExample.litValues.last.getWidth))
+ assert(EnumExample.all.forall{e =>
+ val w = WireInit(e)
+ w.getWidth == EnumExample.litValues.last.getWidth
+ })
+ stop()
+}
+
+class StrongEnumFSMTester extends BasicTester {
+ import StrongEnumFSM.State
+ import StrongEnumFSM.State._
+
+ val dut = Module(new StrongEnumFSM)
+
+ // Inputs and expected results
+ val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B)
+ val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B)
+ val expected_state = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s)
+
+ val cntr = Counter(inputs.length)
+ val cycle = cntr.value
+
+ dut.io.in := inputs(cycle)
+ assert(dut.io.out === expected(cycle))
+ assert(dut.io.state === expected_state(cycle))
+
+ when(cntr.inc()) {
+ stop()
+ }
+}
+
+class StrongEnumSpec extends ChiselFlatSpec {
+ import chisel3.internal.ChiselException
+
+ behavior of "Strong enum tester"
+
+ it should "fail to instantiate non-literal enums with the Value function" in {
+ an [ExceptionInInitializerError] should be thrownBy {
+ elaborate(new SimpleConnector(NonLiteralEnumType(), NonLiteralEnumType()))
+ }
+ }
+
+ it should "fail to instantiate non-increasing enums with the Value function" in {
+ an [ExceptionInInitializerError] should be thrownBy {
+ elaborate(new SimpleConnector(NonIncreasingEnum(), NonIncreasingEnum()))
+ }
+ }
+
+ it should "connect enums of the same type" in {
+ elaborate(new SimpleConnector(EnumExample(), EnumExample()))
+ elaborate(new SimpleConnector(EnumExample(), EnumExample.Type()))
+ }
+
+ it should "fail to connect a strong enum to a UInt" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new SimpleConnector(EnumExample(), UInt()))
+ }
+ }
+
+ it should "fail to connect enums of different types" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new SimpleConnector(EnumExample(), OtherEnum()))
+ }
+
+ a [ChiselException] should be thrownBy {
+ elaborate(new SimpleConnector(EnumExample.Type(), OtherEnum.Type()))
+ }
+ }
+
+ it should "cast enums to UInts correctly" in {
+ assertTesterPasses(new CastToUIntTester)
+ }
+
+ it should "cast literal UInts to enums correctly" in {
+ assertTesterPasses(new CastFromLitTester)
+ }
+
+ it should "cast non-literal UInts to enums correctly and detect illegal casts" in {
+ assertTesterPasses(new CastFromNonLitTester)
+ }
+
+ it should "prevent illegal literal casts to enums" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new CastToInvalidEnumTester)
+ }
+ }
+
+ it should "only allow non-literal casts to enums if the width is smaller than or equal to the enum width" in {
+ for (w <- 0 to EnumExample.getWidth)
+ elaborate(new CastFromNonLitWidth(Some(w)))
+
+ a [ChiselException] should be thrownBy {
+ elaborate(new CastFromNonLitWidth)
+ }
+
+ for (w <- (EnumExample.getWidth+1) to (EnumExample.getWidth+100)) {
+ a [ChiselException] should be thrownBy {
+ elaborate(new CastFromNonLitWidth(Some(w)))
+ }
+ }
+ }
+
+ it should "execute enum comparison operations correctly" in {
+ assertTesterPasses(new EnumOpsTester)
+ }
+
+ it should "fail to compare enums of different types" in {
+ a [ChiselException] should be thrownBy {
+ elaborate(new InvalidEnumOpsTester)
+ }
+ }
+
+ it should "correctly check whether or not enums are literal" in {
+ assertTesterPasses(new IsLitTester)
+ }
+
+ it should "return the correct next values for enums" in {
+ assertTesterPasses(new NextTester)
+ }
+
+ it should "return the correct widths for enums" in {
+ assertTesterPasses(new WidthTester)
+ }
+
+ it should "maintain Scala-level type-safety" in {
+ def foo(e: EnumExample.Type) = {}
+
+ "foo(EnumExample.e1); foo(EnumExample.e1.next)" should compile
+ "foo(OtherEnum.otherEnum)" shouldNot compile
+ }
+
+ "StrongEnum FSM" should "work" in {
+ assertTesterPasses(new StrongEnumFSMTester)
+ }
+}
+
+class StrongEnumAnnotationSpec extends FreeSpec with Matchers {
+ import chisel3.experimental.EnumAnnotations._
+ import firrtl.annotations.ComponentName
+
+ "Test that strong enums annotate themselves appropriately" in {
+
+ def test() = {
+ Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match {
+ case ChiselExecutionSuccess(Some(circuit), emitted, _) =>
+ val annos = circuit.annotations.map(_.toFirrtl)
+
+ val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a }
+ val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a }
+
+ // Print the annotations out onto the screen
+ println("Enum definitions:")
+ enumDefAnnos.foreach {
+ case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition")
+ }
+ println("Enum components:")
+ enumCompAnnos.foreach{
+ case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName")
+ }
+
+ // Check that the global annotation is correct
+ enumDefAnnos.exists {
+ case EnumDefAnnotation(name, map) =>
+ name.endsWith("State") &&
+ map.size == StrongEnumFSM.State.correct_annotation_map.size &&
+ map.forall {
+ case (k, v) =>
+ val correctValue = StrongEnumFSM.State.correct_annotation_map(k)
+ correctValue == v
+ }
+ case _ => false
+ } should be(true)
+
+ // Check that the component annotations are correct
+ enumCompAnnos.count {
+ case EnumComponentAnnotation(target, enumName) =>
+ val ComponentName(targetName, _) = target
+ (targetName == "state" && enumName.endsWith("State")) ||
+ (targetName == "io.state" && enumName.endsWith("State"))
+ case _ => false
+ } should be(2)
+
+ case _ =>
+ assert(false)
+ }
+ }
+
+ // We run this test twice, to test for an older bug where only the first circuit would be annotated
+ test()
+ test()
+ }
+}
diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala
index 22cf8059..170d110f 100644
--- a/src/test/scala/cookbook/FSM.scala
+++ b/src/test/scala/cookbook/FSM.scala
@@ -4,39 +4,44 @@ package cookbook
import chisel3._
import chisel3.util._
+import chisel3.experimental.ChiselEnum
/* ### How do I create a finite state machine?
*
- * Use Chisel Enum to construct the states and switch & is to construct the FSM
+ * Use Chisel StrongEnum to construct the states and switch & is to construct the FSM
* control logic
*/
+
class DetectTwoOnes extends Module {
val io = IO(new Bundle {
val in = Input(Bool())
val out = Output(Bool())
})
- val sNone :: sOne1 :: sTwo1s :: Nil = Enum(3)
- val state = RegInit(sNone)
+ object State extends ChiselEnum {
+ val sNone, sOne1, sTwo1s = Value
+ }
+
+ val state = RegInit(State.sNone)
- io.out := (state === sTwo1s)
+ io.out := (state === State.sTwo1s)
switch (state) {
- is (sNone) {
+ is (State.sNone) {
when (io.in) {
- state := sOne1
+ state := State.sOne1
}
}
- is (sOne1) {
+ is (State.sOne1) {
when (io.in) {
- state := sTwo1s
+ state := State.sTwo1s
} .otherwise {
- state := sNone
+ state := State.sNone
}
}
- is (sTwo1s) {
+ is (State.sTwo1s) {
when (!io.in) {
- state := sNone
+ state := State.sNone
}
}
}