diff options
| -rw-r--r-- | chiselFrontend/src/main/scala/chisel3/core/Data.scala | 2 | ||||
| -rw-r--r-- | chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala | 104 | ||||
| -rw-r--r-- | src/test/scala/chiselTests/StrongEnum.scala | 213 |
3 files changed, 251 insertions, 68 deletions
diff --git a/chiselFrontend/src/main/scala/chisel3/core/Data.scala b/chiselFrontend/src/main/scala/chisel3/core/Data.scala index 64c84c05..7cf005c0 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/Data.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/Data.scala @@ -266,7 +266,7 @@ abstract class Data extends HasId with NamedComponent with SourceInfoDoc { // sc // perform checks in Chisel, where more informative error messages are possible. private var _binding: Option[Binding] = None // Only valid after node is bound (synthesizable), crashes otherwise - protected def binding: Option[Binding] = _binding + protected[core] def binding: Option[Binding] = _binding protected def binding_=(target: Binding) { if (_binding.isDefined) { throw Binding.RebindingException(s"Attempted reassignment of binding to $this") diff --git a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala index f9414901..3439cf16 100644 --- a/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala +++ b/chiselFrontend/src/main/scala/chisel3/core/StrongEnum.scala @@ -5,7 +5,6 @@ package chisel3.core import scala.language.experimental.macros import scala.reflect.macros.blackbox.Context import scala.collection.mutable - import chisel3.internal.Builder.pushOp import chisel3.internal.firrtl.PrimOp._ import chisel3.internal.firrtl._ @@ -15,6 +14,11 @@ import firrtl.annotations._ object EnumAnnotations { + /** An annotation for strong enum instances that are ''not'' inside of Vecs + * + * @param target the enum instance being annotated + * @param typeName the name of the enum's type (e.g. ''"mypackage.MyEnum"'') + */ case class EnumComponentAnnotation(target: Named, enumTypeName: String) extends SingleTargetAnnotation[Named] { def duplicate(n: Named): EnumComponentAnnotation = this.copy(target = n) } @@ -23,16 +27,50 @@ object EnumAnnotations { def toFirrtl: EnumComponentAnnotation = EnumComponentAnnotation(target.toNamed, enumTypeName) } - case class EnumDefAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation + /** An annotation for Vecs of strong enums. + * + * The ''fields'' parameter deserves special attention, since it may be difficult to understand. Suppose you create a the following Vec: + + * {{{ + * VecInit(new Bundle { + * val e = MyEnum() + * val b = new Bundle { + * val inner_e = MyEnum() + * } + * val v = Vec(3, MyEnum()) + * } + * }}} + * + * Then, the ''fields'' parameter will be: ''Seq(Seq("e"), Seq("b", "inner_e"), Seq("v"))''. Note that for any Vec that doesn't contain Bundles, this field will simply be an empty Seq. + * + * @param target the Vec being annotated + * @param typeName the name of the enum's type (e.g. ''"mypackage.MyEnum"'') + * @param fields a list of all chains of elements leading from the Vec instance to its inner enum fields. + * + */ + case class EnumVecAnnotation(target: Named, typeName: String, fields: Seq[Seq[String]]) extends SingleTargetAnnotation[Named] { + def duplicate(n: Named) = this.copy(target = n) + } + + case class EnumVecChiselAnnotation(target: InstanceId, typeName: String, fields: Seq[Seq[String]]) extends ChiselAnnotation { + override def toFirrtl = EnumVecAnnotation(target.toNamed, typeName, fields) + } + + /** An annotation for enum types (rather than enum ''instances''). + * + * @param typeName the name of the enum's type (e.g. ''"mypackage.MyEnum"'') + * @param definition a map describing which integer values correspond to which enum names + */ + case class EnumDefAnnotation(typeName: String, definition: Map[String, BigInt]) extends NoTargetAnnotation - case class EnumDefChiselAnnotation(enumTypeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation { - override def toFirrtl: Annotation = EnumDefAnnotation(enumTypeName, definition) + case class EnumDefChiselAnnotation(typeName: String, definition: Map[String, BigInt]) extends ChiselAnnotation { + override def toFirrtl: Annotation = EnumDefAnnotation(typeName, definition) } } import EnumAnnotations._ -abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = false) extends Element { +abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolean = true) extends Element { override def toString: String = { val bindingString = litOption match { case Some(value) => factory.nameOfValue(value) match { @@ -119,18 +157,53 @@ abstract class EnumType(private val factory: EnumFactory, selfAnnotating: Boolea lit.bindLitArg(this) } - override def bind(target: Binding, parentDirection: SpecifiedDirection): Unit = { + override private[chisel3] def bind(target: Binding, parentDirection: SpecifiedDirection = SpecifiedDirection.Unspecified): Unit = { super.bind(target, parentDirection) - // If we try to annotate something that is bound to a literal, we get a FIRRTL annotation exception. - // To workaround that, we only annotate enums that are not bound to literals. + // Make sure we only annotate hardware and not literals if (selfAnnotating && litOption.isEmpty) { annotateEnum() } } + // This function conducts a depth-wise search to find all enum-type fields within a vector or bundle (or vector of bundles) + private def enumFields(d: Aggregate): Seq[Seq[String]] = d match { + case v: Vec[_] => v.sample_element match { + case b: Bundle => enumFields (b) + case _ => Seq () + } + case b: Bundle => + b.elements.collect { + case (name, e: EnumType) if this.typeEquivalent(e) => Seq(Seq(name)) + case (name, v: Vec[_]) if this.typeEquivalent(v.sample_element) => Seq(Seq(name)) + case (name, b2: Bundle) => enumFields(b2).map(name +: _) + }.flatten.toSeq + } + + private def outerMostVec(d: Data = this): Option[Vec[_]] = { + val currentVecOpt = d match { + case v: Vec[_] => Some(v) + case _ => None + } + + d.binding match { + case Some(ChildBinding(parent)) => outerMostVec(parent) match { + case outer @ Some(_) => outer + case None => currentVecOpt + } + case _ => currentVecOpt + } + } + private def annotateEnum(): Unit = { - annotate(EnumComponentChiselAnnotation(this, enumTypeName)) + val anno = outerMostVec() match { + case Some(v) => EnumVecChiselAnnotation(v, enumTypeName, enumFields(v)) + case None => EnumComponentChiselAnnotation(this, enumTypeName) + } + + if (!Builder.annotations.contains(anno)) { + annotate(anno) + } if (!Builder.annotations.contains(factory.globalAnnotation)) { annotate(factory.globalAnnotation) @@ -209,19 +282,16 @@ abstract class EnumFactory { def apply(n: UInt)(implicit sourceInfo: SourceInfo, connectionCompileOptions: CompileOptions): Type = { // scalastyle:off line.size.limit if (n.litOption.isDefined) { - val result = enumInstances.find(_.litValue == n.litValue) - - if (result.isEmpty) { - throwException(s"${n.litValue}.U is not a valid value for $enumTypeName") - } else { - result.get + enumInstances.find(_.litValue == n.litValue) match { + case Some(result) => result + case None => throwException(s"${n.litValue} is not a valid value for $enumTypeName") } } else if (!n.isWidthKnown) { throwException(s"Non-literal UInts being cast to $enumTypeName must have a defined width") } else if (n.getWidth > this.getWidth) { throwException(s"The UInt being cast to $enumTypeName is wider than $enumTypeName's width ($getWidth)") } else { - Builder.warning(s"A non-literal UInt is being cast to $enumTypeName. You can check that its value is legal by calling isValid") + Builder.warning(s"Casting non-literal UInt to $enumTypeName. You can check that its value is legal by calling isValid") val glue = Wire(new UnsafeEnum(width)) glue := n @@ -254,7 +324,7 @@ private[core] object EnumMacros { val names = c.enclosingClass.collect { case ValDef(_, name, _, rhs) - if rhs.pos == c.macroApplication.pos => name.decoded + if rhs.pos == c.macroApplication.pos => name.decodedName.toString } if (names.isEmpty) { diff --git a/src/test/scala/chiselTests/StrongEnum.scala b/src/test/scala/chiselTests/StrongEnum.scala index f5c3bc2f..6c87aee3 100644 --- a/src/test/scala/chiselTests/StrongEnum.scala +++ b/src/test/scala/chiselTests/StrongEnum.scala @@ -367,64 +367,177 @@ class StrongEnumSpec extends ChiselFlatSpec { "foo(OtherEnum.otherEnum)" shouldNot compile } + it should "prevent enums from being declared without names" in { + "object UnnamedEnum extends ChiselEnum { Value }" shouldNot compile + } + "StrongEnum FSM" should "work" in { assertTesterPasses(new StrongEnumFSMTester) } } -class StrongEnumAnnotationSpec extends ChiselFlatSpec { +class StrongEnumAnnotator extends Module { + import EnumExample._ + + val io = IO(new Bundle{ + val in = Input(EnumExample()) + val out = Output(EnumExample()) + val other = Output(OtherEnum()) + }) + + class Bund extends Bundle { + val field = EnumExample() + val other = OtherEnum() + val vec = Vec(5, EnumExample()) + val inner_bundle1 = new Bundle { + val x = UInt(4.W) + val y = Vec(3, UInt(4.W)) + val e = EnumExample() + val v = Vec(3, EnumExample()) + } + val inner_bundle2 = new Bundle {} + val inner_bundle3 = new Bundle { + val x = Bool() + } + val inner_bundle4 = new Bundle { + val inner_inner_bundle = new Bundle {} + } + } + + val simple = Wire(EnumExample()) + val vec = VecInit(e0, e1, e2) + val vec_of_vecs = VecInit(VecInit(e0, e1), VecInit(e100, e101)) + + val bund = Wire(new Bund()) + val vec_of_bundles = Wire(Vec(5, new Bund())) + + io.out := e101 + io.other := OtherEnum.otherEnum + simple := e100 + bund := DontCare + vec_of_bundles := DontCare + + // Make sure that dynamically indexing into a Vec of enums will not cause an elaboration error. + // The components created here will not be annotated. + val cycle = RegInit(0.U) + cycle := cycle + 1.U + + val indexed1 = vec_of_vecs(cycle)(cycle) + val indexed2 = vec_of_bundles(cycle) +} + +class StrongEnumAnnotationSpec extends FreeSpec with Matchers { import chisel3.experimental.EnumAnnotations._ - import firrtl.annotations.ComponentName - - ignore should "Test that strong enums annotate themselves appropriately" in { - - // scalastyle:off regex - def test(): Assertion = {// scalastyle:ignore cyclomatic.complexity - Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumFSM) match { - case ChiselExecutionSuccess(Some(circuit), emitted, _) => - val annos = circuit.annotations.map(_.toFirrtl) - - val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } - val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } - - // Print the annotations out onto the screen - println("Enum definitions:") - enumDefAnnos.foreach { - case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition") - } - println("Enum components:") - enumCompAnnos.foreach{ - case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName") - } - - // Check that the global annotation is correct - enumDefAnnos.exists { - case EnumDefAnnotation(name, map) => - name.endsWith("State") && - map.size == StrongEnumFSM.State.correct_annotation_map.size && - map.forall { - case (k, v) => - val correctValue = StrongEnumFSM.State.correct_annotation_map(k) - correctValue == v - } - case _ => false - } should be(true) - - // Check that the component annotations are correct - enumCompAnnos.count { - case EnumComponentAnnotation(target, enumName) => - val ComponentName(targetName, _) = target - (targetName == "state" && enumName.endsWith("State")) || - (targetName == "io.state" && enumName.endsWith("State")) - case _ => false - } should be(2) - - case _ => - assert(false) - } - // scalastyle:on regex + import firrtl.annotations.{ComponentName, Annotation} + + val enumExampleName = "EnumExample" + val otherEnumName = "OtherEnum" + + case class CorrectDefAnno(typeName: String, definition: Map[String, BigInt]) + case class CorrectCompAnno(targetName: String, typeName: String) + case class CorrectVecAnno(targetName: String, typeName: String, fields: Set[Seq[String]]) + + val correctDefAnnos = Seq( + CorrectDefAnno(otherEnumName, Map("otherEnum" -> 0)), + CorrectDefAnno(enumExampleName, Map("e0" -> 0, "e1" -> 1, "e2" -> 2, "e100" -> 100, "e101" -> 101)) + ) + + val correctCompAnnos = Seq( + CorrectCompAnno("io.other", otherEnumName), + CorrectCompAnno("io.out", enumExampleName), + CorrectCompAnno("io.in", enumExampleName), + CorrectCompAnno("simple", enumExampleName), + CorrectCompAnno("bund.field", enumExampleName), + CorrectCompAnno("bund.other", otherEnumName), + CorrectCompAnno("bund.inner_bundle1.e", enumExampleName) + ) + + val correctVecAnnos = Seq( + CorrectVecAnno("vec", enumExampleName, Set()), + CorrectVecAnno("vec_of_vecs", enumExampleName, Set()), + CorrectVecAnno("vec_of_bundles", enumExampleName, Set(Seq("field"), Seq("vec"), Seq("inner_bundle1", "e"), Seq("inner_bundle1", "v"))), + CorrectVecAnno("vec_of_bundles", otherEnumName, Set(Seq("other"))), + CorrectVecAnno("bund.vec", enumExampleName, Set()), + CorrectVecAnno("bund.inner_bundle1.v", enumExampleName, Set()) + ) + + // scalastyle:off regex + def printAnnos(annos: Seq[Annotation]) { + println("Enum definitions:") + annos.foreach { + case EnumDefAnnotation(enumTypeName, definition) => println(s"\t$enumTypeName: $definition") + case _ => + } + println("Enum components:") + annos.foreach{ + case EnumComponentAnnotation(target, enumTypeName) => println(s"\t$target => $enumTypeName") + case _ => } + println("Enum vecs:") + annos.foreach{ + case EnumVecAnnotation(target, enumTypeName, fields) => println(s"\t$target[$fields] => $enumTypeName") + case _ => + } + } + // scalastyle:on regex + + def isCorrect(anno: EnumDefAnnotation, correct: CorrectDefAnno): Boolean = { + (anno.typeName == correct.typeName || + anno.typeName.endsWith("." + correct.typeName)) && + anno.definition == correct.definition + } + + def isCorrect(anno: EnumComponentAnnotation, correct: CorrectCompAnno): Boolean = { + (anno.target match { + case ComponentName(name, _) => name == correct.targetName + case _ => throw new Exception("Unknown target type in EnumComponentAnnotation") + }) && + (anno.enumTypeName == correct.typeName || anno.enumTypeName.endsWith("." + correct.typeName)) + } + + def isCorrect(anno: EnumVecAnnotation, correct: CorrectVecAnno): Boolean = { + (anno.target match { + case ComponentName(name, _) => name == correct.targetName + case _ => throw new Exception("Unknown target type in EnumVecAnnotation") + }) && + (anno.typeName == correct.typeName || anno.typeName.endsWith("." + correct.typeName)) && + anno.fields.map(_.toSeq).toSet == correct.fields + } + + def allCorrectDefs(annos: Seq[EnumDefAnnotation], corrects: Seq[CorrectDefAnno]): Boolean = { + corrects.forall(c => annos.exists(isCorrect(_, c))) && + correctDefAnnos.length == annos.length + } + + // Because temporary variables might be formed and annotated, we do not check that every component or vector + // annotation is accounted for in the correct results listed above + def allCorrectComps(annos: Seq[EnumComponentAnnotation], corrects: Seq[CorrectCompAnno]): Boolean = + corrects.forall(c => annos.exists(isCorrect(_, c))) + + def allCorrectVecs(annos: Seq[EnumVecAnnotation], corrects: Seq[CorrectVecAnno]): Boolean = + corrects.forall(c => annos.exists(isCorrect(_, c))) + + def test() { + Driver.execute(Array("--target-dir", "test_run_dir"), () => new StrongEnumAnnotator) match { + case ChiselExecutionSuccess(Some(circuit), emitted, _) => + val annos = circuit.annotations.map(_.toFirrtl) + + printAnnos(annos) + + val enumDefAnnos = annos.collect { case a: EnumDefAnnotation => a } + val enumCompAnnos = annos.collect { case a: EnumComponentAnnotation => a } + val enumVecAnnos = annos.collect { case a: EnumVecAnnotation => a } + + allCorrectDefs(enumDefAnnos, correctDefAnnos) should be(true) + allCorrectComps(enumCompAnnos, correctCompAnnos) should be(true) + allCorrectVecs(enumVecAnnos, correctVecAnnos) should be(true) + + case _ => + assert(false) + } + } + "Test that strong enums annotate themselves appropriately" in { // We run this test twice, to test for an older bug where only the first circuit would be annotated test() test() |
