diff options
| author | chick | 2016-07-25 11:39:54 -0700 |
|---|---|---|
| committer | chick | 2016-07-25 11:39:54 -0700 |
| commit | 94b28438d1658d0835122b8c27bbbf3753892475 (patch) | |
| tree | 7eb7939bc64ac99801dd15e0eb2f783398903fd4 /src | |
| parent | ab340febdc7a5418da945f9b79624d36e66e26db (diff) | |
Detects and flags cyclic module loops
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 63 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CheckSpec.scala | 116 |
2 files changed, 178 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 94d509ed..5c0bb251 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -30,7 +30,7 @@ package firrtl.passes import com.typesafe.scalalogging.LazyLogging // Datastructures -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap,HashSet} import scala.collection.mutable.ArrayBuffer import firrtl._ @@ -62,6 +62,60 @@ object CheckHighForm extends Pass with LazyLogging { class BadPrintfException(x: Char) extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: " + "\"%" + x + "\"") class BadPrintfTrailingException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: trailing " + "\"%\"") class BadPrintfIncorrectNumException extends PassException(s"${sinfo}: [module ${mname}] Bad printf format: incorrect number of arguments") + class InstanceLoop(loop: String) extends PassException(s"${sinfo}: [module ${mname}] Has instance loop $loop") + + /** + * Maintains a one to many graph of each modules instantiated child module. + * This graph can be searched for a path from a child module back to one of + * it's parents. If one is found a recursive loop has happened + * The graph is a map between the name of a node to set of names of that nodes children + */ + class ModuleGraph { + val nodes = new HashMap[String, HashSet[String]] + + /** + * Add a child to a parent node + * A parent node is created if it does not already exist + * + * @param parent module that instantiates another module + * @param child module instantiated by parent + * @return a list indicating a path from child to parent, empty if no such path + */ + def add(parent: String, child: String): List[String] = { + val childSet = nodes.getOrElseUpdate(parent, new HashSet[String]) + childSet += child + pathExists(child, parent, List(child, parent)) + } + + /** + * Starting at the name of a given child explore the tree of all children in depth first manner. + * Return the first path (a list of strings) that goes from child to parent, + * or an empty list of no such path is found. + * + * @param child starting name + * @param parent name to find in children (recursively) + * @param path + * @return + */ + def pathExists(child: String, parent: String, path: List[String] = Nil): List[String] = { + nodes.get(child) match { + case Some(children) => + if(children.contains(parent)) { + parent :: path + } + else { + children.foreach { grandchild => + val newPath = pathExists(grandchild, parent, grandchild :: path) + if(newPath.nonEmpty) { + return newPath + } + } + Nil + } + case _ => Nil + } + } + } // Utility functions def hasFlip(t: Type): Boolean = { @@ -88,6 +142,8 @@ object CheckHighForm extends Pass with LazyLogging { private var sinfo: Info = NoInfo def run (c:Circuit): Circuit = { val errors = new Errors() + val moduleGraph = new ModuleGraph + def checkHighFormPrimop(e: DoPrim) = { def correctNum(ne: Option[Int], nc: Int) = { ne match { @@ -217,6 +273,11 @@ object CheckHighForm extends Pass with LazyLogging { case s: WDefInstance => { if (!c.modules.map(_.name).contains(s.module)) errors.append(new ModuleNotDefinedException(s.module)) + // Check to see if a recursive module instantiation has occured + val childToParent = moduleGraph.add(mname, s.module) + if(childToParent.nonEmpty) { + errors.append(new InstanceLoop(childToParent.mkString("->"))) + } } case s: Connect => checkValidLoc(s.loc) case s: PartialConnect => checkValidLoc(s.loc) diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index 69645ddc..65873540 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -26,4 +26,120 @@ class CheckSpec extends FlatSpec with Matchers { } } } + "Instance loops a -> b -> a" should "be detected" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm) + val input = + """ + |circuit Foo : + | module Foo : + | input a : UInt<32> + | output b : UInt<32> + | inst bar of Bar + | bar.a <= a + | b <= bar.b + | + | module Bar : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of Foo + | foo.a <= a + | b <= foo.b + """.stripMargin + intercept[CheckHighForm.InstanceLoop] { + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + } + } + + "Instance loops a -> b -> c -> a" should "be detected" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm) + val input = + """ + |circuit Dog : + | module Dog : + | input a : UInt<32> + | output b : UInt<32> + | inst bar of Cat + | bar.a <= a + | b <= bar.b + | + | module Cat : + | input a : UInt<32> + | output b : UInt<32> + | inst ik of Ik + | ik.a <= a + | b <= ik.b + | + | module Ik : + | input a : UInt<32> + | output b : UInt<32> + | inst foo of Dog + | foo.a <= a + | b <= foo.b + | """.stripMargin + intercept[CheckHighForm.InstanceLoop] { + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + } + } + + "Instance loops a -> a" should "be detected" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm) + val input = + """ + |circuit Apple : + | module Apple : + | input a : UInt<32> + | output b : UInt<32> + | inst recurse_foo of Apple + | recurse_foo.a <= a + | b <= recurse_foo.b + | """.stripMargin + intercept[CheckHighForm.InstanceLoop] { + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + } + } + + "Instance loops should not have false positives" should "be detected" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm) + val input = + """ + |circuit Hammer : + | module Hammer : + | input a : UInt<32> + | output b : UInt<32> + | inst bar of Chisel + | bar.a <= a + | b <= bar.b + | + | module Chisel : + | input a : UInt<32> + | output b : UInt<32> + | inst ik of Saw + | ik.a <= a + | b <= ik.b + | + | module Saw : + | input a : UInt<32> + | output b : UInt<32> + | b <= a + | """.stripMargin + passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + + } + } |
