diff options
| author | Jack Koenig | 2021-04-16 11:41:07 -0700 |
|---|---|---|
| committer | GitHub | 2021-04-16 11:41:07 -0700 |
| commit | bf1cf3d2db49195d031f89594baebcc9f307659e (patch) | |
| tree | 4a13e03f64c49295dc9cb620f76737d25df08419 /src | |
| parent | e9b2946c962f91a04611e32b1a9d03f78e7edf2b (diff) | |
Make InferTypes error on enable conditions > 1-bit wide (#2182)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckTypes.scala | 33 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CheckSpec.scala | 100 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LowerTypesSpec.scala | 4 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 4 |
4 files changed, 125 insertions, 16 deletions
diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala index f70db148..50fbfc2e 100644 --- a/src/main/scala/firrtl/passes/CheckTypes.scala +++ b/src/main/scala/firrtl/passes/CheckTypes.scala @@ -55,9 +55,9 @@ object CheckTypes extends Pass { class RegReqClk(info: Info, mname: String, name: String) extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.") class EnNotUInt(info: Info, mname: String) - extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.") + extends PassException(s"$info: [module $mname] Enable must be a 1-bit UIntType typed signal.") class PredNotUInt(info: Info, mname: String) - extends PassException(s"$info: [module $mname] Predicate not a UIntType.") + extends PassException(s"$info: [module $mname] Predicate not a 1-bit UIntType.") class OpNotGround(info: Info, mname: String, op: String) extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.") class OpNotUInt(info: Info, mname: String, op: String, e: String) @@ -81,7 +81,7 @@ object CheckTypes extends Pass { class MuxPassiveTypes(info: Info, mname: String) extends PassException(s"$info: [module $mname] Must mux between passive types.") class MuxCondUInt(info: Info, mname: String) - extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.") + extends PassException(s"$info: [module $mname] A mux condition must be of type 1-bit UInt.") class MuxClock(info: Info, mname: String) extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.") class ValidIfPassiveTypes(info: Info, mname: String) @@ -120,6 +120,15 @@ object CheckTypes extends Pass { case _ => false } + private def legalCondType(tpe: Type): Boolean = tpe match { + // If width is known, must be 1 + case UIntType(IntWidth(w)) => w == 1 + // Unknown width or variable widths (for width inference) are acceptable (checked in later run) + case UIntType(_) => true + // Any other type is not okay + case _ => false + } + private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { (t1, t2) match { case (ClockType, ClockType) => flip1 == flip2 @@ -165,7 +174,8 @@ object CheckTypes extends Pass { bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default) //;---------------- Helper Functions -------------- - def ut: UIntType = UIntType(UnknownWidth) + private val UIntUnknown = UIntType(UnknownWidth) + def ut: UIntType = UIntUnknown def st: SIntType = SIntType(UnknownWidth) def run(c: Circuit): Circuit = { @@ -332,9 +342,8 @@ object CheckTypes extends Pass { errors.append(new MuxSameType(info, mname, e.tval.tpe.serialize, e.fval.tpe.serialize)) if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info, mname)) - e.cond.tpe match { - case _: UIntType => - case _ => errors.append(new MuxCondUInt(info, mname)) + if (!legalCondType(e.cond.tpe)) { + errors.append(new MuxCondUInt(info, mname)) } case (e: ValidIf) => if (!passive(e.tpe)) @@ -375,7 +384,7 @@ object CheckTypes extends Pass { if (sx.clock.tpe != ClockType) { errors.append(new RegReqClk(info, mname, sx.name)) } - case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) => + case sx: Conditionally if !legalCondType(sx.pred.tpe) => errors.append(new PredNotUInt(info, mname)) case sx: DefNode => sx.value.tpe match { @@ -396,16 +405,16 @@ object CheckTypes extends Pass { } case sx: Stop => if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) - if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) + if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname)) case sx: Print => if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st))) errors.append(new PrintfArgNotGround(info, mname)) if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) - if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) + if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname)) case sx: Verification => if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname)) - if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname)) - if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname)) + if (!legalCondType(sx.pred.tpe)) errors.append(new PredNotUInt(info, mname)) + if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname)) case sx: DefMemory => sx.dataType match { case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name)) diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index a3efc784..547639d6 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -86,6 +86,106 @@ class CheckSpec extends AnyFlatSpec with Matchers { } } + behavior.of("Check Types") + + def runCheckTypes(input: String) = { + val passes = List(InferTypes, CheckTypes) + val wrapped = "circuit test:\n module test:\n " + input.replaceAll("\n", "\n ") + passes.foldLeft(Parser.parse(wrapped)) { case (c, p) => p.run(c) } + } + + it should "disallow mux enable conditions that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input foo : UInt<8> + |input bar : UInt<8> + |node x = mux(en, foo, bar)""".stripMargin + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow when predicates that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input foo : UInt<8> + |input bar : UInt<8> + |output out : UInt<8> + |when en : + | out <= foo + |else: + | out <= bar""".stripMargin + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow print enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input clock : Clock + |printf(clock, en, "Hello World!\\n")""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow stop enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input clock : Clock + |stop(clock, en, 0)""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow verif node predicates that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : $tpe + |input cond : UInt<1> + |input clock : Clock + |assert(clock, en, cond, "Howdy!")""".stripMargin + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + + it should "disallow verif node enables that are not 1-bit UInts (or unknown width)" in { + def mk(tpe: String) = + s"""|input en : UInt<1> + |input cond : $tpe + |input clock : Clock + |assert(clock, en, cond, "Howdy!")""".stripMargin + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) } + a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) } + runCheckTypes(mk("UInt")) + runCheckTypes(mk("UInt<1>")) + } + "Instance loops a -> b -> a" should "be detected" in { val input = """ diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 78d03e68..6e774d18 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -466,10 +466,10 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec { | input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2] | output a_0_b : UInt<1> | input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2] - | a_0_b <= mux(a[UInt(0)].c_1_e, or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e))) + | a_0_b <= mux(bits(a[UInt(0)].c_1_e, 0, 0), or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e))) """.stripMargin val expected = Seq( - "a_0_b <= mux(a___0_c_1_e, or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))" + "a_0_b <= mux(bits(a___0_c_1_e, 0, 0), or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))" ) executeTest(input, expected) diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index d21f80c8..2156e392 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -422,7 +422,7 @@ circuit CustomMemory : circuit CustomMemory : module CustomMemory : input clock : Clock - output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] } smem mem : UInt<8>[2][1024] read mport r = mem[io.raddr], clock @@ -452,7 +452,7 @@ circuit CustomMemory : circuit CustomMemory : module CustomMemory : input clock : Clock - output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] } io.out is invalid |
