diff options
| author | Albert Magyar | 2018-12-21 10:41:35 -0800 |
|---|---|---|
| committer | GitHub | 2018-12-21 10:41:35 -0800 |
| commit | 3433f8f8a82c3b129456e9512dd2cf442a6042f6 (patch) | |
| tree | ada800a0c1e66a0c54bcf7ce4d1cf82fecdef8a2 /src | |
| parent | 93e1f334de0579f513c3ffa03cb5f06c622b4fa8 (diff) | |
Enhance CheckCombLoops to support annotated ExtModule paths (#962)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/CheckCombLoops.scala | 157 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/CheckCombLoopsSpec.scala | 86 |
2 files changed, 187 insertions, 56 deletions
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index 7afce210..9016dca4 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -7,6 +7,8 @@ import scala.collection.immutable.HashSet import scala.collection.immutable.HashMap import annotation.tailrec +import Function.tupled + import firrtl._ import firrtl.ir._ import firrtl.passes.{Errors, PassException} @@ -26,6 +28,23 @@ object CheckCombLoops { case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation +case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation { + if (!source.isLocal || !sink.isLocal || source.module != sink.module) { + throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module") + } + + override def getTargets: Seq[ReferenceTarget] = Seq(source, sink) + + override def update(renames: RenameMap): Seq[Annotation] = { + val sources = renames.get(source).getOrElse(Seq(source)) + val sinks = renames.get(sink).getOrElse(Seq(sink)) + val paths = sources flatMap { s => sinks.map((s, _)) } + paths.collect { + case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink) + } + } +} + case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) extends Annotation { override def update(renames: RenameMap): Seq[Annotation] = { val newSources: Seq[IsComponent] = sources.flatMap { s => renames.get(s).getOrElse(Seq(s.toTarget)) }.collect {case x: IsComponent if x.isLocal => x} @@ -36,12 +55,12 @@ case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) e /** Finds and detects combinational logic loops in a circuit, if any * exist. Returns the input circuit with no modifications. - * + * * @throws CombLoopException if a loop is found * @note Input form: Low FIRRTL * @note Output form: Low FIRRTL (identity transform) * @note The pass looks for loops through combinational-read memories - * @note The pass cannot find loops that pass through ExtModules + * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules * @note The pass will throw exceptions on "false paths" */ class CheckCombLoops extends Transform with RegisteredTransform { @@ -69,6 +88,8 @@ class CheckCombLoops extends Transform with RegisteredTransform { private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None) private def toLogicNode(e: Expression): LogicNode = e match { + case idx: WSubIndex => + toLogicNode(idx.expr) case r: WRef => LogicNode(r.name) case s: WSubField => @@ -95,29 +116,29 @@ class CheckCombLoops extends Transform with RegisteredTransform { private def getStmtDeps( simplifiedModules: mutable.Map[String,DiGraph[LogicNode]], deps: MutableDiGraph[LogicNode])(s: Statement): Unit = s match { - case Connect(_,loc,expr) => - val lhs = toLogicNode(loc) - if (deps.contains(lhs)) { - getExprDeps(deps, lhs)(expr) - } - case w: DefWire => - deps.addVertex(LogicNode(w.name)) - case n: DefNode => - val lhs = LogicNode(n.name) - deps.addVertex(lhs) - getExprDeps(deps, lhs)(n.value) - case m: DefMemory if (m.readLatency == 0) => - for (rp <- m.readers) { - val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) - deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp)))) - deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp)))) - } - case i: WDefInstance => - val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) - iGraph.getVertices.foreach(deps.addVertex(_)) - iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) - case _ => - s.foreach(getStmtDeps(simplifiedModules,deps)) + case Connect(_,loc,expr) => + val lhs = toLogicNode(loc) + if (deps.contains(lhs)) { + getExprDeps(deps, lhs)(expr) + } + case w: DefWire => + deps.addVertex(LogicNode(w.name)) + case n: DefNode => + val lhs = LogicNode(n.name) + deps.addVertex(lhs) + getExprDeps(deps, lhs)(n.value) + case m: DefMemory if (m.readLatency == 0) => + for (rp <- m.readers) { + val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp)))) + deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp)))) + } + case i: WDefInstance => + val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) + iGraph.getVertices.foreach(deps.addVertex(_)) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) + case _ => + s.foreach(getStmtDeps(simplifiedModules,deps)) } /* @@ -129,7 +150,7 @@ class CheckCombLoops extends Transform with RegisteredTransform { private def expandInstancePaths( m: String, moduleGraphs: mutable.Map[String,DiGraph[LogicNode]], - moduleDeps: Map[String, Map[String,String]], + moduleDeps: Map[String, Map[String,String]], prefix: Seq[String], path: Seq[LogicNode]): Seq[String] = { def absNodeName(prefix: Seq[String], n: LogicNode) = @@ -173,51 +194,65 @@ class CheckCombLoops extends Transform with RegisteredTransform { * module is converted to a netlist and analyzed locally, with its * subinstances represented by trivial, simplified subgraphs. The * overall outline of the process is: - * + * * 1. Create a graph of module instance dependances * 2. Linearize this acyclic graph - * + * * 3. Generate a local netlist; replace any instances with * simplified subgraphs representing connectivity of their IOs - * + * * 4. Check for nontrivial strongly connected components - * + * * 5. Create a reduced representation of the netlist with only the * module IOs as nodes, where output X (which must be a ground type, * as only low FIRRTL is supported) will have an edge to input Y if * and only if it combinationally depends on input Y. Associate this * reduced graph with the module for future use. */ - private def run(c: Circuit): (Circuit, Seq[Annotation]) = { + private def run(state: CircuitState) = { + val c = state.circuit val errors = new Errors() - /* TODO(magyar): deal with exmodules! No pass warnings currently - * exist. Maybe warn when iterating through modules. - */ + val extModulePaths = state.annotations.groupBy { + case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module) + case ann: Annotation => CircuitTarget(c.main) + } val moduleMap = c.modules.map({m => (m.name,m) }).toMap val iGraph = new InstanceGraph(c).graph val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] - for (m <- topoSortedModules) { - val internalDeps = new MutableDiGraph[LogicNode] - m.ports.foreach({ p => internalDeps.addVertex(LogicNode(p.name)) }) - m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps)) - val moduleGraph = DiGraph(internalDeps) - moduleGraphs(m.name) = moduleGraph - simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet) - // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs! - for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) { - errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name))) - } - for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) { - val sccSubgraph = moduleGraph.subgraph(scc.toSet) - val cycle = findCycleInSCC(sccSubgraph) - (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) - val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) - errors.append(new CombLoopException(m.info, m.name, expandedCycle)) - } + topoSortedModules.foreach { + case em: ExtModule => + val portSet = em.ports.map(p => LogicNode(p.name)).toSet + val extModuleDeps = new MutableDiGraph[LogicNode] + portSet.foreach(extModuleDeps.addVertex(_)) + extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect { + case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref)) + } + moduleGraphs(em.name) = DiGraph(extModuleDeps).simplify(portSet) + simplifiedModuleGraphs(em.name) = moduleGraphs(em.name) + case m: Module => + val portSet = m.ports.map(p => LogicNode(p.name)).toSet + val internalDeps = new MutableDiGraph[LogicNode] + portSet.foreach(internalDeps.addVertex(_)) + m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps)) + val moduleGraph = DiGraph(internalDeps) + moduleGraphs(m.name) = moduleGraph + simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet) + // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs! + for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) { + errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name))) + } + for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) { + val sccSubgraph = moduleGraph.subgraph(scc.toSet) + val cycle = findCycleInSCC(sccSubgraph) + (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) + val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) + errors.append(new CombLoopException(m.info, m.name, expandedCycle)) + } + case m => throwInternalError(s"Module ${m.name} has unrecognized type") } val mn = ModuleName(c.main, CircuitName(c.main)) val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty => @@ -225,8 +260,17 @@ class CheckCombLoops extends Transform with RegisteredTransform { val sources = tos.map(x => ComponentName(x.name, mn)) CombinationalPath(sink, sources.toSeq) } - errors.trigger() - (c, annos.toSeq) + (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs) + } + + /** + * Returns a Map from Module name to port connectivity + */ + def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = { + val (result, errors, connectivity) = run(state) + connectivity.map { + case (k, v) => (k, v.transformNodes(ln => ln.name)) + } } def execute(state: CircuitState): CircuitState = { @@ -235,8 +279,9 @@ class CheckCombLoops extends Transform with RegisteredTransform { logger.warn("Skipping Combinational Loop Detection") state } else { - val (result, annos) = run(state.circuit) - CircuitState(result, outputForm, state.annotations ++ annos, state.renames) + val (result, errors, connectivity) = run(state) + errors.trigger() + result } } } diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala index 8fc7dda9..98472f14 100644 --- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala +++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala @@ -165,6 +165,92 @@ class CheckCombLoopsSpec extends SimpleTransformSpec { } } + "Combinational loop through an annotated ExtModule" should "throw an exception" in { + val input = """circuit hasloops : + | extmodule blackbox : + | input in : UInt<1> + | output out : UInt<1> + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of blackbox + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin + + val mt = ModuleTarget("hasloops", "blackbox") + val annos = AnnotationSeq(Seq(ExtModulePathAnnotation(mt.ref("in"), mt.ref("out")))) + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm, annos), writer) + } + } + + "Loop-free circuit with ExtModulePathAnnotations" should "not throw an exception" in { + val input = """circuit hasnoloops : + | extmodule blackbox : + | input in1 : UInt<1> + | input in2 : UInt<1> + | output out1 : UInt<1> + | output out2 : UInt<1> + | module hasnoloops : + | input clk : Clock + | input a : UInt<1> + | output b : UInt<1> + | wire x : UInt<1> + | inst inner of blackbox + | inner.in1 <= a + | x <= inner.out1 + | inner.in2 <= x + | b <= inner.out2 + |""".stripMargin + + val mt = ModuleTarget("hasnoloops", "blackbox") + val annos = AnnotationSeq(Seq( + ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")), + ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2")))) + val writer = new java.io.StringWriter + compile(CircuitState(parse(input), ChirrtlForm, annos), writer) + } + + "Combinational loop through an output RHS reference" should "throw an exception" in { + val input = """circuit hasloops : + | module thru : + | input in : UInt<1> + | output tmp : UInt<1> + | output out : UInt<1> + | tmp <= in + | out <= tmp + | module hasloops : + | input clk : Clock + | input a : UInt<1> + | input b : UInt<1> + | output c : UInt<1> + | output d : UInt<1> + | wire y : UInt<1> + | wire z : UInt<1> + | c <= b + | inst inner of thru + | inner.in <= y + | z <= inner.out + | y <= z + | d <= z + |""".stripMargin + + val writer = new java.io.StringWriter + intercept[CheckCombLoops.CombLoopException] { + compile(CircuitState(parse(input), ChirrtlForm), writer) + } + } + "Multiple simple loops in one SCC" should "throw an exception" in { val input = """circuit hasloops : | module hasloops : |
