aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorDavid Biancolin2019-01-21 18:50:51 -0500
committerGitHub2019-01-21 18:50:51 -0500
commit10586d6a141859b843057ec9979011e26ad207f1 (patch)
treeff23c30013159cdd1879b1e5c3dd5baca5bf4867 /src/main
parent73ae6257fce586ac145b6ab348ce1b47634e7a46 (diff)
parentdf3a34f01d227ff9ad0e63a41ff10001ac01c01d (diff)
Merge branch 'master' into top-wiring-aggregates
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/Compiler.scala4
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala4
-rw-r--r--src/main/scala/firrtl/WIR.scala5
-rw-r--r--src/main/scala/firrtl/annotations/AnnotationUtils.scala4
-rw-r--r--src/main/scala/firrtl/ir/IR.scala2
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala6
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala35
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala157
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala24
-rw-r--r--src/main/scala/firrtl/transforms/GroupComponents.scala13
-rw-r--r--src/main/scala/firrtl/transforms/IdentityTransform.scala17
-rw-r--r--src/main/scala/firrtl/util/ClassUtils.scala19
-rw-r--r--src/main/scala/firrtl/util/TestOptions.scala13
13 files changed, 224 insertions, 79 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 80ba42c4..87662800 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -373,6 +373,10 @@ trait Compiler extends LazyLogging {
*/
def transforms: Seq[Transform]
+ require(transforms.size >= 1,
+ s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " +
+ "Use IdentityTransform if you need an identity/no-op transform.")
+
// Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms
def inputForm: CircuitForm = transforms.head.inputForm
def outputForm: CircuitForm = transforms.last.outputForm
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 350fd433..eab928c2 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -2,6 +2,8 @@
package firrtl
+import firrtl.transforms.IdentityTransform
+
sealed abstract class CoreTransform extends SeqTransform
/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting
@@ -132,7 +134,7 @@ import firrtl.transforms.BlackBoxSourceHelper
*/
class NoneCompiler extends Compiler {
def emitter = new ChirrtlEmitter
- def transforms: Seq[Transform] = Seq.empty
+ def transforms: Seq[Transform] = Seq(new IdentityTransform(ChirrtlForm))
}
/** Emits input circuit
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index b96cd253..a5d2571d 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -40,6 +40,11 @@ object WRef {
def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER)
/** Creates a WRef from a Node */
def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, MALE)
+ /** Creates a WRef from a Port */
+ def apply(port: Port): WRef = new WRef(port.name, port.tpe, PortKind, UNKNOWNGENDER)
+ /** Creates a WRef from a WDefInstance */
+ def apply(wi: WDefInstance): WRef = new WRef(wi.name, wi.tpe, InstanceKind, UNKNOWNGENDER)
+ /** Creates a WRef from an arbitrary string name */
def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER)
}
case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) extends Expression {
diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala
index ba9220f7..72765ab7 100644
--- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala
+++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala
@@ -51,8 +51,8 @@ object AnnotationUtils {
case Some(_) =>
val i = s.indexWhere(c => "[].".contains(c))
s.slice(0, i) match {
- case "" => Seq(s(i).toString) ++ tokenize(s.drop(i + 1))
- case x => Seq(x, s(i).toString) ++ tokenize(s.drop(i + 1))
+ case "" => s(i).toString +: tokenize(s.drop(i + 1))
+ case x => x +: s(i).toString +: tokenize(s.drop(i + 1))
}
case None if s == "" => Nil
case None => Seq(s)
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index cdf8e194..19ee56ca 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -607,7 +607,7 @@ case object UnknownType extends Type {
}
/** [[Port]] Direction */
-abstract class Direction extends FirrtlNode
+sealed abstract class Direction extends FirrtlNode
case object Input extends Direction {
def serialize: String = "input"
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index cc69be6f..5a9a60f8 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -103,7 +103,11 @@ object RemoveCHIRRTL extends Transform {
rds map (_.name), wrs map (_.name), rws map (_.name))
Block(mem +: stmts)
case sx: CDefMPort =>
- types(sx.name) = types(sx.mem)
+ types.get(sx.mem) match {
+ case Some(mem) => types(sx.name) = mem
+ case None =>
+ throw new PassException(s"Undefined memory ${sx.mem} referenced by mport ${sx.name}")
+ }
val addrs = ArrayBuffer[String]()
val clks = ArrayBuffer[String]()
val ens = ArrayBuffer[String]()
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index 2b6fa55d..c13fa261 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -3,14 +3,16 @@
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-import scala.annotation.tailrec
+import scala.annotation.tailrec
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import MemPortUtils.memType
+import scala.collection.mutable
+
/** Resolve name collisions that would occur in [[LowerTypes]]
*
* @note Must be run after [[InferTypes]] because [[ir.DefNode]]s need type
@@ -78,36 +80,43 @@ object Uniquify extends Transform {
t: BundleType,
namespace: collection.mutable.HashSet[String])
(implicit sinfo: Info, mname: String): BundleType = {
- def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): Type = t match {
+ def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match {
case tx: BundleType =>
// First add everything
- val newFields = tx.fields map { f =>
+ val newFieldsAndElts = tx.fields map { f =>
val newName = findValidPrefix(f.name, Seq(""), namespace)
namespace += newName
Field(newName, f.flip, f.tpe)
} map { f => f.tpe match {
- case _: GroundType => f
+ case _: GroundType => (f, Seq[String](f.name))
case _ =>
- val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet())
- val elts = enumerateNames(tpe)
+ val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet())
// Need leading _ for findValidPrefix, it doesn't add _ for checks
- val eltsNames = elts map (e => "_" + LowerTypes.loweredName(e))
+ val eltsNames: Seq[String] = eltsx map (e => "_" + e)
val prefix = findValidPrefix(f.name, eltsNames, namespace)
// We added f.name in previous map, delete if we change it
if (prefix != f.name) {
namespace -= f.name
namespace += prefix
}
- namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e)))
- Field(prefix, f.flip, tpe)
+ val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e)))
+ namespace ++= newElts
+ (Field(prefix, f.flip, tpe), prefix +: newElts)
}
}
- BundleType(newFields)
+ val (newFields, elts) = newFieldsAndElts.unzip
+ (BundleType(newFields), elts.flatten)
case tx: VectorType =>
- VectorType(recUniquifyNames(tx.tpe, namespace), tx.size)
- case tx => tx
+ val (tpe, elts) = recUniquifyNames(tx.tpe, namespace)
+ val newElts = ((0 until tx.size) map (i => i.toString)) ++
+ ((0 until tx.size) flatMap { i =>
+ elts map (e => LowerTypes.loweredName(Seq(i.toString, e)))
+ })
+ (VectorType(tpe, tx.size), newElts)
+ case tx => (tx, Nil)
}
- recUniquifyNames(t, namespace) match {
+ val (tpe, _) = recUniquifyNames(t, namespace)
+ tpe match {
case tx: BundleType => tx
case tx => throwInternalError(s"uniquifyNames: shouldn't be here - $tx")
}
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index 7afce210..9016dca4 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -7,6 +7,8 @@ import scala.collection.immutable.HashSet
import scala.collection.immutable.HashMap
import annotation.tailrec
+import Function.tupled
+
import firrtl._
import firrtl.ir._
import firrtl.passes.{Errors, PassException}
@@ -26,6 +28,23 @@ object CheckCombLoops {
case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation
+case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation {
+ if (!source.isLocal || !sink.isLocal || source.module != sink.module) {
+ throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module")
+ }
+
+ override def getTargets: Seq[ReferenceTarget] = Seq(source, sink)
+
+ override def update(renames: RenameMap): Seq[Annotation] = {
+ val sources = renames.get(source).getOrElse(Seq(source))
+ val sinks = renames.get(sink).getOrElse(Seq(sink))
+ val paths = sources flatMap { s => sinks.map((s, _)) }
+ paths.collect {
+ case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink)
+ }
+ }
+}
+
case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) extends Annotation {
override def update(renames: RenameMap): Seq[Annotation] = {
val newSources: Seq[IsComponent] = sources.flatMap { s => renames.get(s).getOrElse(Seq(s.toTarget)) }.collect {case x: IsComponent if x.isLocal => x}
@@ -36,12 +55,12 @@ case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) e
/** Finds and detects combinational logic loops in a circuit, if any
* exist. Returns the input circuit with no modifications.
- *
+ *
* @throws CombLoopException if a loop is found
* @note Input form: Low FIRRTL
* @note Output form: Low FIRRTL (identity transform)
* @note The pass looks for loops through combinational-read memories
- * @note The pass cannot find loops that pass through ExtModules
+ * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules
* @note The pass will throw exceptions on "false paths"
*/
class CheckCombLoops extends Transform with RegisteredTransform {
@@ -69,6 +88,8 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
private def toLogicNode(e: Expression): LogicNode = e match {
+ case idx: WSubIndex =>
+ toLogicNode(idx.expr)
case r: WRef =>
LogicNode(r.name)
case s: WSubField =>
@@ -95,29 +116,29 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private def getStmtDeps(
simplifiedModules: mutable.Map[String,DiGraph[LogicNode]],
deps: MutableDiGraph[LogicNode])(s: Statement): Unit = s match {
- case Connect(_,loc,expr) =>
- val lhs = toLogicNode(loc)
- if (deps.contains(lhs)) {
- getExprDeps(deps, lhs)(expr)
- }
- case w: DefWire =>
- deps.addVertex(LogicNode(w.name))
- case n: DefNode =>
- val lhs = LogicNode(n.name)
- deps.addVertex(lhs)
- getExprDeps(deps, lhs)(n.value)
- case m: DefMemory if (m.readLatency == 0) =>
- for (rp <- m.readers) {
- val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
- deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
- deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
- }
- case i: WDefInstance =>
- val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
- iGraph.getVertices.foreach(deps.addVertex(_))
- iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
- case _ =>
- s.foreach(getStmtDeps(simplifiedModules,deps))
+ case Connect(_,loc,expr) =>
+ val lhs = toLogicNode(loc)
+ if (deps.contains(lhs)) {
+ getExprDeps(deps, lhs)(expr)
+ }
+ case w: DefWire =>
+ deps.addVertex(LogicNode(w.name))
+ case n: DefNode =>
+ val lhs = LogicNode(n.name)
+ deps.addVertex(lhs)
+ getExprDeps(deps, lhs)(n.value)
+ case m: DefMemory if (m.readLatency == 0) =>
+ for (rp <- m.readers) {
+ val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
+ }
+ case i: WDefInstance =>
+ val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
+ iGraph.getVertices.foreach(deps.addVertex(_))
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
+ case _ =>
+ s.foreach(getStmtDeps(simplifiedModules,deps))
}
/*
@@ -129,7 +150,7 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private def expandInstancePaths(
m: String,
moduleGraphs: mutable.Map[String,DiGraph[LogicNode]],
- moduleDeps: Map[String, Map[String,String]],
+ moduleDeps: Map[String, Map[String,String]],
prefix: Seq[String],
path: Seq[LogicNode]): Seq[String] = {
def absNodeName(prefix: Seq[String], n: LogicNode) =
@@ -173,51 +194,65 @@ class CheckCombLoops extends Transform with RegisteredTransform {
* module is converted to a netlist and analyzed locally, with its
* subinstances represented by trivial, simplified subgraphs. The
* overall outline of the process is:
- *
+ *
* 1. Create a graph of module instance dependances
* 2. Linearize this acyclic graph
- *
+ *
* 3. Generate a local netlist; replace any instances with
* simplified subgraphs representing connectivity of their IOs
- *
+ *
* 4. Check for nontrivial strongly connected components
- *
+ *
* 5. Create a reduced representation of the netlist with only the
* module IOs as nodes, where output X (which must be a ground type,
* as only low FIRRTL is supported) will have an edge to input Y if
* and only if it combinationally depends on input Y. Associate this
* reduced graph with the module for future use.
*/
- private def run(c: Circuit): (Circuit, Seq[Annotation]) = {
+ private def run(state: CircuitState) = {
+ val c = state.circuit
val errors = new Errors()
- /* TODO(magyar): deal with exmodules! No pass warnings currently
- * exist. Maybe warn when iterating through modules.
- */
+ val extModulePaths = state.annotations.groupBy {
+ case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module)
+ case ann: Annotation => CircuitTarget(c.main)
+ }
val moduleMap = c.modules.map({m => (m.name,m) }).toMap
val iGraph = new InstanceGraph(c).graph
val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
- for (m <- topoSortedModules) {
- val internalDeps = new MutableDiGraph[LogicNode]
- m.ports.foreach({ p => internalDeps.addVertex(LogicNode(p.name)) })
- m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
- val moduleGraph = DiGraph(internalDeps)
- moduleGraphs(m.name) = moduleGraph
- simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet)
- // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
- for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) {
- errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
- }
- for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) {
- val sccSubgraph = moduleGraph.subgraph(scc.toSet)
- val cycle = findCycleInSCC(sccSubgraph)
- (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
- val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
- errors.append(new CombLoopException(m.info, m.name, expandedCycle))
- }
+ topoSortedModules.foreach {
+ case em: ExtModule =>
+ val portSet = em.ports.map(p => LogicNode(p.name)).toSet
+ val extModuleDeps = new MutableDiGraph[LogicNode]
+ portSet.foreach(extModuleDeps.addVertex(_))
+ extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect {
+ case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
+ }
+ moduleGraphs(em.name) = DiGraph(extModuleDeps).simplify(portSet)
+ simplifiedModuleGraphs(em.name) = moduleGraphs(em.name)
+ case m: Module =>
+ val portSet = m.ports.map(p => LogicNode(p.name)).toSet
+ val internalDeps = new MutableDiGraph[LogicNode]
+ portSet.foreach(internalDeps.addVertex(_))
+ m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
+ val moduleGraph = DiGraph(internalDeps)
+ moduleGraphs(m.name) = moduleGraph
+ simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet)
+ // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
+ for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) {
+ errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
+ }
+ for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) {
+ val sccSubgraph = moduleGraph.subgraph(scc.toSet)
+ val cycle = findCycleInSCC(sccSubgraph)
+ (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
+ val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
+ errors.append(new CombLoopException(m.info, m.name, expandedCycle))
+ }
+ case m => throwInternalError(s"Module ${m.name} has unrecognized type")
}
val mn = ModuleName(c.main, CircuitName(c.main))
val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty =>
@@ -225,8 +260,17 @@ class CheckCombLoops extends Transform with RegisteredTransform {
val sources = tos.map(x => ComponentName(x.name, mn))
CombinationalPath(sink, sources.toSeq)
}
- errors.trigger()
- (c, annos.toSeq)
+ (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs)
+ }
+
+ /**
+ * Returns a Map from Module name to port connectivity
+ */
+ def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = {
+ val (result, errors, connectivity) = run(state)
+ connectivity.map {
+ case (k, v) => (k, v.transformNodes(ln => ln.name))
+ }
}
def execute(state: CircuitState): CircuitState = {
@@ -235,8 +279,9 @@ class CheckCombLoops extends Transform with RegisteredTransform {
logger.warn("Skipping Combinational Loop Detection")
state
} else {
- val (result, annos) = run(state.circuit)
- CircuitState(result, outputForm, state.annotations ++ annos, state.renames)
+ val (result, errors, connectivity) = run(state)
+ errors.trigger()
+ result
}
}
}
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index da7f1a46..8a273476 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
object FoldADD extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
}
@@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shl), e.tpe)
+ case _ => e
+ }
+
private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
case x => e.args.head match {
@@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftRight(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shr), e.tpe)
+ case _ => e
+ }
+
+
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
@@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
private def constPropPrim(e: DoPrim): Expression = e.op match {
case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
case Cat => foldConcat(e)
case Add => FoldADD(e)
case And => FoldAND(e)
@@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+
private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
@@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
}
- propagated
+ // We're done when the Expression no longer changes
+ if (propagated eq old) propagated
+ else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
}
/** Constant propagate a Module
diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala
index 55828e0a..8c36bb6d 100644
--- a/src/main/scala/firrtl/transforms/GroupComponents.scala
+++ b/src/main/scala/firrtl/transforms/GroupComponents.scala
@@ -4,7 +4,7 @@ import firrtl._
import firrtl.Mappers._
import firrtl.ir._
import firrtl.annotations.{Annotation, ComponentName}
-import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils}
+import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils, ResolveKinds}
import firrtl.Utils.kind
import firrtl.graph.{DiGraph, MutableDiGraph}
@@ -62,7 +62,7 @@ class GroupComponents extends firrtl.Transform {
case other => Seq(other)
}
val cs = state.copy(circuit = state.circuit.copy(modules = newModules))
- val csx = InferTypes.execute(cs)
+ val csx = ResolveKinds.execute(InferTypes.execute(cs))
csx
}
@@ -119,6 +119,11 @@ class GroupComponents extends firrtl.Transform {
}
}
+ // Unused nodes are not reachable from any group nor the root--add them to root group
+ for ((v, _) <- deps.getEdgeMap) {
+ reachableNodes.getOrElseUpdate(v, mutable.Set(""))
+ }
+
// Add nodes who are reached by a single group, to that group
reachableNodes.foreach { case (node, membership) =>
if(membership.size == 1) {
@@ -307,7 +312,9 @@ class GroupComponents extends firrtl.Transform {
}
def onStmt(stmt: Statement): Unit = stmt match {
case w: WDefInstance =>
- case h: IsDeclaration => h map onExpr(WRef(h.name))
+ case h: IsDeclaration =>
+ bidirGraph.addVertex(h.name)
+ h map onExpr(WRef(h.name))
case Attach(_, exprs) => // Add edge between each expression
exprs.tail map onExpr(getWRef(exprs.head))
case Connect(_, loc, expr) =>
diff --git a/src/main/scala/firrtl/transforms/IdentityTransform.scala b/src/main/scala/firrtl/transforms/IdentityTransform.scala
new file mode 100644
index 00000000..a39ca4b7
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/IdentityTransform.scala
@@ -0,0 +1,17 @@
+// See LICENSE for license details.
+
+package firrtl.transforms
+
+import firrtl.{CircuitForm, CircuitState, Transform}
+
+/** Transform that applies an identity function. This returns an unmodified [[CircuitState]].
+ * @param form the input and output [[CircuitForm]]
+ */
+class IdentityTransform(form: CircuitForm) extends Transform {
+
+ final override def inputForm: CircuitForm = form
+ final override def outputForm: CircuitForm = form
+
+ final def execute(state: CircuitState): CircuitState = state
+
+}
diff --git a/src/main/scala/firrtl/util/ClassUtils.scala b/src/main/scala/firrtl/util/ClassUtils.scala
new file mode 100644
index 00000000..1b388035
--- /dev/null
+++ b/src/main/scala/firrtl/util/ClassUtils.scala
@@ -0,0 +1,19 @@
+package firrtl.util
+
+object ClassUtils {
+ /** Determine if a named class is loaded.
+ *
+ * @param name - name of the class: "foo.bar" or "org.foo.bar"
+ * @return true if the class has been loaded (is accessible), false otherwise.
+ */
+ def isClassLoaded(name: String): Boolean = {
+ val found = try {
+ Class.forName(name, false, getClass.getClassLoader) != null
+ } catch {
+ case e: ClassNotFoundException => false
+ case x: Throwable => throw x
+ }
+// println(s"isClassLoaded: %s $name".format(if (found) "found" else "didn't find"))
+ found
+ }
+}
diff --git a/src/main/scala/firrtl/util/TestOptions.scala b/src/main/scala/firrtl/util/TestOptions.scala
new file mode 100644
index 00000000..9ee99f8c
--- /dev/null
+++ b/src/main/scala/firrtl/util/TestOptions.scala
@@ -0,0 +1,13 @@
+package firrtl.util
+
+import ClassUtils.isClassLoaded
+
+object TestOptions {
+ // Our timing is inaccurate if we're running tests under coverage.
+ // If any of the classes known to be associated with evaluating coverage are loaded,
+ // assume we're running tests under coverage.
+ // NOTE: We assume we need only ask the class loader that loaded us.
+ // If it was loaded by another class loader (outside of our hierarchy), it wouldn't be available to us.
+ val coverageClasses = List("scoverage.Platform", "com.intellij.rt.coverage.instrumentation.TouchCounter")
+ val accurateTiming = !coverageClasses.exists(isClassLoaded(_))
+}