From bf1cf3d2db49195d031f89594baebcc9f307659e Mon Sep 17 00:00:00 2001 From: Jack Koenig Date: Fri, 16 Apr 2021 11:41:07 -0700 Subject: Make InferTypes error on enable conditions > 1-bit wide (#2182) --- src/test/scala/firrtlTests/CheckSpec.scala | 100 +++++++++++++++++++++++ src/test/scala/firrtlTests/LowerTypesSpec.scala | 4 +- src/test/scala/firrtlTests/ReplSeqMemTests.scala | 4 +- 3 files changed, 104 insertions(+), 4 deletions(-) (limited to 'src/test') diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index a3efc784..547639d6 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -86,6 +86,106 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } + behavior.of("Check Types") + + def runCheckTypes(input: String) = { + val passes = List(InferTypes, CheckTypes) + val wrapped = "circuit test:\n module test:\n " + input.replaceAll("\n", "\n ") + passes.foldLeft(Parser.parse(wrapped)) { case (c, p) => p.run(c) } + } + + it should "disallow mux enable conditions that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input foo : UInt<8> + |input bar : UInt<8> + |node x = mux(en, foo, bar)""".stripMargin + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow when predicates that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input foo : UInt<8> + |input bar : UInt<8> + |output out : UInt<8> + |when en : + | out <= foo + |else: + | out <= bar""".stripMargin + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow print enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input clock : Clock + |printf(clock, en, "Hello World!\\n")""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow stop enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input clock : Clock + |stop(clock, en, 0)""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow verif node predicates that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input cond : UInt<1> + |input clock : Clock + |assert(clock, en, cond, "Howdy!")""".stripMargin + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow verif node enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : UInt<1> + |input cond : $tpe + |input clock : Clock + |assert(clock, en, cond, "Howdy!")""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + "Instance loops a -> b -> a" should "be detected" in { val input = """ diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 78d03e68..6e774d18 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -466,10 +466,10 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] | output a_0_b : UInt<1> | input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2] - | a_0_b <= mux(a[UInt(0)].c_1_e, or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e))) + | a_0_b <= mux(bits(a[UInt(0)].c_1_e, 0, 0), or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e))) """.stripMargin val expected = Seq( - "a_0_b <= mux(a___0_c_1_e, or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))" + "a_0_b <= mux(bits(a___0_c_1_e, 0, 0), or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))" ) executeTest(input, expected) diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index d21f80c8..2156e392 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -422,7 +422,7 @@ circuit CustomMemory : circuit CustomMemory : module CustomMemory : input clock : Clock - output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] } smem mem : UInt<8>[2][1024] read mport r = mem[io.raddr], clock @@ -452,7 +452,7 @@ circuit CustomMemory : circuit CustomMemory : module CustomMemory : input clock : Clock - output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] } io.out is invalid -- cgit v1.2.3