aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorchick2016-07-25 11:39:54 -0700
committerchick2016-07-25 11:39:54 -0700
commit94b28438d1658d0835122b8c27bbbf3753892475 (patch)
tree7eb7939bc64ac99801dd15e0eb2f783398903fd4 /src
parentab340febdc7a5418da945f9b79624d36e66e26db (diff)
Detects and flags cyclic module loops
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala63
-rw-r--r--src/test/scala/firrtlTests/CheckSpec.scala116
2 files changed, 178 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 94d509ed..5c0bb251 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -30,7 +30,7 @@ package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
// Datastructures
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap,HashSet}
import scala.collection.mutable.ArrayBuffer
import firrtl._
@@ -62,6 +62,60 @@ object CheckHighForm extends Pass with LazyLogging {
class BadPrintfException(x: Char) extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: " + "\"%" + x + "\"")
class BadPrintfTrailingException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: trailing " + "\"%\"")
class BadPrintfIncorrectNumException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: incorrect number of arguments")
+ class InstanceLoop(loop: String) extends PassException(s"${sinfo}: [module ${mname}] Has instance loop $loop")
+
+ /**
+ * Maintains a one to many graph of each modules instantiated child module.
+ * This graph can be searched for a path from a child module back to one of
+ * it's parents. If one is found a recursive loop has happened
+ * The graph is a map between the name of a node to set of names of that nodes children
+ */
+ class ModuleGraph {
+ val nodes = new HashMap[String, HashSet[String]]
+
+ /**
+ * Add a child to a parent node
+ * A parent node is created if it does not already exist
+ *
+ * @param parent module that instantiates another module
+ * @param child module instantiated by parent
+ * @return a list indicating a path from child to parent, empty if no such path
+ */
+ def add(parent: String, child: String): List[String] = {
+ val childSet = nodes.getOrElseUpdate(parent, new HashSet[String])
+ childSet += child
+ pathExists(child, parent, List(child, parent))
+ }
+
+ /**
+ * Starting at the name of a given child explore the tree of all children in depth first manner.
+ * Return the first path (a list of strings) that goes from child to parent,
+ * or an empty list of no such path is found.
+ *
+ * @param child starting name
+ * @param parent name to find in children (recursively)
+ * @param path
+ * @return
+ */
+ def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = {
+ nodes.get(child) match {
+ case Some(children) =>
+ if(children.contains(parent)) {
+ parent :: path
+ }
+ else {
+ children.foreach { grandchild =>
+ val newPath = pathExists(grandchild, parent, grandchild :: path)
+ if(newPath.nonEmpty) {
+ return newPath
+ }
+ }
+ Nil
+ }
+ case _ => Nil
+ }
+ }
+ }
// Utility functions
def hasFlip(t: Type): Boolean = {
@@ -88,6 +142,8 @@ object CheckHighForm extends Pass with LazyLogging {
private var sinfo: Info = NoInfo
def run (c:Circuit): Circuit = {
val errors = new Errors()
+ val moduleGraph = new ModuleGraph
+
def checkHighFormPrimop(e: DoPrim) = {
def correctNum(ne: Option[Int], nc: Int) = {
ne match {
@@ -217,6 +273,11 @@ object CheckHighForm extends Pass with LazyLogging {
case s: WDefInstance => {
if (!c.modules.map(_.name).contains(s.module))
errors.append(new ModuleNotDefinedException(s.module))
+ // Check to see if a recursive module instantiation has occured
+ val childToParent = moduleGraph.add(mname, s.module)
+ if(childToParent.nonEmpty) {
+ errors.append(new InstanceLoop(childToParent.mkString("->")))
+ }
}
case s: Connect => checkValidLoc(s.loc)
case s: PartialConnect => checkValidLoc(s.loc)
diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala
index 69645ddc..65873540 100644
--- a/src/test/scala/firrtlTests/CheckSpec.scala
+++ b/src/test/scala/firrtlTests/CheckSpec.scala
@@ -26,4 +26,120 @@ class CheckSpec extends FlatSpec with Matchers {
}
}
}
+ "Instance loops a -> b -> a" should "be detected" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm)
+ val input =
+ """
+ |circuit Foo :
+ | module Foo :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst bar of Bar
+ | bar.a <= a
+ | b <= bar.b
+ |
+ | module Bar :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst foo of Foo
+ | foo.a <= a
+ | b <= foo.b
+ """.stripMargin
+ intercept[CheckHighForm.InstanceLoop] {
+ passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ }
+ }
+
+ "Instance loops a -> b -> c -> a" should "be detected" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm)
+ val input =
+ """
+ |circuit Dog :
+ | module Dog :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst bar of Cat
+ | bar.a <= a
+ | b <= bar.b
+ |
+ | module Cat :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst ik of Ik
+ | ik.a <= a
+ | b <= ik.b
+ |
+ | module Ik :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst foo of Dog
+ | foo.a <= a
+ | b <= foo.b
+ | """.stripMargin
+ intercept[CheckHighForm.InstanceLoop] {
+ passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ }
+ }
+
+ "Instance loops a -> a" should "be detected" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm)
+ val input =
+ """
+ |circuit Apple :
+ | module Apple :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst recurse_foo of Apple
+ | recurse_foo.a <= a
+ | b <= recurse_foo.b
+ | """.stripMargin
+ intercept[CheckHighForm.InstanceLoop] {
+ passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ }
+ }
+
+ "Instance loops should not have false positives" should "be detected" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm)
+ val input =
+ """
+ |circuit Hammer :
+ | module Hammer :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst bar of Chisel
+ | bar.a <= a
+ | b <= bar.b
+ |
+ | module Chisel :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | inst ik of Saw
+ | ik.a <= a
+ | b <= ik.b
+ |
+ | module Saw :
+ | input a : UInt<32>
+ | output b : UInt<32>
+ | b <= a
+ | """.stripMargin
+ passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+
+ }
+
}