diff options
| author | Albert Magyar | 2019-08-07 15:13:57 -0700 |
|---|---|---|
| committer | mergify[bot] | 2019-08-07 22:13:57 +0000 |
| commit | 23a104d3409385718a960427f1576f508e3f473b (patch) | |
| tree | 1bde53f80a0fe0f0c32eb0a1432f413989fd75b4 /src | |
| parent | 0fe6aad23a4aee50119b9fe2645ba2ff833f65bb (diff) | |
DRY check chirrtl (#1148)
* Avoid redundancy between CheckChirrtl and CheckHighForm, add more checks
* Add test case for illegal Chirrtl memory in HighForm
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 118 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 49 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CheckSpec.scala | 80 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ChirrtlSpec.scala | 7 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InternalErrorSpec.scala | 46 |
5 files changed, 67 insertions, 233 deletions
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index af44e1e6..08237ab2 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -3,121 +3,7 @@ package firrtl.passes import firrtl.ir._ -import firrtl.Utils._ -import firrtl.traversals.Foreachers._ -object CheckChirrtl extends Pass { - type NameSet = collection.mutable.HashSet[String] - - class NotUniqueException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Reference $name does not have a unique name.") - class InvalidLOCException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Invalid connect to an expression that is not a reference or a WritePort.") - class UndeclaredReferenceException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Reference $name is not declared.") - class MemWithFlipException(info: Info, mname: String, name: String) extends PassException( - s"$info: [module $mname] Memory $name cannot be a bundle type with flips.") - class InvalidAccessException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Invalid access to non-reference.") - class ModuleNotDefinedException(info: Info, mname: String, name: String) extends PassException( - s"$info: Module $name is not defined.") - class NegWidthException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Width cannot be negative or zero.") - class NegVecSizeException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Vector type size cannot be negative.") - class NegMemSizeException(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Memory size cannot be negative or zero.") - class NoTopModuleException(info: Info, name: String) extends PassException( - s"$info: A single module must be named $name.") - - def run (c: Circuit): Circuit = { - val errors = new Errors() - val moduleNames = (c.modules map (_.name)).toSet - - def checkValidLoc(info: Info, mname: String, e: Expression) = e match { - case _: UIntLiteral | _: SIntLiteral | _: DoPrim => - errors append new InvalidLOCException(info, mname) - case _ => // Do Nothing - } - def checkChirrtlW(info: Info, mname: String)(w: Width): Unit = w match { - case w: IntWidth if (w.width < BigInt(0)) => errors.append(new NegWidthException(info, mname)) - case _ => - } - - def checkChirrtlT(info: Info, mname: String)(t: Type): Unit = { - t.foreach(checkChirrtlT(info, mname)) - t match { - case t: VectorType if t.size < 0 => - errors append new NegVecSizeException(info, mname) - t.foreach(checkChirrtlW(info, mname)) - //case FixedType(width, point) => FixedType(checkChirrtlW(width), point) - case _ => t.foreach(checkChirrtlW(info, mname)) - } - } - - def validSubexp(info: Info, mname: String)(e: Expression): Unit = e match { - case _: Reference | _: SubField | _: SubIndex | _: SubAccess | - _: Mux | _: ValidIf => // No error - case _ => errors append new InvalidAccessException(info, mname) - } - - def checkChirrtlE(info: Info, mname: String, names: NameSet)(e: Expression): Unit = { - e match { - case _: DoPrim | _:Mux | _:ValidIf | _: UIntLiteral => - case ex: Reference if !names(ex.name) => - errors append new UndeclaredReferenceException(info, mname, ex.name) - case ex: SubAccess => validSubexp(info, mname)(ex.expr) - case ex => ex.foreach(validSubexp(info, mname)) - } - e.foreach(checkChirrtlW(info, mname)) - e.foreach(checkChirrtlT(info, mname)) - e.foreach(checkChirrtlE(info, mname, names)) - } - - def checkName(info: Info, mname: String, names: NameSet)(name: String): Unit = { - if (names(name)) - errors append new NotUniqueException(info, mname, name) - names += name - } - - def checkChirrtlS(minfo: Info, mname: String, names: NameSet)(s: Statement): Unit = { - val info = get_info(s) match {case NoInfo => minfo case x => x} - s.foreach(checkName(info, mname, names)) - s match { - case sx: DefMemory => - if (hasFlip(sx.dataType)) errors append new MemWithFlipException(info, mname, sx.name) - if (sx.depth <= 0) errors append new NegMemSizeException(info, mname) - case sx: DefInstance if !moduleNames(sx.module) => - errors append new ModuleNotDefinedException(info, mname, sx.module) - case sx: Connect => checkValidLoc(info, mname, sx.loc) - case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) - case _ => // Do Nothing - } - s.foreach(checkChirrtlT(info, mname)) - s.foreach(checkChirrtlE(info, mname, names)) - s.foreach(checkChirrtlS(info, mname, names)) - } - - def checkChirrtlP(mname: String, names: NameSet)(p: Port): Unit = { - if (names(p.name)) - errors append new NotUniqueException(NoInfo, mname, p.name) - names += p.name - p.tpe.foreach(checkChirrtlT(p.info, mname)) - p.tpe.foreach(checkChirrtlW(p.info, mname)) - } - - def checkChirrtlM(m: DefModule): Unit = { - val names = new NameSet - m.foreach(checkChirrtlP(m.name, names)) - m.foreach(checkChirrtlS(m.info, m.name, names)) - } - - c.modules.foreach(checkChirrtlM) - c.modules count (_.name == c.main) match { - case 1 => - case _ => errors append new NoTopModuleException(c.info, c.main) - } - errors.trigger() - c - } +object CheckChirrtl extends Pass with CheckHighFormLike { + def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = None } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index c1415b19..471fe216 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -9,7 +9,7 @@ import firrtl.Utils._ import firrtl.traversals.Foreachers._ import firrtl.WrappedType._ -object CheckHighForm extends Pass { +trait CheckHighFormLike { type NameSet = collection.mutable.HashSet[String] // Custom Exceptions @@ -23,6 +23,8 @@ object CheckHighForm extends Pass { s"$info: [module $mname] Reference $name is not declared.") class PoisonWithFlipException(info: Info, mname: String, name: String) extends PassException( s"$info: [module $mname] Poison $name cannot be a bundle type with flips.") + class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException( + s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") class MemWithFlipException(info: Info, mname: String, name: String) extends PassException( s"$info: [module $mname] Memory $name cannot be a bundle type with flips.") class RegWithFlipException(info: Info, mname: String, name: String) extends PassException( @@ -58,6 +60,9 @@ object CheckHighForm extends Pass { class NonLiteralAsyncResetValueException(info: Info, mname: String, reg: String, init: String) extends PassException( s"$info: [module $mname] AsyncReset Reg '$reg' reset to non-literal '$init'") + // Is Chirrtl allowed for this check? If not, return an error + def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] + def run(c: Circuit): Circuit = { val errors = new Errors() val moduleGraph = new ModuleGraph @@ -84,9 +89,8 @@ object CheckHighForm extends Pass { correctNum(Option(1), 1) case Shl | Shr => correctNum(Option(1), 1) - val amount = e.consts.head.toInt - if (amount < 0) { - errors.append(new NegArgException(info, mname, e.op.toString, amount)) + val amount = e.consts.map(_.toInt).filter(_ < 0).foreach { + c => errors.append(new NegArgException(info, mname, e.op.toString, c)) } case Bits => correctNum(Option(1), 2) @@ -137,6 +141,7 @@ object CheckHighForm extends Pass { def validSubexp(info: Info, mname: String)(e: Expression): Unit = { e match { + case _: Reference | _: SubField | _: SubIndex | _: SubAccess => // No error case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess | _: Mux | _: ValidIf => // No error case _ => errors.append(new InvalidAccessException(info, mname)) } @@ -144,12 +149,15 @@ object CheckHighForm extends Pass { def checkHighFormE(info: Info, mname: String, names: NameSet)(e: Expression): Unit = { e match { + case ex: Reference if !names(ex.name) => + errors.append(new UndeclaredReferenceException(info, mname, ex.name)) case ex: WRef if !names(ex.name) => errors.append(new UndeclaredReferenceException(info, mname, ex.name)) case ex: UIntLiteral if ex.value < 0 => errors.append(new NegUIntException(info, mname)) case ex: DoPrim => checkHighFormPrimop(info, mname, ex) - case _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => + case _: Reference | _: WRef | _: UIntLiteral | _: Mux | _: ValidIf => + case ex: SubAccess => validSubexp(info, mname)(ex.expr) case ex: WSubAccess => validSubexp(info, mname)(ex.expr) case ex => ex foreach validSubexp(info, mname) } @@ -164,6 +172,15 @@ object CheckHighForm extends Pass { names += name } + def checkInstance(info: Info, child: String, parent: String): Unit = { + if (!moduleNames(child)) + errors.append(new ModuleNotDefinedException(info, parent, child)) + // Check to see if a recursive module instantiation has occured + val childToParent = moduleGraph add (parent, child) + if (childToParent.nonEmpty) + errors.append(new InstanceLoop(info, parent, childToParent mkString "->")) + } + def checkHighFormS(minfo: Info, mname: String, names: NameSet)(s: Statement): Unit = { val info = get_info(s) match {case NoInfo => minfo case x => x} s foreach checkName(info, mname, names) @@ -178,16 +195,12 @@ object CheckHighForm extends Pass { errors.append(new MemWithFlipException(info, mname, sx.name)) if (sx.depth <= 0) errors.append(new NegMemSizeException(info, mname)) - case sx: WDefInstance => - if (!moduleNames(sx.module)) - errors.append(new ModuleNotDefinedException(info, mname, sx.module)) - // Check to see if a recursive module instantiation has occured - val childToParent = moduleGraph add (mname, sx.module) - if (childToParent.nonEmpty) - errors.append(new InstanceLoop(info, mname, childToParent mkString "->")) + case sx: DefInstance => checkInstance(info, mname, sx.module) + case sx: WDefInstance => checkInstance(info, mname, sx.module) case sx: Connect => checkValidLoc(info, mname, sx.loc) case sx: PartialConnect => checkValidLoc(info, mname, sx.loc) case sx: Print => checkFstring(info, mname, sx.string, sx.args.length) + case _: CDefMemory | _: CDefMPort => errorOnChirrtl(info, mname, s).foreach { e => errors.append(e) } case sx => // Do Nothing } s foreach checkHighFormT(info, mname) @@ -219,6 +232,18 @@ object CheckHighForm extends Pass { } } +object CheckHighForm extends Pass with CheckHighFormLike { + class IllegalChirrtlMemException(info: Info, mname: String, name: String) extends PassException( + s"$info: [module $mname] Memory $name has not been properly lowered from Chirrtl IR.") + + def errorOnChirrtl(info: Info, mname: String, s: Statement): Option[PassException] = { + val memName = s match { + case cm: CDefMemory => cm.name + case cp: CDefMPort => cp.mem + } + Some(new IllegalChirrtlMemException(info, mname, memName)) + } +} object CheckTypes extends Pass { // Custom Exceptions diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index af16ec03..93bc2cab 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -2,18 +2,32 @@ package firrtlTests -import java.io._ import org.scalatest._ -import org.scalatest.prop._ import firrtl.{Parser, CircuitState, UnknownForm, Transform} import firrtl.ir.Circuit import firrtl.passes.{Pass,ToWorkingIR,CheckHighForm,ResolveKinds,InferTypes,CheckTypes,PassException,InferWidths,CheckWidths,ResolveGenders,CheckGenders} class CheckSpec extends FlatSpec with Matchers { + val defaultPasses = Seq(ToWorkingIR, CheckHighForm) + def checkHighInput(input: String) = { + defaultPasses.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + } + + "CheckHighForm" should "disallow Chirrtl-style memories" in { + val input = + """circuit foo : + | module foo : + | input clock : Clock + | input addr : UInt<2> + | smem mem : UInt<1>[4]""".stripMargin + intercept[CheckHighForm.IllegalChirrtlMemException] { + checkHighInput(input) + } + } + "Memories with flip in the data type" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) val input = """circuit Unit : | module Unit : @@ -23,16 +37,11 @@ class CheckSpec extends FlatSpec with Matchers { | read-latency => 0 | write-latency => 1""".stripMargin intercept[CheckHighForm.MemWithFlipException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } } "Registers with flip in the type" should "throw an exception" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) val input = """circuit Unit : | module Unit : @@ -42,16 +51,11 @@ class CheckSpec extends FlatSpec with Matchers { | reg r : {a : UInt<32>, flip b : UInt<32>}, clk | out <= in""".stripMargin intercept[CheckHighForm.RegWithFlipException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } } "Instance loops a -> b -> a" should "be detected" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) val input = """ |circuit Foo : @@ -70,16 +74,11 @@ class CheckSpec extends FlatSpec with Matchers { | b <= foo.b """.stripMargin intercept[CheckHighForm.InstanceLoop] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } } "Instance loops a -> b -> c -> a" should "be detected" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) val input = """ |circuit Dog : @@ -105,16 +104,11 @@ class CheckSpec extends FlatSpec with Matchers { | b <= foo.b | """.stripMargin intercept[CheckHighForm.InstanceLoop] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } } "Instance loops a -> a" should "be detected" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) val input = """ |circuit Apple : @@ -126,16 +120,11 @@ class CheckSpec extends FlatSpec with Matchers { | b <= recurse_foo.b | """.stripMargin intercept[CheckHighForm.InstanceLoop] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } } "Instance loops should not have false positives" should "be detected" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm) val input = """ |circuit Hammer : @@ -158,10 +147,7 @@ class CheckSpec extends FlatSpec with Matchers { | output b : UInt<32> | b <= a | """.stripMargin - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } - + checkHighInput(input) } "Clock Types" should "be connectable" in { @@ -264,10 +250,6 @@ class CheckSpec extends FlatSpec with Matchers { for (op <- List("shl", "shr")) { s"$op by negative amount" should "result in an error" in { - val passes = Seq( - ToWorkingIR, - CheckHighForm - ) val amount = -1 val input = s"""circuit Unit : @@ -276,9 +258,7 @@ class CheckSpec extends FlatSpec with Matchers { | output z: UInt | z <= $op(x, $amount)""".stripMargin val exception = intercept[PassException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } exception.getMessage should include (s"Primop $op argument $amount < 0") } @@ -292,14 +272,8 @@ class CheckSpec extends FlatSpec with Matchers { | output foo : UInt | foo <= bits(in, 3, 4) | """.stripMargin - val passes = Seq( - ToWorkingIR, - CheckHighForm - ) val exception = intercept[PassException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) - } + checkHighInput(input) } } diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index 9344b861..aeb70c8d 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -2,13 +2,8 @@ package firrtlTests -import java.io._ -import org.scalatest._ -import org.scalatest.prop._ -import firrtl.{Parser, CircuitState, UnknownForm, Transform} -import firrtl.ir.Circuit -import firrtl.passes._ import firrtl._ +import firrtl.passes._ class ChirrtlSpec extends FirrtlFlatSpec { def transforms = Seq( diff --git a/src/test/scala/firrtlTests/InternalErrorSpec.scala b/src/test/scala/firrtlTests/InternalErrorSpec.scala deleted file mode 100644 index 85c9c67d..00000000 --- a/src/test/scala/firrtlTests/InternalErrorSpec.scala +++ /dev/null @@ -1,46 +0,0 @@ -// See LICENSE for license details. - -package firrtlTests - -import java.io.File - -import firrtl._ -import firrtl.Utils.getThrowable -import firrtl.util.BackendCompilationUtilities -import org.scalatest.{FreeSpec, Matchers} - - -class InternalErrorSpec extends FreeSpec with Matchers with BackendCompilationUtilities { - "Unexpected exceptions" - { - val input = - """ - |circuit Dummy : - | module Dummy : - | input clock : Clock - | input x : UInt<1> - | output y : UInt<1> - | output io : { flip in : UInt<16>, out : UInt<16> } - | y <= shr(x, UInt(1)); this should generate an exception in PrimOps.scala:127. - | """.stripMargin - - var exception: Exception = null - "should throw a FIRRTLException" in { - val manager = new ExecutionOptionsManager("test") with HasFirrtlOptions { - commonOptions = CommonOptions(topName = "Dummy") - firrtlOptions = FirrtlExecutionOptions(firrtlSource = Some(input), compilerName = "low") - } - exception = intercept[FIRRTLException] { - firrtl.Driver.execute(manager) - } - } - - "should contain the expected string" in { - assert(exception.getMessage.contains("Internal Error! Please file an issue")) - } - - "should contain the name of the file originating the exception in the stack trace" in { - val first = true - assert(getThrowable(Some(exception), first).getStackTrace exists (_.getFileName.contains("PrimOps.scala"))) - } - } -} |
