aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJack Koenig2021-04-16 11:41:07 -0700
committerGitHub2021-04-16 11:41:07 -0700
commitbf1cf3d2db49195d031f89594baebcc9f307659e (patch)
tree4a13e03f64c49295dc9cb620f76737d25df08419
parente9b2946c962f91a04611e32b1a9d03f78e7edf2b (diff)
Make InferTypes error on enable conditions > 1-bit wide (#2182)
-rw-r--r--src/main/scala/firrtl/passes/CheckTypes.scala33
-rw-r--r--src/test/scala/firrtlTests/CheckSpec.scala100
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala4
4 files changed, 125 insertions, 16 deletions
diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala
index f70db148..50fbfc2e 100644
--- a/src/main/scala/firrtl/passes/CheckTypes.scala
+++ b/src/main/scala/firrtl/passes/CheckTypes.scala
@@ -55,9 +55,9 @@ object CheckTypes extends Pass {
class RegReqClk(info: Info, mname: String, name: String)
extends PassException(s"$info: [module $mname] Register $name requires a clock typed signal.")
class EnNotUInt(info: Info, mname: String)
- extends PassException(s"$info: [module $mname] Enable must be a UIntType typed signal.")
+ extends PassException(s"$info: [module $mname] Enable must be a 1-bit UIntType typed signal.")
class PredNotUInt(info: Info, mname: String)
- extends PassException(s"$info: [module $mname] Predicate not a UIntType.")
+ extends PassException(s"$info: [module $mname] Predicate not a 1-bit UIntType.")
class OpNotGround(info: Info, mname: String, op: String)
extends PassException(s"$info: [module $mname] Primop $op cannot operate on non-ground types.")
class OpNotUInt(info: Info, mname: String, op: String, e: String)
@@ -81,7 +81,7 @@ object CheckTypes extends Pass {
class MuxPassiveTypes(info: Info, mname: String)
extends PassException(s"$info: [module $mname] Must mux between passive types.")
class MuxCondUInt(info: Info, mname: String)
- extends PassException(s"$info: [module $mname] A mux condition must be of type UInt.")
+ extends PassException(s"$info: [module $mname] A mux condition must be of type 1-bit UInt.")
class MuxClock(info: Info, mname: String)
extends PassException(s"$info: [module $mname] Firrtl does not support muxing clocks.")
class ValidIfPassiveTypes(info: Info, mname: String)
@@ -120,6 +120,15 @@ object CheckTypes extends Pass {
case _ => false
}
+ private def legalCondType(tpe: Type): Boolean = tpe match {
+ // If width is known, must be 1
+ case UIntType(IntWidth(w)) => w == 1
+ // Unknown width or variable widths (for width inference) are acceptable (checked in later run)
+ case UIntType(_) => true
+ // Any other type is not okay
+ case _ => false
+ }
+
private def bulk_equals(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = {
(t1, t2) match {
case (ClockType, ClockType) => flip1 == flip2
@@ -165,7 +174,8 @@ object CheckTypes extends Pass {
bulk_equals(con.loc.tpe, con.expr.tpe, Default, Default)
//;---------------- Helper Functions --------------
- def ut: UIntType = UIntType(UnknownWidth)
+ private val UIntUnknown = UIntType(UnknownWidth)
+ def ut: UIntType = UIntUnknown
def st: SIntType = SIntType(UnknownWidth)
def run(c: Circuit): Circuit = {
@@ -332,9 +342,8 @@ object CheckTypes extends Pass {
errors.append(new MuxSameType(info, mname, e.tval.tpe.serialize, e.fval.tpe.serialize))
if (!passive(e.tpe))
errors.append(new MuxPassiveTypes(info, mname))
- e.cond.tpe match {
- case _: UIntType =>
- case _ => errors.append(new MuxCondUInt(info, mname))
+ if (!legalCondType(e.cond.tpe)) {
+ errors.append(new MuxCondUInt(info, mname))
}
case (e: ValidIf) =>
if (!passive(e.tpe))
@@ -375,7 +384,7 @@ object CheckTypes extends Pass {
if (sx.clock.tpe != ClockType) {
errors.append(new RegReqClk(info, mname, sx.name))
}
- case sx: Conditionally if wt(sx.pred.tpe) != wt(ut) =>
+ case sx: Conditionally if !legalCondType(sx.pred.tpe) =>
errors.append(new PredNotUInt(info, mname))
case sx: DefNode =>
sx.value.tpe match {
@@ -396,16 +405,16 @@ object CheckTypes extends Pass {
}
case sx: Stop =>
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
- if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
+ if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname))
case sx: Print =>
if (sx.args.exists(x => wt(x.tpe) != wt(ut) && wt(x.tpe) != wt(st)))
errors.append(new PrintfArgNotGround(info, mname))
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
- if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
+ if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname))
case sx: Verification =>
if (wt(sx.clk.tpe) != wt(ClockType)) errors.append(new ReqClk(info, mname))
- if (wt(sx.pred.tpe) != wt(ut)) errors.append(new PredNotUInt(info, mname))
- if (wt(sx.en.tpe) != wt(ut)) errors.append(new EnNotUInt(info, mname))
+ if (!legalCondType(sx.pred.tpe)) errors.append(new PredNotUInt(info, mname))
+ if (!legalCondType(sx.en.tpe)) errors.append(new EnNotUInt(info, mname))
case sx: DefMemory =>
sx.dataType match {
case AnalogType(w) => errors.append(new IllegalAnalogDeclaration(info, mname, sx.name))
diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala
index a3efc784..547639d6 100644
--- a/src/test/scala/firrtlTests/CheckSpec.scala
+++ b/src/test/scala/firrtlTests/CheckSpec.scala
@@ -86,6 +86,106 @@ class CheckSpec extends AnyFlatSpec with Matchers {
}
}
+ behavior.of("Check Types")
+
+ def runCheckTypes(input: String) = {
+ val passes = List(InferTypes, CheckTypes)
+ val wrapped = "circuit test:\n module test:\n " + input.replaceAll("\n", "\n ")
+ passes.foldLeft(Parser.parse(wrapped)) { case (c, p) => p.run(c) }
+ }
+
+ it should "disallow mux enable conditions that are not 1-bit UInts (or unknown width)" in {
+ def mk(tpe: String) =
+ s"""|input en : $tpe
+ |input foo : UInt<8>
+ |input bar : UInt<8>
+ |node x = mux(en, foo, bar)""".stripMargin
+ a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
+ a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
+ a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
+ a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
+ a[CheckTypes.MuxCondUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
+ runCheckTypes(mk("UInt"))
+ runCheckTypes(mk("UInt<1>"))
+ }
+
+ it should "disallow when predicates that are not 1-bit UInts (or unknown width)" in {
+ def mk(tpe: String) =
+ s"""|input en : $tpe
+ |input foo : UInt<8>
+ |input bar : UInt<8>
+ |output out : UInt<8>
+ |when en :
+ | out <= foo
+ |else:
+ | out <= bar""".stripMargin
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
+ runCheckTypes(mk("UInt"))
+ runCheckTypes(mk("UInt<1>"))
+ }
+
+ it should "disallow print enables that are not 1-bit UInts (or unknown width)" in {
+ def mk(tpe: String) =
+ s"""|input en : $tpe
+ |input clock : Clock
+ |printf(clock, en, "Hello World!\\n")""".stripMargin
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
+ runCheckTypes(mk("UInt"))
+ runCheckTypes(mk("UInt<1>"))
+ }
+
+ it should "disallow stop enables that are not 1-bit UInts (or unknown width)" in {
+ def mk(tpe: String) =
+ s"""|input en : $tpe
+ |input clock : Clock
+ |stop(clock, en, 0)""".stripMargin
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
+ runCheckTypes(mk("UInt"))
+ runCheckTypes(mk("UInt<1>"))
+ }
+
+ it should "disallow verif node predicates that are not 1-bit UInts (or unknown width)" in {
+ def mk(tpe: String) =
+ s"""|input en : $tpe
+ |input cond : UInt<1>
+ |input clock : Clock
+ |assert(clock, en, cond, "Howdy!")""".stripMargin
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
+ a[CheckTypes.PredNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
+ runCheckTypes(mk("UInt"))
+ runCheckTypes(mk("UInt<1>"))
+ }
+
+ it should "disallow verif node enables that are not 1-bit UInts (or unknown width)" in {
+ def mk(tpe: String) =
+ s"""|input en : UInt<1>
+ |input cond : $tpe
+ |input clock : Clock
+ |assert(clock, en, cond, "Howdy!")""".stripMargin
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt<1>")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("SInt")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("UInt<3>")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("Clock")) }
+ a[CheckTypes.EnNotUInt] shouldBe thrownBy { runCheckTypes(mk("AsyncReset")) }
+ runCheckTypes(mk("UInt"))
+ runCheckTypes(mk("UInt<1>"))
+ }
+
"Instance loops a -> b -> a" should "be detected" in {
val input =
"""
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 78d03e68..6e774d18 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -466,10 +466,10 @@ class LowerTypesUniquifySpec extends FirrtlFlatSpec {
| input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2]
| output a_0_b : UInt<1>
| input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2]
- | a_0_b <= mux(a[UInt(0)].c_1_e, or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e)))
+ | a_0_b <= mux(bits(a[UInt(0)].c_1_e, 0, 0), or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e)))
""".stripMargin
val expected = Seq(
- "a_0_b <= mux(a___0_c_1_e, or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))"
+ "a_0_b <= mux(bits(a___0_c_1_e, 0, 0), or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))"
)
executeTest(input, expected)
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index d21f80c8..2156e392 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -422,7 +422,7 @@ circuit CustomMemory :
circuit CustomMemory :
module CustomMemory :
input clock : Clock
- output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] }
smem mem : UInt<8>[2][1024]
read mport r = mem[io.raddr], clock
@@ -452,7 +452,7 @@ circuit CustomMemory :
circuit CustomMemory :
module CustomMemory :
input clock : Clock
- output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<1>[2] }
io.out is invalid