diff options
Diffstat (limited to 'src/main/scala/firrtl/transforms')
3 files changed, 35 insertions, 36 deletions
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala index d50a027c..bb2ffea9 100644 --- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala +++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala @@ -77,15 +77,16 @@ class CheckCombLoops extends Transform { } } - private def getExprDeps(deps: mutable.Set[LogicNode])(e: Expression): Expression = e match { + + private def getExprDeps(deps: MutableDiGraph[LogicNode], v: LogicNode)(e: Expression): Expression = e match { case r: WRef => - deps += toLogicNode(r) + deps.addEdgeIfValid(v, toLogicNode(r)) r case s: WSubField => - deps += toLogicNode(s) + deps.addEdgeIfValid(v, toLogicNode(s)) s case _ => - e map getExprDeps(deps) + e map getExprDeps(deps, v) } private def getStmtDeps( @@ -95,14 +96,14 @@ class CheckCombLoops extends Transform { case Connect(_,loc,expr) => val lhs = toLogicNode(loc) if (deps.contains(lhs)) { - getExprDeps(deps.getEdges(lhs))(expr) + getExprDeps(deps, lhs)(expr) } case w: DefWire => deps.addVertex(LogicNode(w.name)) case n: DefNode => val lhs = LogicNode(n.name) deps.addVertex(lhs) - getExprDeps(deps.getEdges(lhs))(n.value) + getExprDeps(deps, lhs)(n.value) case m: DefMemory if (m.readLatency == 0) => for (rp <- m.readers) { val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp))) @@ -111,10 +112,8 @@ class CheckCombLoops extends Transform { } case i: WDefInstance => val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name))) - for (v <- iGraph.getVertices) { - deps.addVertex(v) - iGraph.getEdges(v).foreach { deps.addEdge(v,_) } - } + iGraph.getVertices.foreach(deps.addVertex(_)) + iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } }) case _ => s map getStmtDeps(simplifiedModules,deps) } @@ -196,9 +195,9 @@ class CheckCombLoops extends Transform { * exist. Maybe warn when iterating through modules. */ val moduleMap = c.modules.map({m => (m.name,m) }).toMap - val iGraph = new InstanceGraph(c) - val moduleDeps = iGraph.graph.edges.map{ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) } - val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } + val iGraph = new InstanceGraph(c).graph + val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap + val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) } val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]] for (m <- topoSortedModules) { @@ -212,7 +211,7 @@ class CheckCombLoops extends Transform { val sccSubgraph = moduleGraphs(m.name).subgraph(scc.toSet) val cycle = findCycleInSCC(sccSubgraph) (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) }) - val expandedCycle = expandInstancePaths(m.name,moduleGraphs,moduleDeps,Seq(m.name),cycle.reverse) + val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse) errors.append(new CombLoopException(m.info, m.name, expandedCycle)) } } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 69502911..84b63e3d 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -398,9 +398,9 @@ class ConstantPropagation extends Transform { private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = { val iGraph = (new InstanceGraph(c)).graph - val moduleDeps = iGraph.edges.map { case (mod, children) => + val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) => mod.module -> children.map(i => i.name -> i.module).toMap - } + }) // Module name to number of instances val instCount: Map[String, Int] = iGraph.getVertices.groupBy(_.module).mapValues(_.size) diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index 24c1c51c..578de264 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -95,11 +95,11 @@ class DeadCodeElimination extends Transform { case DefRegister(_, name, _, clock, reset, init) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) - Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref)) + Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefNode(_, name, value) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) - getDeps(value).foreach(ref => depGraph.addEdge(node, ref)) + getDeps(value).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefWire(_, name, _) => depGraph.addVertex(LogicNode(mod.name, name)) case mem: DefMemory => @@ -111,23 +111,23 @@ class DeadCodeElimination extends Transform { val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_)) val memNode = getDeps(memRef) match { case Seq(node) => node } depGraph.addVertex(memNode) - sinks.foreach(sink => depGraph.addEdge(sink, memNode)) - sources.foreach(source => depGraph.addEdge(memNode, source)) + sinks.foreach(sink => depGraph.addPairWithEdge(sink, memNode)) + sources.foreach(source => depGraph.addPairWithEdge(memNode, source)) case Attach(_, exprs) => // Add edge between each expression exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach { case Seq(a, b) => - depGraph.addEdge(a, b) - depGraph.addEdge(b, a) + depGraph.addPairWithEdge(a, b) + depGraph.addPairWithEdge(b, a) } case Connect(_, loc, expr) => // This match enforces the low Firrtl requirement of expanded connections val node = getDeps(loc) match { case Seq(elt) => elt } - getDeps(expr).foreach(ref => depGraph.addEdge(node, ref)) + getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref)) // Simulation constructs are treated as top-level outputs case Stop(_,_, clk, en) => - Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) + Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Print(_, _, args, clk, en) => - (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref)) + (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref)) case Block(stmts) => stmts.foreach(onStmt(_)) case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing case other => throw new Exception(s"Unexpected Statement $other") @@ -153,30 +153,30 @@ class DeadCodeElimination extends Transform { val node = LogicNode(ext) // Don't touch external modules *unless* they are specifically marked as doTouch // Simply marking the extmodule itself is sufficient to prevent inputs from being removed - if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, node) + if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, node) ext.ports.foreach { case Port(_, pname, _, AnalogType(_)) => - depGraph.addEdge(LogicNode(ext.name, pname), node) - depGraph.addEdge(node, LogicNode(ext.name, pname)) + depGraph.addPairWithEdge(LogicNode(ext.name, pname), node) + depGraph.addPairWithEdge(node, LogicNode(ext.name, pname)) case Port(_, pname, Output, _) => val portNode = LogicNode(ext.name, pname) - depGraph.addEdge(portNode, node) + depGraph.addPairWithEdge(portNode, node) // Also mark all outputs as circuit sinks (unless marked doTouch obviously) - if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode) - case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname)) + if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, portNode) + case Port(_, pname, Input, _) => depGraph.addPairWithEdge(node, LogicNode(ext.name, pname)) } } // Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface) val topModule = c.modules.find(_.name == c.main).get val topOutputs = topModule.ports.foreach { port => - depGraph.addEdge(circuitSink, LogicNode(c.main, port.name)) + depGraph.addPairWithEdge(circuitSink, LogicNode(c.main, port.name)) } depGraph } private def deleteDeadCode(instMap: collection.Map[String, String], - deadNodes: Set[LogicNode], + deadNodes: collection.Set[LogicNode], moduleMap: collection.Map[String, DefModule], renames: RenameMap) (mod: DefModule): Option[DefModule] = { @@ -243,9 +243,9 @@ class DeadCodeElimination extends Transform { val c = state.circuit val moduleMap = c.modules.map(m => m.name -> m).toMap val iGraph = new InstanceGraph(c) - val moduleDeps = iGraph.graph.edges.map { case (k,v) => + val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) => k.module -> v.map(i => i.name -> i.module).toMap - } + }) val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_)) val depGraph = { @@ -255,7 +255,7 @@ class DeadCodeElimination extends Transform { dontTouches.foreach { dontTouch => // Ensure that they are actually found if (vertices.contains(dontTouch)) { - dGraph.addEdge(circuitSink, dontTouch) + dGraph.addPairWithEdge(circuitSink, dontTouch) } else { val (root, tail) = Utils.splitRef(dontTouch.e1) DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize) |
