From 476d0d7475d641bc61d7630c8a7a8966cf61e04c Mon Sep 17 00:00:00 2001 From: Albert Magyar Date: Sun, 26 Feb 2017 16:11:37 -0800 Subject: Add pass to detect combinational loops --- src/main/scala/firrtl/LoweringCompilers.scala | 3 +- src/main/scala/firrtl/passes/CheckCombLoops.scala | 180 +++++++++++++++++++++ .../scala/firrtlTests/CheckCombLoopsSpec.scala | 127 +++++++++++++++ 3 files changed, 309 insertions(+), 1 deletion(-) create mode 100644 src/main/scala/firrtl/passes/CheckCombLoops.scala create mode 100644 src/test/scala/firrtlTests/CheckCombLoopsSpec.scala (limited to 'src') diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index b5808e93..b9042781 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -86,7 +86,8 @@ class MiddleFirrtlToLowFirrtl extends CoreTransform { passes.InferTypes, passes.ResolveGenders, passes.InferWidths, - passes.Legalize) + passes.Legalize, + passes.CheckCombLoops) } /** Runs a series of optimization passes on LowFirrtl diff --git a/src/main/scala/firrtl/passes/CheckCombLoops.scala b/src/main/scala/firrtl/passes/CheckCombLoops.scala new file mode 100644 index 00000000..af2ab666 --- /dev/null +++ b/src/main/scala/firrtl/passes/CheckCombLoops.scala @@ -0,0 +1,180 @@ +// See LICENSE for license details. + +package firrtl.passes + +import scala.collection.mutable +import scala.collection.immutable.HashSet +import scala.collection.immutable.HashMap +import annotation.tailrec + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.Utils.throwInternalError +import firrtl.graph.{MutableDiGraph,DiGraph} +import firrtl.analyses.InstanceGraph + +/** Finds and detects combinational logic loops in a circuit, if any + * exist. Returns the input circuit with no modifications. + * + * @throws a CombLoopException if a loop is found + * @note Input form: Low FIRRTL + * @note Output form: Low FIRRTL (identity transform) + * @note The pass looks for loops through combinational-read memories + * @note The pass cannot find loops that pass through ExtModules + * @note The pass will throw exceptions on "false paths" + */ + +object CheckCombLoops extends Pass { + + class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException( + s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n")) + + /* + * A case class that represents a net in the circuit. This is + * necessary since combinational loop checking is an analysis on the + * netlist of the circuit; the fields are specialized for low + * FIRRTL. Since all wires are ground types, a given ground type net + * may only be a subfield of an instance or a memory + * port. Therefore, it is uniquely specified within its module + * context by its name, its optional parent instance (a WDefInstance + * or WDefMemory), and its optional memory port name. + */ + private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None) + + private def toLogicNode(e: Expression): LogicNode = e match { + case r: WRef => + LogicNode(r.name) + case s: WSubField => + s.exp match { + case modref: WRef => + LogicNode(s.name,Some(modref.name)) + case memport: WSubField => + memport.exp match { + case memref: WRef => + LogicNode(s.name,Some(memref.name),Some(memport.name)) + case _ => throwInternalError + } + case _ => throwInternalError + } + } + + private def getExprDeps(deps: mutable.Set[LogicNode])(e: Expression): Expression = e match { + case r: WRef => + deps += toLogicNode(r) + r + case s: WSubField => + deps += toLogicNode(s) + s + case _ => + e map getExprDeps(deps) + } + + private def getStmtDeps( + simplifiedModules: mutable.Map[String,DiGraph[LogicNode]], + deps: MutableDiGraph[LogicNode])(s: Statement): Statement = { + s match { + case Connect(_,loc,expr) => + val lhs = toLogicNode(loc) + if (deps.contains(lhs)) { + getExprDeps(deps.getEdges(lhs))(expr) + } + case w: DefWire => + deps.addVertex(LogicNode(w.name)) + case n: DefNode => + val lhs = LogicNode(n.name) + deps.addVertex(lhs) + getExprDeps(deps.getEdges(lhs))(n.value) + case m: DefMemory if (m.readLatency == 0) => + for (rp <- m.readers) { + val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp)))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp)))) + } + case i: WDefInstance => + val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) + for (v <- iGraph.getVertices) { + deps.addVertex(v) + iGraph.getEdges(v).foreach { deps.addEdge(v,_) } + } + case _ => + s map getStmtDeps(simplifiedModules,deps) + } + s + } + + /* + * Recover the full path of the cycle. Since cycles may pass through + * simplified instances, the hierarchy that the path passes through + * must be recursively recovered. + */ + private def recoverCycle( + m: String, + moduleGraphs: mutable.Map[String,DiGraph[LogicNode]], + moduleDeps: Map[String, Map[String,String]], + prefix: Seq[String], + cycle: Seq[LogicNode]): Seq[String] = { + def absNodeName(prefix: Seq[String], n: LogicNode) = + (prefix ++ n.inst ++ n.memport :+ n.name).mkString(".") + val cycNodes = (cycle zip cycle.tail) map { case (a, b) => + if (a.inst.isDefined && !a.memport.isDefined && a.inst == b.inst) { + val child = moduleDeps(m)(a.inst.get) + val newprefix = prefix :+ a.inst.get + val subpath = moduleGraphs(child).path(b.copy(inst=None),a.copy(inst=None)).tail.reverse + recoverCycle(child,moduleGraphs,moduleDeps,newprefix,subpath) + } else { + Seq(absNodeName(prefix,a)) + } + } + cycNodes.flatten :+ absNodeName(prefix, cycle.last) + } + + /* + * This implementation of combinational loop detection avoids ever + * generating a full netlist from the FIRRTL circuit. Instead, each + * module is converted to a netlist and analyzed locally, with its + * subinstances represented by trivial, simplified subgraphs. The + * overall outline of the process is: + * + * 1. Create a graph of module instance dependances + + * 2. Linearize this acyclic graph + * + * 3. Generate a local netlist; replace any instances with + * simplified subgraphs representing connectivity of their IOs + * + * 4. Check for nontrivial strongly connected components + * + * 5. Create a reduced representation of the netlist with only the + * module IOs as nodes, where output X (which must be a ground type, + * as only low FIRRTL is supported) will have an edge to input Y if + * and only if it combinationally depends on input Y. Associate this + * reduced graph with the module for future use. + */ + def run(c: Circuit): Circuit = { + val errors = new Errors() + /* TODO(magyar): deal with exmodules! No pass warnings currently + * exist. Maybe warn when iterating through modules. + */ + val moduleMap = c.modules.map({m => (m.name,m) }).toMap + val iGraph = new InstanceGraph(c) + val moduleDeps = iGraph.graph.edges.map{ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) } + val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } + val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] + val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] + for (m <- topoSortedModules) { + val internalDeps = new MutableDiGraph[LogicNode] + m.ports foreach { p => internalDeps.addVertex(LogicNode(p.name)) } + m map getStmtDeps(simplifiedModuleGraphs, internalDeps) + moduleGraphs(m.name) = DiGraph(internalDeps) + simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet) + for (scc <- moduleGraphs(m.name).findSCCs.filter(_.length > 1)) { + val cycle = recoverCycle(m.name,moduleGraphs,moduleDeps,Seq(m.name),scc :+ scc.head) + errors.append(new CombLoopException(m.info, m.name, cycle)) + } + } + errors.trigger() + c + } + +} diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala new file mode 100644 index 00000000..16482560 --- /dev/null +++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala @@ -0,0 +1,127 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.passes._ +import firrtl.Mappers._ +import annotations._ + +class CheckCombLoopsSpec extends SimpleTransformSpec { + + def emitter = new LowFirrtlEmitter + + def transforms = Seq( + new ChirrtlToHighFirrtl, + new IRToWorkingIR, + new ResolveAndCheck, + new HighFirrtlToMiddleFirrtl, + new MiddleFirrtlToLowFirrtl + ) + + "Simple combinational loop" should "throw an exception" in { + val input = """circuit hasloops : + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | z <= y + | y <= z + | d <= z + |""".stripMargin + + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm, None), writer) + } + } + + "Node combinational loop" should "throw an exception" in { + val input = """circuit hasloops : + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | c <= b + | node z = and(c,y) + | y <= z + | d <= z + |""".stripMargin + + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm, None), writer) + } + } + + "Combinational loop through a combinational memory read port" should "throw an exception" in { + val input = """circuit hasloops : + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | mem m : + | data-type => UInt<1> + | depth => 2 + | read-latency => 0 + | write-latency => 1 + | reader => r + | read-under-write => undefined + | m.r.clk <= clk + | m.r.addr <= y + | m.r.en <= UInt(1) + | z <= m.r.data + | y <= z + | d <= z + |""".stripMargin + + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm, None), writer) + } + } + + "Combination loop through an instance" should "throw an exception" in { + val input = """circuit hasloops : + | module thru : + | input in : UInt<1> + | output out : UInt<1> + | out <= in + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin + + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm, None), writer) + } + } + + +} -- cgit v1.2.3