aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAlbert Magyar2017-02-26 16:11:37 -0800
committerAlbert Magyar2017-03-23 23:59:20 -0700
commit476d0d7475d641bc61d7630c8a7a8966cf61e04c (patch)
tree7e49f80f45855c8014bab614018902742c92a947 /src
parent67eb4e2de6166b8f1eb5190215640117b82e8c48 (diff)
Add pass to detect combinational loops
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala3
-rw-r--r--src/main/scala/firrtl/passes/CheckCombLoops.scala180
-rw-r--r--src/test/scala/firrtlTests/CheckCombLoopsSpec.scala127
3 files changed, 309 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index b5808e93..b9042781 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -86,7 +86,8 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform {
passes.InferTypes,
passes.ResolveGenders,
passes.InferWidths,
- passes.Legalize)
+ passes.Legalize,
+ passes.CheckCombLoops)
}
/** Runs a series of optimization passes on LowFirrtl
diff --git a/src/main/scala/firrtl/passes/CheckCombLoops.scala b/src/main/scala/firrtl/passes/CheckCombLoops.scala
new file mode 100644
index 00000000..af2ab666
--- /dev/null
+++ b/src/main/scala/firrtl/passes/CheckCombLoops.scala
@@ -0,0 +1,180 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+
+import scala.collection.mutable
+import scala.collection.immutable.HashSet
+import scala.collection.immutable.HashMap
+import annotation.tailrec
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.Utils.throwInternalError
+import firrtl.graph.{MutableDiGraph,DiGraph}
+import firrtl.analyses.InstanceGraph
+
+/** Finds and detects combinational logic loops in a circuit, if any
+ * exist. Returns the input circuit with no modifications.
+ *
+ * @throws a CombLoopException if a loop is found
+ * @note Input form: Low FIRRTL
+ * @note Output form: Low FIRRTL (identity transform)
+ * @note The pass looks for loops through combinational-read memories
+ * @note The pass cannot find loops that pass through ExtModules
+ * @note The pass will throw exceptions on "false paths"
+ */
+
+object CheckCombLoops extends Pass {
+
+ class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException(
+ s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))
+
+ /*
+ * A case class that represents a net in the circuit. This is
+ * necessary since combinational loop checking is an analysis on the
+ * netlist of the circuit; the fields are specialized for low
+ * FIRRTL. Since all wires are ground types, a given ground type net
+ * may only be a subfield of an instance or a memory
+ * port. Therefore, it is uniquely specified within its module
+ * context by its name, its optional parent instance (a WDefInstance
+ * or WDefMemory), and its optional memory port name.
+ */
+ private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
+
+ private def toLogicNode(e: Expression): LogicNode = e match {
+ case r: WRef =>
+ LogicNode(r.name)
+ case s: WSubField =>
+ s.exp match {
+ case modref: WRef =>
+ LogicNode(s.name,Some(modref.name))
+ case memport: WSubField =>
+ memport.exp match {
+ case memref: WRef =>
+ LogicNode(s.name,Some(memref.name),Some(memport.name))
+ case _ => throwInternalError
+ }
+ case _ => throwInternalError
+ }
+ }
+
+ private def getExprDeps(deps: mutable.Set[LogicNode])(e: Expression): Expression = e match {
+ case r: WRef =>
+ deps += toLogicNode(r)
+ r
+ case s: WSubField =>
+ deps += toLogicNode(s)
+ s
+ case _ =>
+ e map getExprDeps(deps)
+ }
+
+ private def getStmtDeps(
+ simplifiedModules: mutable.Map[String,DiGraph[LogicNode]],
+ deps: MutableDiGraph[LogicNode])(s: Statement): Statement = {
+ s match {
+ case Connect(_,loc,expr) =>
+ val lhs = toLogicNode(loc)
+ if (deps.contains(lhs)) {
+ getExprDeps(deps.getEdges(lhs))(expr)
+ }
+ case w: DefWire =>
+ deps.addVertex(LogicNode(w.name))
+ case n: DefNode =>
+ val lhs = LogicNode(n.name)
+ deps.addVertex(lhs)
+ getExprDeps(deps.getEdges(lhs))(n.value)
+ case m: DefMemory if (m.readLatency == 0) =>
+ for (rp <- m.readers) {
+ val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
+ }
+ case i: WDefInstance =>
+ val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
+ for (v <- iGraph.getVertices) {
+ deps.addVertex(v)
+ iGraph.getEdges(v).foreach { deps.addEdge(v,_) }
+ }
+ case _ =>
+ s map getStmtDeps(simplifiedModules,deps)
+ }
+ s
+ }
+
+ /*
+ * Recover the full path of the cycle. Since cycles may pass through
+ * simplified instances, the hierarchy that the path passes through
+ * must be recursively recovered.
+ */
+ private def recoverCycle(
+ m: String,
+ moduleGraphs: mutable.Map[String,DiGraph[LogicNode]],
+ moduleDeps: Map[String, Map[String,String]],
+ prefix: Seq[String],
+ cycle: Seq[LogicNode]): Seq[String] = {
+ def absNodeName(prefix: Seq[String], n: LogicNode) =
+ (prefix ++ n.inst ++ n.memport :+ n.name).mkString(".")
+ val cycNodes = (cycle zip cycle.tail) map { case (a, b) =>
+ if (a.inst.isDefined && !a.memport.isDefined && a.inst == b.inst) {
+ val child = moduleDeps(m)(a.inst.get)
+ val newprefix = prefix :+ a.inst.get
+ val subpath = moduleGraphs(child).path(b.copy(inst=None),a.copy(inst=None)).tail.reverse
+ recoverCycle(child,moduleGraphs,moduleDeps,newprefix,subpath)
+ } else {
+ Seq(absNodeName(prefix,a))
+ }
+ }
+ cycNodes.flatten :+ absNodeName(prefix, cycle.last)
+ }
+
+ /*
+ * This implementation of combinational loop detection avoids ever
+ * generating a full netlist from the FIRRTL circuit. Instead, each
+ * module is converted to a netlist and analyzed locally, with its
+ * subinstances represented by trivial, simplified subgraphs. The
+ * overall outline of the process is:
+ *
+ * 1. Create a graph of module instance dependances
+
+ * 2. Linearize this acyclic graph
+ *
+ * 3. Generate a local netlist; replace any instances with
+ * simplified subgraphs representing connectivity of their IOs
+ *
+ * 4. Check for nontrivial strongly connected components
+ *
+ * 5. Create a reduced representation of the netlist with only the
+ * module IOs as nodes, where output X (which must be a ground type,
+ * as only low FIRRTL is supported) will have an edge to input Y if
+ * and only if it combinationally depends on input Y. Associate this
+ * reduced graph with the module for future use.
+ */
+ def run(c: Circuit): Circuit = {
+ val errors = new Errors()
+ /* TODO(magyar): deal with exmodules! No pass warnings currently
+ * exist. Maybe warn when iterating through modules.
+ */
+ val moduleMap = c.modules.map({m => (m.name,m) }).toMap
+ val iGraph = new InstanceGraph(c)
+ val moduleDeps = iGraph.graph.edges.map{ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }
+ val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
+ val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
+ val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
+ for (m <- topoSortedModules) {
+ val internalDeps = new MutableDiGraph[LogicNode]
+ m.ports foreach { p => internalDeps.addVertex(LogicNode(p.name)) }
+ m map getStmtDeps(simplifiedModuleGraphs, internalDeps)
+ moduleGraphs(m.name) = DiGraph(internalDeps)
+ simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet)
+ for (scc <- moduleGraphs(m.name).findSCCs.filter(_.length > 1)) {
+ val cycle = recoverCycle(m.name,moduleGraphs,moduleDeps,Seq(m.name),scc :+ scc.head)
+ errors.append(new CombLoopException(m.info, m.name, cycle))
+ }
+ }
+ errors.trigger()
+ c
+ }
+
+}
diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
new file mode 100644
index 00000000..16482560
--- /dev/null
+++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
@@ -0,0 +1,127 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import firrtl._
+import firrtl.ir._
+import firrtl.passes._
+import firrtl.Mappers._
+import annotations._
+
+class CheckCombLoopsSpec extends SimpleTransformSpec {
+
+ def emitter = new LowFirrtlEmitter
+
+ def transforms = Seq(
+ new ChirrtlToHighFirrtl,
+ new IRToWorkingIR,
+ new ResolveAndCheck,
+ new HighFirrtlToMiddleFirrtl,
+ new MiddleFirrtlToLowFirrtl
+ )
+
+ "Simple combinational loop" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | z <= y
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm, None), writer)
+ }
+ }
+
+ "Node combinational loop" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | c <= b
+ | node z = and(c,y)
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm, None), writer)
+ }
+ }
+
+ "Combinational loop through a combinational memory read port" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | mem m :
+ | data-type => UInt<1>
+ | depth => 2
+ | read-latency => 0
+ | write-latency => 1
+ | reader => r
+ | read-under-write => undefined
+ | m.r.clk <= clk
+ | m.r.addr <= y
+ | m.r.en <= UInt(1)
+ | z <= m.r.data
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm, None), writer)
+ }
+ }
+
+ "Combination loop through an instance" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | module thru :
+ | input in : UInt<1>
+ | output out : UInt<1>
+ | out <= in
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | inst inner of thru
+ | inner.in <= y
+ | z <= inner.out
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm, None), writer)
+ }
+ }
+
+
+}