From 600405254c20c14fb3389aa4758ec27dffe992d0 Mon Sep 17 00:00:00 2001 From: Hasan Genc Date: Fri, 12 Oct 2018 12:52:48 -0700 Subject: Strong enums (#892) * Added new strongly-typed enum construct called "StrongEnum". "StrongEnum" will automatically generate annotations that HDL backends can use to mark components as enums Removed "override val width" constructor parameter from "Element" so that classes with variable widths, like the new strong enums, can inherit from it Changed the parameter types of certain functions, such as "switch", "is", and "LitArg.bindLitArg" from "Bits" to "Element", so that they can take the new strong enums as arguments * Added tests for the new strong enums * Changed StrongEnum exception names and made sure in StrongEnum tests that the correct types of exceptions are thrown * Fixed bug where an enum's global annotation would not be set if it was used in multiple circuits Made styling changes to StrongEnum.scala * Reverted accidental changes to the AnnotatingDiamond test * Changed the API for casting non-literal UInts to enums Added an isValid function that checks whether or not enums have valid values Calling getWidth on an enum's companion object now returns a BigInt instead of an Int * Casting a literal to an enum using the StrongEnum.castFromNonLit(n) function is now simply a wrapper for StrongEnum.apply(n) * Fixed compilation bug * * Added "next" method to EnumType * Renamed "castFromNonLit" to "fromBits" * The FSM example in the test/scala/cookbook now uses StrongEnums * * Changed strong enum API, so that users no longer have to declare both a class and a companion object for each strong enum * Strong enums do not have to be static any longer * * Added scope protections to ChiselEnum.Value so that users cannot call it outside of a ChiselEnum definition * Renamed ChiselEnum.Value type to ChiselEnum.Type so that we can give it a companion object just like UInt and Bool do * * Moved strong enums into experimental package * Non-literal UInts can now be cast to enums with apply() rather than fromBits() * Reduced code-duplication by moving some functions from EnumType and Bits to Element --- .../src/main/scala/chisel3/core/Bits.scala | 66 ++-- .../src/main/scala/chisel3/core/Clock.scala | 2 +- .../src/main/scala/chisel3/core/Data.scala | 4 +- .../src/main/scala/chisel3/core/MonoConnect.scala | 6 + .../src/main/scala/chisel3/core/StrongEnum.scala | 248 ++++++++++++ .../main/scala/chisel3/internal/firrtl/IR.scala | 8 +- .../scala/chisel3/internal/firrtl/Converter.scala | 3 +- .../scala/chisel3/internal/firrtl/Emitter.scala | 3 +- src/main/scala/chisel3/package.scala | 3 + src/main/scala/chisel3/util/Conditional.scala | 10 +- src/test/scala/chiselTests/StrongEnum.scala | 430 +++++++++++++++++++++ src/test/scala/cookbook/FSM.scala | 27 +- 12 files changed, 753 insertions(+), 57 deletions(-) create mode 100644 chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala create mode 100644 src/test/scala/chiselTests/StrongEnum.scala diff --git a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala index e9458446..9356a91c 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Bits.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Bits.scala @@ -19,7 +19,11 @@ import chisel3.internal.firrtl.PrimOp._ * * @define coll element */ -abstract class Element(private[chisel3] val width: Width) extends Data { +abstract class Element extends Data { + private[chisel3] final def allElements: Seq[Element] = Seq(this) + def widthKnown: Boolean = width.known + def name: String = getRef.name + private[chisel3] override def bind(target: Binding, parentDirection: SpecifiedDirection) { binding = target val resolvedDirection = SpecifiedDirection.fromParent(parentDirection, specifiedDirection) @@ -30,9 +34,32 @@ abstract class Element(private[chisel3] val width: Width) extends Data { } } - private[chisel3] final def allElements: Seq[Element] = Seq(this) - def widthKnown: Boolean = width.known - def name: String = getRef.name + private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { + // Translate Bundle lit bindings to Element lit bindings + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => Some(ElementLitBinding(litArg)) + case _ => Some(DontCareBinding()) + } + case topBindingOpt => topBindingOpt + } + + private[core] def litArgOption: Option[LitArg] = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => Some(litArg) + case _ => None + } + + override def litOption: Option[BigInt] = litArgOption.map(_.num) + private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) + + // provide bits-specific literal handling functionality here + override private[chisel3] def ref: Arg = topBindingOpt match { + case Some(ElementLitBinding(litArg)) => litArg + case Some(BundleLitBinding(litMap)) => litMap.get(this) match { + case Some(litArg) => litArg + case _ => throwException(s"internal error: DontCare should be caught before getting ref") + } + case _ => super.ref + } private[core] def legacyConnect(that: Data)(implicit sourceInfo: SourceInfo): Unit = { // If the source is a DontCare, generate a DefInvalid for the sink, @@ -69,7 +96,7 @@ private[chisel3] sealed trait ToBoolable extends Element { * @define sumWidth @note The width of the returned $coll is `width of this` + `width of that`. * @define unchangedWidth @note The width of the returned $coll is unchanged, i.e., the `width of this`. */ -sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable { //scalastyle:off number.of.methods +sealed abstract class Bits(private[chisel3] val width: Width) extends Element with ToBoolable { //scalastyle:off number.of.methods // TODO: perhaps make this concrete? // Arguments for: self-checking code (can't do arithmetic on bits) // Arguments against: generates down to a FIRRTL UInt anyways @@ -79,33 +106,6 @@ sealed abstract class Bits(width: Width) extends Element(width) with ToBoolable def cloneType: this.type = cloneTypeWidth(width) - private[core] override def topBindingOpt: Option[TopBinding] = super.topBindingOpt match { - // Translate Bundle lit bindings to Element lit bindings - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => Some(ElementLitBinding(litArg)) - case _ => Some(DontCareBinding()) - } - case topBindingOpt => topBindingOpt - } - - private[core] def litArgOption: Option[LitArg] = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => Some(litArg) - case _ => None - } - - override def litOption: Option[BigInt] = litArgOption.map(_.num) - private[core] def litIsForcedWidth: Option[Boolean] = litArgOption.map(_.forcedWidth) - - // provide bits-specific literal handling functionality here - override private[chisel3] def ref: Arg = topBindingOpt match { - case Some(ElementLitBinding(litArg)) => litArg - case Some(BundleLitBinding(litMap)) => litMap.get(this) match { - case Some(litArg) => litArg - case _ => throwException(s"internal error: DontCare should be caught before getting ref") - } - case _ => super.ref - } - /** Tail operator * * @param n the number of bits to remove @@ -1693,7 +1693,7 @@ object FixedPoint { * * @note This API is experimental and subject to change */ -final class Analog private (width: Width) extends Element(width) { +final class Analog private (private[chisel3] val width: Width) extends Element { require(width.known, "Since Analog is only for use in BlackBoxes, width must be known") private[core] override def typeEquivalent(that: Data): Boolean = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala index b728075b..88208d9a 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Clock.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Clock.scala @@ -12,7 +12,7 @@ object Clock { } // TODO: Document this. -sealed class Clock extends Element(Width(1)) { +sealed class Clock(private[chisel3] val width: Width = Width(1)) extends Element { def cloneType: this.type = Clock().asInstanceOf[this.type] private[core] def typeEquivalent(that: Data): Boolean = diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index 869e22fb..f292d3c6 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -533,10 +533,12 @@ object WireInit { /** RHS (source) for Invalidate API. * Causes connection logic to emit a DefInvalid when connected to an output port (or wire). */ -object DontCare extends Element(width = UnknownWidth()) { +object DontCare extends Element { // This object should be initialized before we execute any user code that refers to it, // otherwise this "Chisel" object will end up on the UserModule's id list. + private[chisel3] override val width: Width = UnknownWidth() + bind(DontCareBinding(), SpecifiedDirection.Output) override def cloneType = DontCare diff --git a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala index eba24870..c9420ba7 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/MonoConnect.scala @@ -79,6 +79,12 @@ object MonoConnect { elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) case (sink_e: Clock, source_e: Clock) => elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: EnumType, source_e: UnsafeEnum) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: EnumType, source_e: EnumType) if sink_e.typeEquivalent(source_e) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) + case (sink_e: UnsafeEnum, source_e: UInt) => + elemConnect(sourceInfo, connectCompileOptions, sink_e, source_e, context_mod) // Handle Vec case case (sink_v: Vec[Data @unchecked], source_v: Vec[Data @unchecked]) => diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala new file mode 100644 index 00000000..a9f51387 --- /dev/null +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -0,0 +1,248 @@ +// See LICENSE for license details. + +package chisel3.core + +import scala.language.experimental.macros +import scala.reflect.macros.blackbox.Context +import scala.collection.mutable + +import chisel3.internal.Builder.pushOp +import chisel3.internal.firrtl.PrimOp._ +import chisel3.internal.firrtl._ +import chisel3.internal.sourceinfo._ +import chisel3.internal.{Builder, InstanceId, throwException} +import firrtl.annotations._ + + +object EnumAnnotations { + case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] { + def duplicate(n: Named) = this.copy(target = n) + } + + case class EnumComponentChiselAnnotation(target: InstanceId, enumTypeName: String) extends ChiselAnnotation { + def toFirrtl = EnumComponentAnnotation(target.toNamed, enumTypeName) + } + + case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation + + case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation { + override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition) + } +} +import EnumAnnotations._ + + +abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element { + override def cloneType: this.type = factory().asInstanceOf[this.type] + + private[core] def compop(sourceInfo: SourceInfo, op: PrimOp, other: EnumType): Bool = { + requireIsHardware(this, "bits operated on") + requireIsHardware(other, "bits operated on") + + if(!this.typeEquivalent(other)) + throwException(s"Enum types are not equivalent: ${this.enumTypeName}, ${other.enumTypeName}") + + pushOp(DefPrim(sourceInfo, Bool(), op, this.ref, other.ref)) + } + + private[core] override def typeEquivalent(that: Data): Boolean = { + this.getClass == that.getClass && + this.factory == that.asInstanceOf[EnumType].factory + } + + // This isn't actually used anywhere (and it would throw an exception anyway). But it has to be defined since we + // inherit it from Data. + private[core] override def connectFromBits(that: Bits)(implicit sourceInfo: SourceInfo, + compileOptions: CompileOptions): Unit = ??? + + final def === (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def =/= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def < (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def <= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def > (that: EnumType): Bool = macro SourceInfoTransform.thatArg + final def >= (that: EnumType): Bool = macro SourceInfoTransform.thatArg + + def do_=== (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, EqualOp, that) + def do_=/= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, NotEqualOp, that) + def do_< (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessOp, that) + def do_> (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterOp, that) + def do_<= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, LessEqOp, that) + def do_>= (that: EnumType)(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = compop(sourceInfo, GreaterEqOp, that) + + override def do_asUInt(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = + pushOp(DefPrim(sourceInfo, UInt(width), AsUIntOp, ref)) + + protected[chisel3] override def width: Width = factory.width + + def isValid(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): Bool = { + if (litOption.isDefined) { + true.B + } else { + factory.all.map(this === _).reduce(_ || _) + } + } + + def next(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): this.type = { + if (litOption.isDefined) { + val index = factory.all.indexOf(this) + + if (index < factory.all.length-1) + factory.all(index+1).asInstanceOf[this.type] + else + factory.all.head.asInstanceOf[this.type] + } else { + val enums_with_nexts = factory.all zip (factory.all.tail :+ factory.all.head) + val next_enum = SeqUtils.priorityMux(enums_with_nexts.map { case (e,n) => (this === e, n) } ) + next_enum.asInstanceOf[this.type] + } + } + + private[core] def bindToLiteral(num: BigInt, w: Width): Unit = { + val lit = ULit(num, w) + lit.bindLitArg(this) + } + + override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = { + super.bind(target, parentDirection) + + // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception. + // To workaround that, we only annotate enums that are not bound to literals. + if (selfAnnotating && litOption.isEmpty) { + annotateEnum() + } + } + + private def annotateEnum(): Unit = { + annotate(EnumComponentChiselAnnotation(this, enumTypeName)) + + if (!Builder.annotations.contains(factory.globalAnnotation)) { + annotate(factory.globalAnnotation) + } + } + + protected def enumTypeName: String = factory.enumTypeName + + def toPrintable: Printable = FullName(this) // TODO: Find a better pretty printer +} + + +abstract class EnumFactory { + class Type extends EnumType(this) + object Type { + def apply(): Type = EnumFactory.this.apply() + } + + private var id: BigInt = 0 + private[core] var width: Width = 0.W + + private case class EnumRecord(inst: Type, name: String) + private val enum_records = mutable.ArrayBuffer.empty[EnumRecord] + + private def enumNames = enum_records.map(_.name).toSeq + private def enumValues = enum_records.map(_.inst.litValue()).toSeq + private def enumInstances = enum_records.map(_.inst).toSeq + + private[core] val enumTypeName = getClass.getName.init + + private[core] def globalAnnotation: EnumDefChiselAnnotation = + EnumDefChiselAnnotation(enumTypeName, (enumNames, enumValues).zipped.toMap) + + def getWidth: Int = width.get + + def all: Seq[Type] = enumInstances + + protected def Value: Type = macro EnumMacros.ValImpl + protected def Value(id: UInt): Type = macro EnumMacros.ValCustomImpl + + protected def do_Value(names: Seq[String]): Type = { + val result = new Type + + // We have to use UnknownWidth here, because we don't actually know what the final width will be + result.bindToLiteral(id, UnknownWidth()) + + val result_name = names.find(!enumNames.contains(_)).get + enum_records.append(EnumRecord(result, result_name)) + + width = (1 max id.bitLength).W + id += 1 + + result + } + + protected def do_Value(names: Seq[String], id: UInt): Type = { + // TODO: These throw ExceptionInInitializerError which can be confusing to the user. Get rid of the error, and just + // throw an exception + if (id.litOption.isEmpty) + throwException(s"$enumTypeName defined with a non-literal type") + if (id.litValue() < this.id) + throwException(s"Enums must be strictly increasing: $enumTypeName") + + this.id = id.litValue() + do_Value(names) + } + + def apply(): Type = new Type + + def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { + if (n.litOption.isDefined) { + val result = enumInstances.find(_.litValue == n.litValue) + + if (result.isEmpty) { + throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") + } else { + result.get + } + } else if (!n.isWidthKnown) { + throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width") + } else if (n.getWidth > this.getWidth) { + throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)") + } else { + Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that its value is legal by calling isValid") + + val glue = Wire(new UnsafeEnum(width)) + glue := n + val result = Wire(new Type) + result := glue + result + } + } +} + + +private[core] object EnumMacros { + def ValImpl(c: Context) : c.Tree = { + import c.universe._ + val names = getNames(c) + q"""this.do_Value(Seq(..$names))""" + } + + def ValCustomImpl(c: Context)(id: c.Expr[UInt]) = { + import c.universe._ + val names = getNames(c) + q"""this.do_Value(Seq(..$names), $id)""" + } + + // Much thanks to Travis Brown for this solution: + // stackoverflow.com/questions/18450203/retrieve-the-name-of-the-value-a-scala-macro-invocation-will-be-assigned-to + def getNames(c: Context): Seq[String] = { + import c.universe._ + + val names = c.enclosingClass.collect { + case ValDef(_, name, _, rhs) + if rhs.pos == c.macroApplication.pos => name.decoded + } + + if (names.isEmpty) + c.abort(c.enclosingPosition, "Value cannot be called without assigning to an enum") + + names + } +} + + +// This is an enum type that can be connected directly to UInts. It is used as a "glue" to cast non-literal UInts +// to enums. +private[chisel3] class UnsafeEnum(override val width: Width) extends EnumType(UnsafeEnum, selfAnnotating = false) { + override def cloneType: this.type = new UnsafeEnum(width).asInstanceOf[this.type] +} +private object UnsafeEnum extends EnumFactory diff --git a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala index b6630f7f..ae8b248a 100644 --- a/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala +++ b/chiselFrontend/src/main/scala/chisel3/internal/firrtl/IR.scala @@ -68,10 +68,10 @@ abstract class LitArg(val num: BigInt, widthArg: Width) extends Arg { private[chisel3] def width: Width = if (forcedWidth) widthArg else Width(minWidth) override def fullName(ctx: Component): String = name // Ensure the node representing this LitArg has a ref to it and a literal binding. - def bindLitArg[T <: Bits](bits: T): T = { - bits.bind(ElementLitBinding(this)) - bits.setRef(this) - bits + def bindLitArg[T <: Element](elem: T): T = { + elem.bind(ElementLitBinding(this)) + elem.setRef(this) + elem } protected def minWidth: Int diff --git a/src/main/scala/chisel3/internal/firrtl/Converter.scala b/src/main/scala/chisel3/internal/firrtl/Converter.scala index 97504aba..181bdfe8 100644 --- a/src/main/scala/chisel3/internal/firrtl/Converter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Converter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.core.SpecifiedDirection +import chisel3.core.{SpecifiedDirection, EnumType} import chisel3.experimental._ import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine, SourceInfo} import firrtl.{ir => fir} @@ -211,6 +211,7 @@ private[chisel3] object Converter { def extractType(data: Data, clearDir: Boolean = false): fir.Type = data match { case _: Clock => fir.ClockType + case d: EnumType => fir.UIntType(convert(d.width)) case d: UInt => fir.UIntType(convert(d.width)) case d: SInt => fir.SIntType(convert(d.width)) case d: FixedPoint => fir.FixedType(convert(d.width), convert(d.binaryPoint)) diff --git a/src/main/scala/chisel3/internal/firrtl/Emitter.scala b/src/main/scala/chisel3/internal/firrtl/Emitter.scala index 26ccc09d..ac4bf8e7 100644 --- a/src/main/scala/chisel3/internal/firrtl/Emitter.scala +++ b/src/main/scala/chisel3/internal/firrtl/Emitter.scala @@ -2,7 +2,7 @@ package chisel3.internal.firrtl import chisel3._ -import chisel3.core.SpecifiedDirection +import chisel3.core.{SpecifiedDirection, EnumType} import chisel3.experimental._ import chisel3.internal.sourceinfo.{NoSourceInfo, SourceLine} @@ -28,6 +28,7 @@ private class Emitter(circuit: Circuit) { private def emitType(d: Data, clearDir: Boolean = false): String = d match { case d: Clock => "Clock" + case d: chisel3.core.EnumType => s"UInt${d.width}" case d: UInt => s"UInt${d.width}" case d: SInt => s"SInt${d.width}" case d: FixedPoint => s"Fixed${d.width}${d.binaryPoint}" diff --git a/src/main/scala/chisel3/package.scala b/src/main/scala/chisel3/package.scala index b7c39bad..e79a1186 100644 --- a/src/main/scala/chisel3/package.scala +++ b/src/main/scala/chisel3/package.scala @@ -420,6 +420,9 @@ package object chisel3 { // scalastyle:ignore package.object.name val Analog = chisel3.core.Analog val attach = chisel3.core.attach + type ChiselEnum = chisel3.core.EnumFactory + val EnumAnnotations = chisel3.core.EnumAnnotations + val withClockAndReset = chisel3.core.withClockAndReset val withClock = chisel3.core.withClock val withReset = chisel3.core.withReset diff --git a/src/main/scala/chisel3/util/Conditional.scala b/src/main/scala/chisel3/util/Conditional.scala index bf2d4268..3630f8ad 100644 --- a/src/main/scala/chisel3/util/Conditional.scala +++ b/src/main/scala/chisel3/util/Conditional.scala @@ -24,7 +24,7 @@ object unless { // scalastyle:ignore object.name * user-facing API. * @note DO NOT USE. This API is subject to change without warning. */ -class SwitchContext[T <: Bits](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) { +class SwitchContext[T <: Element](cond: T, whenContext: Option[WhenContext], lits: Set[BigInt]) { def is(v: Iterable[T])(block: => Unit): SwitchContext[T] = { if (!v.isEmpty) { val newLits = v.map { w => @@ -60,19 +60,19 @@ object is { // scalastyle:ignore object.name // TODO: Begin deprecation of non-type-parameterized is statements. /** Executes `block` if the switch condition is equal to any of the values in `v`. */ - def apply(v: Iterable[Bits])(block: => Unit) { + def apply(v: Iterable[Element])(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } /** Executes `block` if the switch condition is equal to `v`. */ - def apply(v: Bits)(block: => Unit) { + def apply(v: Element)(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } /** Executes `block` if the switch condition is equal to any of the values in the argument list. */ - def apply(v: Bits, vr: Bits*)(block: => Unit) { + def apply(v: Element, vr: Element*)(block: => Unit) { require(false, "The 'is' keyword may not be used outside of a switch.") } } @@ -91,7 +91,7 @@ object is { // scalastyle:ignore object.name * }}} */ object switch { // scalastyle:ignore object.name - def apply[T <: Bits](cond: T)(x: => Unit): Unit = macro impl + def apply[T <: Element](cond: T)(x: => Unit): Unit = macro impl def impl(c: Context)(cond: c.Tree)(x: c.Tree): c.Tree = { import c.universe._ val q"..$body" = x val res = body.foldLeft(q"""new SwitchContext($cond, None, Set.empty)""") { diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala new file mode 100644 index 00000000..98286624 --- /dev/null +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -0,0 +1,430 @@ +// See LICENSE for license details. + +package chiselTests + +import chisel3._ +import chisel3.experimental.ChiselEnum +import chisel3.internal.firrtl.UnknownWidth +import chisel3.util._ +import chisel3.testers.BasicTester +import org.scalatest.{FreeSpec, Matchers} + +object EnumExample extends ChiselEnum { + val e0, e1, e2 = Value + + val e100 = Value(100.U) + val e101 = Value(101.U) + + val litValues = List(0.U, 1.U, 2.U, 100.U, 101.U) +} + +object OtherEnum extends ChiselEnum { + val otherEnum = Value +} + +object NonLiteralEnumType extends ChiselEnum { + val nonLit = Value(UInt()) +} + +object NonIncreasingEnum extends ChiselEnum { + val x = Value(2.U) + val y = Value(2.U) +} + +class SimpleConnector(inType: Data, outType: Data) extends Module { + val io = IO(new Bundle { + val in = Input(inType) + val out = Output(outType) + }) + + io.out := io.in +} + +class CastToUInt extends Module { + val io = IO(new Bundle { + val in = Input(EnumExample()) + val out = Output(UInt()) + }) + + io.out := io.in.asUInt() +} + +class CastFromLit(in: UInt) extends Module { + val io = IO(new Bundle { + val out = Output(EnumExample()) + val valid = Output(Bool()) + }) + + io.out := EnumExample(in) + io.valid := io.out.isValid +} + +class CastFromNonLit extends Module { + val io = IO(new Bundle { + val in = Input(UInt(EnumExample.getWidth.W)) + val out = Output(EnumExample()) + val valid = Output(Bool()) + }) + + io.out := EnumExample(io.in) + io.valid := io.out.isValid +} + +class CastFromNonLitWidth(w: Option[Int] = None) extends Module { + val width = if (w.isDefined) w.get.W else UnknownWidth() + + override val io = IO(new Bundle { + val in = Input(UInt(width)) + val out = Output(EnumExample()) + }) + + io.out := EnumExample(io.in) +} + +class EnumOps(val xType: ChiselEnum, val yType: ChiselEnum) extends Module { + val io = IO(new Bundle { + val x = Input(xType()) + val y = Input(yType()) + + val lt = Output(Bool()) + val le = Output(Bool()) + val gt = Output(Bool()) + val ge = Output(Bool()) + val eq = Output(Bool()) + val ne = Output(Bool()) + }) + + io.lt := io.x < io.y + io.le := io.x <= io.y + io.gt := io.x > io.y + io.ge := io.x >= io.y + io.eq := io.x === io.y + io.ne := io.x =/= io.y +} + +object StrongEnumFSM { + object State extends ChiselEnum { + val sNone, sOne1, sTwo1s = Value + + val correct_annotation_map = Map[String, BigInt]("sNone" -> 0, "sOne1" -> 1, "sTwo1s" -> 2) + } +} + +class StrongEnumFSM extends Module { + import StrongEnumFSM.State + import StrongEnumFSM.State._ + + // This FSM detects two 1's one after the other + val io = IO(new Bundle { + val in = Input(Bool()) + val out = Output(Bool()) + val state = Output(State()) + }) + + val state = RegInit(sNone) + + io.out := (state === sTwo1s) + io.state := state + + switch (state) { + is (sNone) { + when (io.in) { + state := sOne1 + } + } + is (sOne1) { + when (io.in) { + state := sTwo1s + } .otherwise { + state := sNone + } + } + is (sTwo1s) { + when (!io.in) { + state := sNone + } + } + } +} + +class CastToUIntTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastToUInt) + mod.io.in := enum + assert(mod.io.out === lit) + } + stop() +} + +class CastFromLitTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastFromLit(lit)) + assert(mod.io.out === enum) + assert(mod.io.valid === true.B) + } + stop() +} + +class CastFromNonLitTester extends BasicTester { + for ((enum,lit) <- EnumExample.all zip EnumExample.litValues) { + val mod = Module(new CastFromNonLit) + mod.io.in := lit + assert(mod.io.out === enum) + assert(mod.io.valid === true.B) + } + + val invalid_values = (1 until (1 << EnumExample.getWidth)). + filter(!EnumExample.litValues.map(_.litValue).contains(_)). + map(_.U) + + for (invalid_val <- invalid_values) { + val mod = Module(new CastFromNonLit) + mod.io.in := invalid_val + + assert(mod.io.valid === false.B) + } + + stop() +} + +class CastToInvalidEnumTester extends BasicTester { + val invalid_value: UInt = EnumExample.litValues.last + 1.U + Module(new CastFromLit(invalid_value)) +} + +class EnumOpsTester extends BasicTester { + for (x <- EnumExample.all; + y <- EnumExample.all) { + val mod = Module(new EnumOps(EnumExample, EnumExample)) + mod.io.x := x + mod.io.y := y + + assert(mod.io.lt === (x.asUInt() < y.asUInt())) + assert(mod.io.le === (x.asUInt() <= y.asUInt())) + assert(mod.io.gt === (x.asUInt() > y.asUInt())) + assert(mod.io.ge === (x.asUInt() >= y.asUInt())) + assert(mod.io.eq === (x.asUInt() === y.asUInt())) + assert(mod.io.ne === (x.asUInt() =/= y.asUInt())) + } + stop() +} + +class InvalidEnumOpsTester extends BasicTester { + val mod = Module(new EnumOps(EnumExample, OtherEnum)) + mod.io.x := EnumExample.e0 + mod.io.y := OtherEnum.otherEnum +} + +class IsLitTester extends BasicTester { + for (e <- EnumExample.all) { + val wire = WireInit(e) + + assert(e.isLit()) + assert(!wire.isLit()) + } + stop() +} + +class NextTester extends BasicTester { + for ((e,n) <- EnumExample.all.zip(EnumExample.litValues.tail :+ EnumExample.litValues.head)) { + assert(e.next.litValue == n.litValue) + val w = WireInit(e) + assert(w.next === EnumExample(n)) + } + stop() +} + +class WidthTester extends BasicTester { + assert(EnumExample.getWidth == EnumExample.litValues.last.getWidth) + assert(EnumExample.all.forall(_.getWidth == EnumExample.litValues.last.getWidth)) + assert(EnumExample.all.forall{e => + val w = WireInit(e) + w.getWidth == EnumExample.litValues.last.getWidth + }) + stop() +} + +class StrongEnumFSMTester extends BasicTester { + import StrongEnumFSM.State + import StrongEnumFSM.State._ + + val dut = Module(new StrongEnumFSM) + + // Inputs and expected results + val inputs: Vec[Bool] = VecInit(false.B, true.B, false.B, true.B, true.B, true.B, false.B, true.B, true.B, false.B) + val expected: Vec[Bool] = VecInit(false.B, false.B, false.B, false.B, false.B, true.B, true.B, false.B, false.B, true.B) + val expected_state = VecInit(sNone, sNone, sOne1, sNone, sOne1, sTwo1s, sTwo1s, sNone, sOne1, sTwo1s) + + val cntr = Counter(inputs.length) + val cycle = cntr.value + + dut.io.in := inputs(cycle) + assert(dut.io.out === expected(cycle)) + assert(dut.io.state === expected_state(cycle)) + + when(cntr.inc()) { + stop() + } +} + +class StrongEnumSpec extends ChiselFlatSpec { + import chisel3.internal.ChiselException + + behavior of "Strong enum tester" + + it should "fail to instantiate non-literal enums with the Value function" in { + an [ExceptionInInitializerError] should be thrownBy { + elaborate(new SimpleConnector(NonLiteralEnumType(), NonLiteralEnumType())) + } + } + + it should "fail to instantiate non-increasing enums with the Value function" in { + an [ExceptionInInitializerError] should be thrownBy { + elaborate(new SimpleConnector(NonIncreasingEnum(), NonIncreasingEnum())) + } + } + + it should "connect enums of the same type" in { + elaborate(new SimpleConnector(EnumExample(), EnumExample())) + elaborate(new SimpleConnector(EnumExample(), EnumExample.Type())) + } + + it should "fail to connect a strong enum to a UInt" in { + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample(), UInt())) + } + } + + it should "fail to connect enums of different types" in { + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample(), OtherEnum())) + } + + a [ChiselException] should be thrownBy { + elaborate(new SimpleConnector(EnumExample.Type(), OtherEnum.Type())) + } + } + + it should "cast enums to UInts correctly" in { + assertTesterPasses(new CastToUIntTester) + } + + it should "cast literal UInts to enums correctly" in { + assertTesterPasses(new CastFromLitTester) + } + + it should "cast non-literal UInts to enums correctly and detect illegal casts" in { + assertTesterPasses(new CastFromNonLitTester) + } + + it should "prevent illegal literal casts to enums" in { + a [ChiselException] should be thrownBy { + elaborate(new CastToInvalidEnumTester) + } + } + + it should "only allow non-literal casts to enums if the width is smaller than or equal to the enum width" in { + for (w <- 0 to EnumExample.getWidth) + elaborate(new CastFromNonLitWidth(Some(w))) + + a [ChiselException] should be thrownBy { + elaborate(new CastFromNonLitWidth) + } + + for (w <- (EnumExample.getWidth+1) to (EnumExample.getWidth+100)) { + a [ChiselException] should be thrownBy { + elaborate(new CastFromNonLitWidth(Some(w))) + } + } + } + + it should "execute enum comparison operations correctly" in { + assertTesterPasses(new EnumOpsTester) + } + + it should "fail to compare enums of different types" in { + a [ChiselException] should be thrownBy { + elaborate(new InvalidEnumOpsTester) + } + } + + it should "correctly check whether or not enums are literal" in { + assertTesterPasses(new IsLitTester) + } + + it should "return the correct next values for enums" in { + assertTesterPasses(new NextTester) + } + + it should "return the correct widths for enums" in { + assertTesterPasses(new WidthTester) + } + + it should "maintain Scala-level type-safety" in { + def foo(e: EnumExample.Type) = {} + + "foo(EnumExample.e1); foo(EnumExample.e1.next)" should compile + "foo(OtherEnum.otherEnum)" shouldNot compile + } + + "StrongEnum FSM" should "work" in { + assertTesterPasses(new StrongEnumFSMTester) + } +} + +class StrongEnumAnnotationSpec extends FreeSpec with Matchers { + import chisel3.experimental.EnumAnnotations._ + import firrtl.annotations.ComponentName + + "Test that strong enums annotate themselves appropriately" in { + + def test() = { + Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { + case ChiselExecutionSuccess(Some(circuit), emitted, _) => + val annos = circuit.annotations.map(_.toFirrtl) + + val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } + val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } + + // Print the annotations out onto the screen + println("Enum definitions:") + enumDefAnnos.foreach { + case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition") + } + println("Enum components:") + enumCompAnnos.foreach{ + case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName") + } + + // Check that the global annotation is correct + enumDefAnnos.exists { + case EnumDefAnnotation(name, map) => + name.endsWith("State") && + map.size == StrongEnumFSM.State.correct_annotation_map.size && + map.forall { + case (k, v) => + val correctValue = StrongEnumFSM.State.correct_annotation_map(k) + correctValue == v + } + case _ => false + } should be(true) + + // Check that the component annotations are correct + enumCompAnnos.count { + case EnumComponentAnnotation(target, enumName) => + val ComponentName(targetName, _) = target + (targetName == "state" && enumName.endsWith("State")) || + (targetName == "io.state" && enumName.endsWith("State")) + case _ => false + } should be(2) + + case _ => + assert(false) + } + } + + // We run this test twice, to test for an older bug where only the first circuit would be annotated + test() + test() + } +} diff --git a/src/test/scala/cookbook/FSM.scala b/src/test/scala/cookbook/FSM.scala index 22cf8059..170d110f 100644 --- a/src/test/scala/cookbook/FSM.scala +++ b/src/test/scala/cookbook/FSM.scala @@ -4,39 +4,44 @@ package cookbook import chisel3._ import chisel3.util._ +import chisel3.experimental.ChiselEnum /* ### How do I create a finite state machine? * - * Use Chisel Enum to construct the states and switch & is to construct the FSM + * Use Chisel StrongEnum to construct the states and switch & is to construct the FSM * control logic */ + class DetectTwoOnes extends Module { val io = IO(new Bundle { val in = Input(Bool()) val out = Output(Bool()) }) - val sNone :: sOne1 :: sTwo1s :: Nil = Enum(3) - val state = RegInit(sNone) + object State extends ChiselEnum { + val sNone, sOne1, sTwo1s = Value + } + + val state = RegInit(State.sNone) - io.out := (state === sTwo1s) + io.out := (state === State.sTwo1s) switch (state) { - is (sNone) { + is (State.sNone) { when (io.in) { - state := sOne1 + state := State.sOne1 } } - is (sOne1) { + is (State.sOne1) { when (io.in) { - state := sTwo1s + state := State.sTwo1s } .otherwise { - state := sNone + state := State.sNone } } - is (sTwo1s) { + is (State.sTwo1s) { when (!io.in) { - state := sNone + state := State.sNone } } } -- cgit v1.2.3