aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAlbert Magyar2018-12-21 10:41:35 -0800
committerGitHub2018-12-21 10:41:35 -0800
commit3433f8f8a82c3b129456e9512dd2cf442a6042f6 (patch)
treeada800a0c1e66a0c54bcf7ce4d1cf82fecdef8a2
parent93e1f334de0579f513c3ffa03cb5f06c622b4fa8 (diff)
Enhance CheckCombLoops to support annotated ExtModule paths (#962)
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala157
-rw-r--r--src/test/scala/firrtlTests/CheckCombLoopsSpec.scala86
2 files changed, 187 insertions, 56 deletions
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index 7afce210..9016dca4 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -7,6 +7,8 @@ import scala.collection.immutable.HashSet
import scala.collection.immutable.HashMap
import annotation.tailrec
+import Function.tupled
+
import firrtl._
import firrtl.ir._
import firrtl.passes.{Errors, PassException}
@@ -26,6 +28,23 @@ object CheckCombLoops {
case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation
+case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation {
+ if (!source.isLocal || !sink.isLocal || source.module != sink.module) {
+ throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module")
+ }
+
+ override def getTargets: Seq[ReferenceTarget] = Seq(source, sink)
+
+ override def update(renames: RenameMap): Seq[Annotation] = {
+ val sources = renames.get(source).getOrElse(Seq(source))
+ val sinks = renames.get(sink).getOrElse(Seq(sink))
+ val paths = sources flatMap { s => sinks.map((s, _)) }
+ paths.collect {
+ case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink)
+ }
+ }
+}
+
case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) extends Annotation {
override def update(renames: RenameMap): Seq[Annotation] = {
val newSources: Seq[IsComponent] = sources.flatMap { s => renames.get(s).getOrElse(Seq(s.toTarget)) }.collect {case x: IsComponent if x.isLocal => x}
@@ -36,12 +55,12 @@ case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) e
/** Finds and detects combinational logic loops in a circuit, if any
* exist. Returns the input circuit with no modifications.
- *
+ *
* @throws CombLoopException if a loop is found
* @note Input form: Low FIRRTL
* @note Output form: Low FIRRTL (identity transform)
* @note The pass looks for loops through combinational-read memories
- * @note The pass cannot find loops that pass through ExtModules
+ * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules
* @note The pass will throw exceptions on "false paths"
*/
class CheckCombLoops extends Transform with RegisteredTransform {
@@ -69,6 +88,8 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
private def toLogicNode(e: Expression): LogicNode = e match {
+ case idx: WSubIndex =>
+ toLogicNode(idx.expr)
case r: WRef =>
LogicNode(r.name)
case s: WSubField =>
@@ -95,29 +116,29 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private def getStmtDeps(
simplifiedModules: mutable.Map[String,DiGraph[LogicNode]],
deps: MutableDiGraph[LogicNode])(s: Statement): Unit = s match {
- case Connect(_,loc,expr) =>
- val lhs = toLogicNode(loc)
- if (deps.contains(lhs)) {
- getExprDeps(deps, lhs)(expr)
- }
- case w: DefWire =>
- deps.addVertex(LogicNode(w.name))
- case n: DefNode =>
- val lhs = LogicNode(n.name)
- deps.addVertex(lhs)
- getExprDeps(deps, lhs)(n.value)
- case m: DefMemory if (m.readLatency == 0) =>
- for (rp <- m.readers) {
- val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
- deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
- deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
- }
- case i: WDefInstance =>
- val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
- iGraph.getVertices.foreach(deps.addVertex(_))
- iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
- case _ =>
- s.foreach(getStmtDeps(simplifiedModules,deps))
+ case Connect(_,loc,expr) =>
+ val lhs = toLogicNode(loc)
+ if (deps.contains(lhs)) {
+ getExprDeps(deps, lhs)(expr)
+ }
+ case w: DefWire =>
+ deps.addVertex(LogicNode(w.name))
+ case n: DefNode =>
+ val lhs = LogicNode(n.name)
+ deps.addVertex(lhs)
+ getExprDeps(deps, lhs)(n.value)
+ case m: DefMemory if (m.readLatency == 0) =>
+ for (rp <- m.readers) {
+ val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
+ }
+ case i: WDefInstance =>
+ val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
+ iGraph.getVertices.foreach(deps.addVertex(_))
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
+ case _ =>
+ s.foreach(getStmtDeps(simplifiedModules,deps))
}
/*
@@ -129,7 +150,7 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private def expandInstancePaths(
m: String,
moduleGraphs: mutable.Map[String,DiGraph[LogicNode]],
- moduleDeps: Map[String, Map[String,String]],
+ moduleDeps: Map[String, Map[String,String]],
prefix: Seq[String],
path: Seq[LogicNode]): Seq[String] = {
def absNodeName(prefix: Seq[String], n: LogicNode) =
@@ -173,51 +194,65 @@ class CheckCombLoops extends Transform with RegisteredTransform {
* module is converted to a netlist and analyzed locally, with its
* subinstances represented by trivial, simplified subgraphs. The
* overall outline of the process is:
- *
+ *
* 1. Create a graph of module instance dependances
* 2. Linearize this acyclic graph
- *
+ *
* 3. Generate a local netlist; replace any instances with
* simplified subgraphs representing connectivity of their IOs
- *
+ *
* 4. Check for nontrivial strongly connected components
- *
+ *
* 5. Create a reduced representation of the netlist with only the
* module IOs as nodes, where output X (which must be a ground type,
* as only low FIRRTL is supported) will have an edge to input Y if
* and only if it combinationally depends on input Y. Associate this
* reduced graph with the module for future use.
*/
- private def run(c: Circuit): (Circuit, Seq[Annotation]) = {
+ private def run(state: CircuitState) = {
+ val c = state.circuit
val errors = new Errors()
- /* TODO(magyar): deal with exmodules! No pass warnings currently
- * exist. Maybe warn when iterating through modules.
- */
+ val extModulePaths = state.annotations.groupBy {
+ case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module)
+ case ann: Annotation => CircuitTarget(c.main)
+ }
val moduleMap = c.modules.map({m => (m.name,m) }).toMap
val iGraph = new InstanceGraph(c).graph
val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
- for (m <- topoSortedModules) {
- val internalDeps = new MutableDiGraph[LogicNode]
- m.ports.foreach({ p => internalDeps.addVertex(LogicNode(p.name)) })
- m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
- val moduleGraph = DiGraph(internalDeps)
- moduleGraphs(m.name) = moduleGraph
- simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet)
- // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
- for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) {
- errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
- }
- for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) {
- val sccSubgraph = moduleGraph.subgraph(scc.toSet)
- val cycle = findCycleInSCC(sccSubgraph)
- (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
- val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
- errors.append(new CombLoopException(m.info, m.name, expandedCycle))
- }
+ topoSortedModules.foreach {
+ case em: ExtModule =>
+ val portSet = em.ports.map(p => LogicNode(p.name)).toSet
+ val extModuleDeps = new MutableDiGraph[LogicNode]
+ portSet.foreach(extModuleDeps.addVertex(_))
+ extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect {
+ case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
+ }
+ moduleGraphs(em.name) = DiGraph(extModuleDeps).simplify(portSet)
+ simplifiedModuleGraphs(em.name) = moduleGraphs(em.name)
+ case m: Module =>
+ val portSet = m.ports.map(p => LogicNode(p.name)).toSet
+ val internalDeps = new MutableDiGraph[LogicNode]
+ portSet.foreach(internalDeps.addVertex(_))
+ m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
+ val moduleGraph = DiGraph(internalDeps)
+ moduleGraphs(m.name) = moduleGraph
+ simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet)
+ // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
+ for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) {
+ errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
+ }
+ for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) {
+ val sccSubgraph = moduleGraph.subgraph(scc.toSet)
+ val cycle = findCycleInSCC(sccSubgraph)
+ (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
+ val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
+ errors.append(new CombLoopException(m.info, m.name, expandedCycle))
+ }
+ case m => throwInternalError(s"Module ${m.name} has unrecognized type")
}
val mn = ModuleName(c.main, CircuitName(c.main))
val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty =>
@@ -225,8 +260,17 @@ class CheckCombLoops extends Transform with RegisteredTransform {
val sources = tos.map(x => ComponentName(x.name, mn))
CombinationalPath(sink, sources.toSeq)
}
- errors.trigger()
- (c, annos.toSeq)
+ (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs)
+ }
+
+ /**
+ * Returns a Map from Module name to port connectivity
+ */
+ def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = {
+ val (result, errors, connectivity) = run(state)
+ connectivity.map {
+ case (k, v) => (k, v.transformNodes(ln => ln.name))
+ }
}
def execute(state: CircuitState): CircuitState = {
@@ -235,8 +279,9 @@ class CheckCombLoops extends Transform with RegisteredTransform {
logger.warn("Skipping Combinational Loop Detection")
state
} else {
- val (result, annos) = run(state.circuit)
- CircuitState(result, outputForm, state.annotations ++ annos, state.renames)
+ val (result, errors, connectivity) = run(state)
+ errors.trigger()
+ result
}
}
}
diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
index 8fc7dda9..98472f14 100644
--- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
+++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
@@ -165,6 +165,92 @@ class CheckCombLoopsSpec extends SimpleTransformSpec {
}
}
+ "Combinational loop through an annotated ExtModule" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | extmodule blackbox :
+ | input in : UInt<1>
+ | output out : UInt<1>
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | inst inner of blackbox
+ | inner.in <= y
+ | z <= inner.out
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val mt = ModuleTarget("hasloops", "blackbox")
+ val annos = AnnotationSeq(Seq(ExtModulePathAnnotation(mt.ref("in"), mt.ref("out"))))
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm, annos), writer)
+ }
+ }
+
+ "Loop-free circuit with ExtModulePathAnnotations" should "not throw an exception" in {
+ val input = """circuit hasnoloops :
+ | extmodule blackbox :
+ | input in1 : UInt<1>
+ | input in2 : UInt<1>
+ | output out1 : UInt<1>
+ | output out2 : UInt<1>
+ | module hasnoloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | output b : UInt<1>
+ | wire x : UInt<1>
+ | inst inner of blackbox
+ | inner.in1 <= a
+ | x <= inner.out1
+ | inner.in2 <= x
+ | b <= inner.out2
+ |""".stripMargin
+
+ val mt = ModuleTarget("hasnoloops", "blackbox")
+ val annos = AnnotationSeq(Seq(
+ ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")),
+ ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2"))))
+ val writer = new java.io.StringWriter
+ compile(CircuitState(parse(input), ChirrtlForm, annos), writer)
+ }
+
+ "Combinational loop through an output RHS reference" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | module thru :
+ | input in : UInt<1>
+ | output tmp : UInt<1>
+ | output out : UInt<1>
+ | tmp <= in
+ | out <= tmp
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | inst inner of thru
+ | inner.in <= y
+ | z <= inner.out
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm), writer)
+ }
+ }
+
"Multiple simple loops in one SCC" should "throw an exception" in {
val input = """circuit hasloops :
| module hasloops :