aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/transforms
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/transforms')
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala27
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala4
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala40
3 files changed, 35 insertions, 36 deletions
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index d50a027c..bb2ffea9 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -77,15 +77,16 @@ class CheckCombLoops extends Transform {
}
}
- private def getExprDeps(deps: mutable.Set[LogicNode])(e: Expression): Expression = e match {
+
+ private def getExprDeps(deps: MutableDiGraph[LogicNode], v: LogicNode)(e: Expression): Expression = e match {
case r: WRef =>
- deps += toLogicNode(r)
+ deps.addEdgeIfValid(v, toLogicNode(r))
r
case s: WSubField =>
- deps += toLogicNode(s)
+ deps.addEdgeIfValid(v, toLogicNode(s))
s
case _ =>
- e map getExprDeps(deps)
+ e map getExprDeps(deps, v)
}
private def getStmtDeps(
@@ -95,14 +96,14 @@ class CheckCombLoops extends Transform {
case Connect(_,loc,expr) =>
val lhs = toLogicNode(loc)
if (deps.contains(lhs)) {
- getExprDeps(deps.getEdges(lhs))(expr)
+ getExprDeps(deps, lhs)(expr)
}
case w: DefWire =>
deps.addVertex(LogicNode(w.name))
case n: DefNode =>
val lhs = LogicNode(n.name)
deps.addVertex(lhs)
- getExprDeps(deps.getEdges(lhs))(n.value)
+ getExprDeps(deps, lhs)(n.value)
case m: DefMemory if (m.readLatency == 0) =>
for (rp <- m.readers) {
val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
@@ -111,10 +112,8 @@ class CheckCombLoops extends Transform {
}
case i: WDefInstance =>
val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
- for (v <- iGraph.getVertices) {
- deps.addVertex(v)
- iGraph.getEdges(v).foreach { deps.addEdge(v,_) }
- }
+ iGraph.getVertices.foreach(deps.addVertex(_))
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
case _ =>
s map getStmtDeps(simplifiedModules,deps)
}
@@ -196,9 +195,9 @@ class CheckCombLoops extends Transform {
* exist. Maybe warn when iterating through modules.
*/
val moduleMap = c.modules.map({m => (m.name,m) }).toMap
- val iGraph = new InstanceGraph(c)
- val moduleDeps = iGraph.graph.edges.map{ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }
- val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
+ val iGraph = new InstanceGraph(c).graph
+ val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
+ val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
for (m <- topoSortedModules) {
@@ -212,7 +211,7 @@ class CheckCombLoops extends Transform {
val sccSubgraph = moduleGraphs(m.name).subgraph(scc.toSet)
val cycle = findCycleInSCC(sccSubgraph)
(cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
- val expandedCycle = expandInstancePaths(m.name,moduleGraphs,moduleDeps,Seq(m.name),cycle.reverse)
+ val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
errors.append(new CombLoopException(m.info, m.name, expandedCycle))
}
}
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index 69502911..84b63e3d 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -398,9 +398,9 @@ class ConstantPropagation extends Transform {
private def run(c: Circuit, dontTouchMap: Map[String, Set[String]]): Circuit = {
val iGraph = (new InstanceGraph(c)).graph
- val moduleDeps = iGraph.edges.map { case (mod, children) =>
+ val moduleDeps = iGraph.getEdgeMap.map({ case (mod, children) =>
mod.module -> children.map(i => i.name -> i.module).toMap
- }
+ })
// Module name to number of instances
val instCount: Map[String, Int] = iGraph.getVertices.groupBy(_.module).mapValues(_.size)
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index 24c1c51c..578de264 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -95,11 +95,11 @@ class DeadCodeElimination extends Transform {
case DefRegister(_, name, _, clock, reset, init) =>
val node = LogicNode(mod.name, name)
depGraph.addVertex(node)
- Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(node, ref))
+ Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref))
case DefNode(_, name, value) =>
val node = LogicNode(mod.name, name)
depGraph.addVertex(node)
- getDeps(value).foreach(ref => depGraph.addEdge(node, ref))
+ getDeps(value).foreach(ref => depGraph.addPairWithEdge(node, ref))
case DefWire(_, name, _) =>
depGraph.addVertex(LogicNode(mod.name, name))
case mem: DefMemory =>
@@ -111,23 +111,23 @@ class DeadCodeElimination extends Transform {
val sinks = exprs.getOrElse(FEMALE, List.empty).flatMap(getDeps(_))
val memNode = getDeps(memRef) match { case Seq(node) => node }
depGraph.addVertex(memNode)
- sinks.foreach(sink => depGraph.addEdge(sink, memNode))
- sources.foreach(source => depGraph.addEdge(memNode, source))
+ sinks.foreach(sink => depGraph.addPairWithEdge(sink, memNode))
+ sources.foreach(source => depGraph.addPairWithEdge(memNode, source))
case Attach(_, exprs) => // Add edge between each expression
exprs.flatMap(getDeps(_)).toSet.subsets(2).map(_.toList).foreach {
case Seq(a, b) =>
- depGraph.addEdge(a, b)
- depGraph.addEdge(b, a)
+ depGraph.addPairWithEdge(a, b)
+ depGraph.addPairWithEdge(b, a)
}
case Connect(_, loc, expr) =>
// This match enforces the low Firrtl requirement of expanded connections
val node = getDeps(loc) match { case Seq(elt) => elt }
- getDeps(expr).foreach(ref => depGraph.addEdge(node, ref))
+ getDeps(expr).foreach(ref => depGraph.addPairWithEdge(node, ref))
// Simulation constructs are treated as top-level outputs
case Stop(_,_, clk, en) =>
- Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ Seq(clk, en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Print(_, _, args, clk, en) =>
- (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addEdge(circuitSink, ref))
+ (args :+ clk :+ en).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(circuitSink, ref))
case Block(stmts) => stmts.foreach(onStmt(_))
case ignore @ (_: IsInvalid | _: WDefInstance | EmptyStmt) => // do nothing
case other => throw new Exception(s"Unexpected Statement $other")
@@ -153,30 +153,30 @@ class DeadCodeElimination extends Transform {
val node = LogicNode(ext)
// Don't touch external modules *unless* they are specifically marked as doTouch
// Simply marking the extmodule itself is sufficient to prevent inputs from being removed
- if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, node)
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, node)
ext.ports.foreach {
case Port(_, pname, _, AnalogType(_)) =>
- depGraph.addEdge(LogicNode(ext.name, pname), node)
- depGraph.addEdge(node, LogicNode(ext.name, pname))
+ depGraph.addPairWithEdge(LogicNode(ext.name, pname), node)
+ depGraph.addPairWithEdge(node, LogicNode(ext.name, pname))
case Port(_, pname, Output, _) =>
val portNode = LogicNode(ext.name, pname)
- depGraph.addEdge(portNode, node)
+ depGraph.addPairWithEdge(portNode, node)
// Also mark all outputs as circuit sinks (unless marked doTouch obviously)
- if (!doTouchExtMods.contains(ext.name)) depGraph.addEdge(circuitSink, portNode)
- case Port(_, pname, Input, _) => depGraph.addEdge(node, LogicNode(ext.name, pname))
+ if (!doTouchExtMods.contains(ext.name)) depGraph.addPairWithEdge(circuitSink, portNode)
+ case Port(_, pname, Input, _) => depGraph.addPairWithEdge(node, LogicNode(ext.name, pname))
}
}
// Connect circuitSink to ALL top-level ports (we don't want to change the top-level interface)
val topModule = c.modules.find(_.name == c.main).get
val topOutputs = topModule.ports.foreach { port =>
- depGraph.addEdge(circuitSink, LogicNode(c.main, port.name))
+ depGraph.addPairWithEdge(circuitSink, LogicNode(c.main, port.name))
}
depGraph
}
private def deleteDeadCode(instMap: collection.Map[String, String],
- deadNodes: Set[LogicNode],
+ deadNodes: collection.Set[LogicNode],
moduleMap: collection.Map[String, DefModule],
renames: RenameMap)
(mod: DefModule): Option[DefModule] = {
@@ -243,9 +243,9 @@ class DeadCodeElimination extends Transform {
val c = state.circuit
val moduleMap = c.modules.map(m => m.name -> m).toMap
val iGraph = new InstanceGraph(c)
- val moduleDeps = iGraph.graph.edges.map { case (k,v) =>
+ val moduleDeps = iGraph.graph.getEdgeMap.map({ case (k,v) =>
k.module -> v.map(i => i.name -> i.module).toMap
- }
+ })
val topoSortedModules = iGraph.graph.transformNodes(_.module).linearize.reverse.map(moduleMap(_))
val depGraph = {
@@ -255,7 +255,7 @@ class DeadCodeElimination extends Transform {
dontTouches.foreach { dontTouch =>
// Ensure that they are actually found
if (vertices.contains(dontTouch)) {
- dGraph.addEdge(circuitSink, dontTouch)
+ dGraph.addPairWithEdge(circuitSink, dontTouch)
} else {
val (root, tail) = Utils.splitRef(dontTouch.e1)
DontTouchAnnotation.errorNotFound(root.serialize, tail.serialize)