aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms/CheckCombLoops.scala
blob: 6bd62cfa2982005653165e0089e326b19bed6900 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// See LICENSE for license details.

package firrtl.transforms

import scala.collection.mutable
import scala.collection.immutable.HashSet
import scala.collection.immutable.HashMap
import annotation.tailrec

import firrtl._
import firrtl.ir._
import firrtl.passes.{Errors, PassException}
import firrtl.Mappers._
import firrtl.annotations._
import firrtl.Utils.throwInternalError
import firrtl.graph.{MutableDiGraph,DiGraph}
import firrtl.analyses.InstanceGraph

object CheckCombLoops {
  class CombLoopException(info: Info, mname: String, cycle: Seq[String]) extends PassException(
    s"$info: [module $mname] Combinational loop detected:\n" + cycle.mkString("\n"))

}

case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation

/** Finds and detects combinational logic loops in a circuit, if any
  * exist. Returns the input circuit with no modifications.
  * 
  * @throws CombLoopException if a loop is found
  * @note Input form: Low FIRRTL
  * @note Output form: Low FIRRTL (identity transform)
  * @note The pass looks for loops through combinational-read memories
  * @note The pass cannot find loops that pass through ExtModules
  * @note The pass will throw exceptions on "false paths"
  */
class CheckCombLoops extends Transform {
  def inputForm = LowForm
  def outputForm = LowForm

  import CheckCombLoops._

  /*
   * A case class that represents a net in the circuit. This is
   * necessary since combinational loop checking is an analysis on the
   * netlist of the circuit; the fields are specialized for low
   * FIRRTL. Since all wires are ground types, a given ground type net
   * may only be a subfield of an instance or a memory
   * port. Therefore, it is uniquely specified within its module
   * context by its name, its optional parent instance (a WDefInstance
   * or WDefMemory), and its optional memory port name.
   */
  private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)

  private def toLogicNode(e: Expression): LogicNode = e match {
    case r: WRef =>
      LogicNode(r.name)
    case s: WSubField =>
      s.expr match {
        case modref: WRef =>
          LogicNode(s.name,Some(modref.name))
        case memport: WSubField =>
          memport.expr match {
            case memref: WRef =>
              LogicNode(s.name,Some(memref.name),Some(memport.name))
            case _ => throwInternalError(Some(s"toLogicNode: unrecognized subsubfield expression - $memport"))
          }
        case _ => throwInternalError(Some(s"toLogicNode: unrecognized subfield expression - $s"))
      }
  }


  private def getExprDeps(deps: MutableDiGraph[LogicNode], v: LogicNode)(e: Expression): Expression = e match {
    case r: WRef =>
      deps.addEdgeIfValid(v, toLogicNode(r))
      r
    case s: WSubField =>
      deps.addEdgeIfValid(v, toLogicNode(s))
      s
    case _ =>
      e map getExprDeps(deps, v)
  }

  private def getStmtDeps(
    simplifiedModules: mutable.Map[String,DiGraph[LogicNode]],
    deps: MutableDiGraph[LogicNode])(s: Statement): Statement = {
    s match {
      case Connect(_,loc,expr) =>
        val lhs = toLogicNode(loc)
        if (deps.contains(lhs)) {
          getExprDeps(deps, lhs)(expr)
        }
      case w: DefWire =>
        deps.addVertex(LogicNode(w.name))
      case n: DefNode =>
        val lhs = LogicNode(n.name)
        deps.addVertex(lhs)
        getExprDeps(deps, lhs)(n.value)
      case m: DefMemory if (m.readLatency == 0) =>
        for (rp <- m.readers) {
          val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
          deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
          deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
        }
      case i: WDefInstance =>
        val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
        iGraph.getVertices.foreach(deps.addVertex(_))
        iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
      case _ =>
        s map getStmtDeps(simplifiedModules,deps)
    }
    s
  }

  /*
   * Recover the full path from a path passing through simplified
   * instances. Since edges may pass through simplified instances, the
   * hierarchy that the path passes through must be recursively
   * recovered.
   */
  private def expandInstancePaths(
    m: String,
    moduleGraphs: mutable.Map[String,DiGraph[LogicNode]],
    moduleDeps: Map[String, Map[String,String]], 
    prefix: Seq[String],
    path: Seq[LogicNode]): Seq[String] = {
    def absNodeName(prefix: Seq[String], n: LogicNode) =
      (prefix ++ n.inst ++ n.memport :+ n.name).mkString(".")
    val pathNodes = (path zip path.tail) map { case (a, b) =>
      if (a.inst.isDefined && !a.memport.isDefined && a.inst == b.inst) {
        val child = moduleDeps(m)(a.inst.get)
        val newprefix = prefix :+ a.inst.get
        val subpath = moduleGraphs(child).path(b.copy(inst=None),a.copy(inst=None)).tail.reverse
        expandInstancePaths(child,moduleGraphs,moduleDeps,newprefix,subpath)
      } else {
        Seq(absNodeName(prefix,a))
      }
    }
    pathNodes.flatten :+ absNodeName(prefix, path.last)
  }

  /*
   * An SCC may contain more than one loop. In this case, the sequence
   * of nodes forming the SCC cannot be interpreted as a simple
   * cycle. However, it is desirable to print an error consisting of a
   * loop rather than an arbitrary ordering of the SCC. This function
   * operates on a pruned subgraph composed only of the SCC and finds
   * a simple cycle by performing an arbitrary walk.
   */
  private def findCycleInSCC[T](sccGraph: DiGraph[T]): Seq[T] = {
    val walk = new mutable.ArrayBuffer[T]
    val visited = new mutable.HashSet[T]
    var current = sccGraph.getVertices.head
    while (!visited.contains(current)) {
      walk += current
      visited += current
      current = sccGraph.getEdges(current).head
    }
    walk.drop(walk.indexOf(current)).toSeq :+ current
  }

  /*
   * This implementation of combinational loop detection avoids ever
   * generating a full netlist from the FIRRTL circuit. Instead, each
   * module is converted to a netlist and analyzed locally, with its
   * subinstances represented by trivial, simplified subgraphs. The
   * overall outline of the process is:
   * 
   * 1. Create a graph of module instance dependances

   * 2. Linearize this acyclic graph
   * 
   * 3. Generate a local netlist; replace any instances with
   * simplified subgraphs representing connectivity of their IOs
   * 
   * 4. Check for nontrivial strongly connected components
   * 
   * 5. Create a reduced representation of the netlist with only the
   * module IOs as nodes, where output X (which must be a ground type,
   * as only low FIRRTL is supported) will have an edge to input Y if
   * and only if it combinationally depends on input Y. Associate this
   * reduced graph with the module for future use.
   */
  private def run(c: Circuit): Circuit = {
    val errors = new Errors()
    /* TODO(magyar): deal with exmodules! No pass warnings currently
     *  exist. Maybe warn when iterating through modules.
     */
    val moduleMap = c.modules.map({m => (m.name,m) }).toMap
    val iGraph = new InstanceGraph(c).graph
    val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
    val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
    val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
    val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
    for (m <- topoSortedModules) {
      val internalDeps = new MutableDiGraph[LogicNode]
      m.ports.foreach({ p => internalDeps.addVertex(LogicNode(p.name)) })
      m map getStmtDeps(simplifiedModuleGraphs, internalDeps)
      val moduleGraph = DiGraph(internalDeps)
      moduleGraphs(m.name) = moduleGraph
      simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet)
      for (scc <- moduleGraphs(m.name).findSCCs.filter(_.length > 1)) {
        val sccSubgraph = moduleGraphs(m.name).subgraph(scc.toSet)
        val cycle = findCycleInSCC(sccSubgraph)
        (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
        val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
        errors.append(new CombLoopException(m.info, m.name, expandedCycle))
      }
    }
    errors.trigger()
    c
  }

  def execute(state: CircuitState): CircuitState = {
    val dontRun = state.annotations.contains(DontCheckCombLoopsAnnotation)
    if (dontRun) {
      logger.warn("Skipping Combinational Loop Detection")
      state
    } else {
      val result = run(state.circuit)
      CircuitState(result, outputForm, state.annotations, state.renames)
    }
  }
}