aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDavid Biancolin2019-01-21 18:50:51 -0500
committerGitHub2019-01-21 18:50:51 -0500
commit10586d6a141859b843057ec9979011e26ad207f1 (patch)
treeff23c30013159cdd1879b1e5c3dd5baca5bf4867 /src
parent73ae6257fce586ac145b6ab348ce1b47634e7a46 (diff)
parentdf3a34f01d227ff9ad0e63a41ff10001ac01c01d (diff)
Merge branch 'master' into top-wiring-aggregates
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Compiler.scala4
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala4
-rw-r--r--src/main/scala/firrtl/WIR.scala5
-rw-r--r--src/main/scala/firrtl/annotations/AnnotationUtils.scala4
-rw-r--r--src/main/scala/firrtl/ir/IR.scala2
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala6
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala35
-rw-r--r--src/main/scala/firrtl/transforms/CheckCombLoops.scala157
-rw-r--r--src/main/scala/firrtl/transforms/ConstantPropagation.scala24
-rw-r--r--src/main/scala/firrtl/transforms/GroupComponents.scala13
-rw-r--r--src/main/scala/firrtl/transforms/IdentityTransform.scala17
-rw-r--r--src/main/scala/firrtl/util/ClassUtils.scala19
-rw-r--r--src/main/scala/firrtl/util/TestOptions.scala13
-rw-r--r--src/test/scala/firrtlTests/CheckCombLoopsSpec.scala86
-rw-r--r--src/test/scala/firrtlTests/ChirrtlMemSpec.scala17
-rw-r--r--src/test/scala/firrtlTests/ConstantPropagationTests.scala89
-rw-r--r--src/test/scala/firrtlTests/UniquifySpec.scala28
-rw-r--r--src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala73
18 files changed, 517 insertions, 79 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 80ba42c4..87662800 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -373,6 +373,10 @@ trait Compiler extends LazyLogging {
*/
def transforms: Seq[Transform]
+ require(transforms.size >= 1,
+ s"Compiler transforms for '${this.getClass.getName}' must have at least ONE Transform! " +
+ "Use IdentityTransform if you need an identity/no-op transform.")
+
// Similar to (input|output)Form on [[Transform]] but derived from this Compiler's transforms
def inputForm: CircuitForm = transforms.head.inputForm
def outputForm: CircuitForm = transforms.last.outputForm
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 350fd433..eab928c2 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -2,6 +2,8 @@
package firrtl
+import firrtl.transforms.IdentityTransform
+
sealed abstract class CoreTransform extends SeqTransform
/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting
@@ -132,7 +134,7 @@ import firrtl.transforms.BlackBoxSourceHelper
*/
class NoneCompiler extends Compiler {
def emitter = new ChirrtlEmitter
- def transforms: Seq[Transform] = Seq.empty
+ def transforms: Seq[Transform] = Seq(new IdentityTransform(ChirrtlForm))
}
/** Emits input circuit
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index b96cd253..a5d2571d 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -40,6 +40,11 @@ object WRef {
def apply(reg: DefRegister): WRef = new WRef(reg.name, reg.tpe, RegKind, UNKNOWNGENDER)
/** Creates a WRef from a Node */
def apply(node: DefNode): WRef = new WRef(node.name, node.value.tpe, NodeKind, MALE)
+ /** Creates a WRef from a Port */
+ def apply(port: Port): WRef = new WRef(port.name, port.tpe, PortKind, UNKNOWNGENDER)
+ /** Creates a WRef from a WDefInstance */
+ def apply(wi: WDefInstance): WRef = new WRef(wi.name, wi.tpe, InstanceKind, UNKNOWNGENDER)
+ /** Creates a WRef from an arbitrary string name */
def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind): WRef = new WRef(n, t, k, UNKNOWNGENDER)
}
case class WSubField(expr: Expression, name: String, tpe: Type, gender: Gender) extends Expression {
diff --git a/src/main/scala/firrtl/annotations/AnnotationUtils.scala b/src/main/scala/firrtl/annotations/AnnotationUtils.scala
index ba9220f7..72765ab7 100644
--- a/src/main/scala/firrtl/annotations/AnnotationUtils.scala
+++ b/src/main/scala/firrtl/annotations/AnnotationUtils.scala
@@ -51,8 +51,8 @@ object AnnotationUtils {
case Some(_) =>
val i = s.indexWhere(c => "[].".contains(c))
s.slice(0, i) match {
- case "" => Seq(s(i).toString) ++ tokenize(s.drop(i + 1))
- case x => Seq(x, s(i).toString) ++ tokenize(s.drop(i + 1))
+ case "" => s(i).toString +: tokenize(s.drop(i + 1))
+ case x => x +: s(i).toString +: tokenize(s.drop(i + 1))
}
case None if s == "" => Nil
case None => Seq(s)
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index cdf8e194..19ee56ca 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -607,7 +607,7 @@ case object UnknownType extends Type {
}
/** [[Port]] Direction */
-abstract class Direction extends FirrtlNode
+sealed abstract class Direction extends FirrtlNode
case object Input extends Direction {
def serialize: String = "input"
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index cc69be6f..5a9a60f8 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -103,7 +103,11 @@ object RemoveCHIRRTL extends Transform {
rds map (_.name), wrs map (_.name), rws map (_.name))
Block(mem +: stmts)
case sx: CDefMPort =>
- types(sx.name) = types(sx.mem)
+ types.get(sx.mem) match {
+ case Some(mem) => types(sx.name) = mem
+ case None =>
+ throw new PassException(s"Undefined memory ${sx.mem} referenced by mport ${sx.name}")
+ }
val addrs = ArrayBuffer[String]()
val clks = ArrayBuffer[String]()
val ens = ArrayBuffer[String]()
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index 2b6fa55d..c13fa261 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -3,14 +3,16 @@
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
-import scala.annotation.tailrec
+import scala.annotation.tailrec
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
import MemPortUtils.memType
+import scala.collection.mutable
+
/** Resolve name collisions that would occur in [[LowerTypes]]
*
* @note Must be run after [[InferTypes]] because [[ir.DefNode]]s need type
@@ -78,36 +80,43 @@ object Uniquify extends Transform {
t: BundleType,
namespace: collection.mutable.HashSet[String])
(implicit sinfo: Info, mname: String): BundleType = {
- def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): Type = t match {
+ def recUniquifyNames(t: Type, namespace: collection.mutable.HashSet[String]): (Type, Seq[String]) = t match {
case tx: BundleType =>
// First add everything
- val newFields = tx.fields map { f =>
+ val newFieldsAndElts = tx.fields map { f =>
val newName = findValidPrefix(f.name, Seq(""), namespace)
namespace += newName
Field(newName, f.flip, f.tpe)
} map { f => f.tpe match {
- case _: GroundType => f
+ case _: GroundType => (f, Seq[String](f.name))
case _ =>
- val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet())
- val elts = enumerateNames(tpe)
+ val (tpe, eltsx) = recUniquifyNames(f.tpe, collection.mutable.HashSet())
// Need leading _ for findValidPrefix, it doesn't add _ for checks
- val eltsNames = elts map (e => "_" + LowerTypes.loweredName(e))
+ val eltsNames: Seq[String] = eltsx map (e => "_" + e)
val prefix = findValidPrefix(f.name, eltsNames, namespace)
// We added f.name in previous map, delete if we change it
if (prefix != f.name) {
namespace -= f.name
namespace += prefix
}
- namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e)))
- Field(prefix, f.flip, tpe)
+ val newElts: Seq[String] = eltsx map (e => LowerTypes.loweredName(prefix +: Seq(e)))
+ namespace ++= newElts
+ (Field(prefix, f.flip, tpe), prefix +: newElts)
}
}
- BundleType(newFields)
+ val (newFields, elts) = newFieldsAndElts.unzip
+ (BundleType(newFields), elts.flatten)
case tx: VectorType =>
- VectorType(recUniquifyNames(tx.tpe, namespace), tx.size)
- case tx => tx
+ val (tpe, elts) = recUniquifyNames(tx.tpe, namespace)
+ val newElts = ((0 until tx.size) map (i => i.toString)) ++
+ ((0 until tx.size) flatMap { i =>
+ elts map (e => LowerTypes.loweredName(Seq(i.toString, e)))
+ })
+ (VectorType(tpe, tx.size), newElts)
+ case tx => (tx, Nil)
}
- recUniquifyNames(t, namespace) match {
+ val (tpe, _) = recUniquifyNames(t, namespace)
+ tpe match {
case tx: BundleType => tx
case tx => throwInternalError(s"uniquifyNames: shouldn't be here - $tx")
}
diff --git a/src/main/scala/firrtl/transforms/CheckCombLoops.scala b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
index 7afce210..9016dca4 100644
--- a/src/main/scala/firrtl/transforms/CheckCombLoops.scala
+++ b/src/main/scala/firrtl/transforms/CheckCombLoops.scala
@@ -7,6 +7,8 @@ import scala.collection.immutable.HashSet
import scala.collection.immutable.HashMap
import annotation.tailrec
+import Function.tupled
+
import firrtl._
import firrtl.ir._
import firrtl.passes.{Errors, PassException}
@@ -26,6 +28,23 @@ object CheckCombLoops {
case object DontCheckCombLoopsAnnotation extends NoTargetAnnotation
+case class ExtModulePathAnnotation(source: ReferenceTarget, sink: ReferenceTarget) extends Annotation {
+ if (!source.isLocal || !sink.isLocal || source.module != sink.module) {
+ throwInternalError(s"ExtModulePathAnnotation must connect two local targets from the same module")
+ }
+
+ override def getTargets: Seq[ReferenceTarget] = Seq(source, sink)
+
+ override def update(renames: RenameMap): Seq[Annotation] = {
+ val sources = renames.get(source).getOrElse(Seq(source))
+ val sinks = renames.get(sink).getOrElse(Seq(sink))
+ val paths = sources flatMap { s => sinks.map((s, _)) }
+ paths.collect {
+ case (source: ReferenceTarget, sink: ReferenceTarget) => ExtModulePathAnnotation(source, sink)
+ }
+ }
+}
+
case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) extends Annotation {
override def update(renames: RenameMap): Seq[Annotation] = {
val newSources: Seq[IsComponent] = sources.flatMap { s => renames.get(s).getOrElse(Seq(s.toTarget)) }.collect {case x: IsComponent if x.isLocal => x}
@@ -36,12 +55,12 @@ case class CombinationalPath(sink: ComponentName, sources: Seq[ComponentName]) e
/** Finds and detects combinational logic loops in a circuit, if any
* exist. Returns the input circuit with no modifications.
- *
+ *
* @throws CombLoopException if a loop is found
* @note Input form: Low FIRRTL
* @note Output form: Low FIRRTL (identity transform)
* @note The pass looks for loops through combinational-read memories
- * @note The pass cannot find loops that pass through ExtModules
+ * @note The pass relies on ExtModulePathAnnotations to find loops through ExtModules
* @note The pass will throw exceptions on "false paths"
*/
class CheckCombLoops extends Transform with RegisteredTransform {
@@ -69,6 +88,8 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private case class LogicNode(name: String, inst: Option[String] = None, memport: Option[String] = None)
private def toLogicNode(e: Expression): LogicNode = e match {
+ case idx: WSubIndex =>
+ toLogicNode(idx.expr)
case r: WRef =>
LogicNode(r.name)
case s: WSubField =>
@@ -95,29 +116,29 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private def getStmtDeps(
simplifiedModules: mutable.Map[String,DiGraph[LogicNode]],
deps: MutableDiGraph[LogicNode])(s: Statement): Unit = s match {
- case Connect(_,loc,expr) =>
- val lhs = toLogicNode(loc)
- if (deps.contains(lhs)) {
- getExprDeps(deps, lhs)(expr)
- }
- case w: DefWire =>
- deps.addVertex(LogicNode(w.name))
- case n: DefNode =>
- val lhs = LogicNode(n.name)
- deps.addVertex(lhs)
- getExprDeps(deps, lhs)(n.value)
- case m: DefMemory if (m.readLatency == 0) =>
- for (rp <- m.readers) {
- val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
- deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
- deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
- }
- case i: WDefInstance =>
- val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
- iGraph.getVertices.foreach(deps.addVertex(_))
- iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
- case _ =>
- s.foreach(getStmtDeps(simplifiedModules,deps))
+ case Connect(_,loc,expr) =>
+ val lhs = toLogicNode(loc)
+ if (deps.contains(lhs)) {
+ getExprDeps(deps, lhs)(expr)
+ }
+ case w: DefWire =>
+ deps.addVertex(LogicNode(w.name))
+ case n: DefNode =>
+ val lhs = LogicNode(n.name)
+ deps.addVertex(lhs)
+ getExprDeps(deps, lhs)(n.value)
+ case m: DefMemory if (m.readLatency == 0) =>
+ for (rp <- m.readers) {
+ val dataNode = deps.addVertex(LogicNode("data",Some(m.name),Some(rp)))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("addr",Some(m.name),Some(rp))))
+ deps.addEdge(dataNode, deps.addVertex(LogicNode("en",Some(m.name),Some(rp))))
+ }
+ case i: WDefInstance =>
+ val iGraph = simplifiedModules(i.module).transformNodes(n => n.copy(inst = Some(i.name)))
+ iGraph.getVertices.foreach(deps.addVertex(_))
+ iGraph.getVertices.foreach({ v => iGraph.getEdges(v).foreach { deps.addEdge(v,_) } })
+ case _ =>
+ s.foreach(getStmtDeps(simplifiedModules,deps))
}
/*
@@ -129,7 +150,7 @@ class CheckCombLoops extends Transform with RegisteredTransform {
private def expandInstancePaths(
m: String,
moduleGraphs: mutable.Map[String,DiGraph[LogicNode]],
- moduleDeps: Map[String, Map[String,String]],
+ moduleDeps: Map[String, Map[String,String]],
prefix: Seq[String],
path: Seq[LogicNode]): Seq[String] = {
def absNodeName(prefix: Seq[String], n: LogicNode) =
@@ -173,51 +194,65 @@ class CheckCombLoops extends Transform with RegisteredTransform {
* module is converted to a netlist and analyzed locally, with its
* subinstances represented by trivial, simplified subgraphs. The
* overall outline of the process is:
- *
+ *
* 1. Create a graph of module instance dependances
* 2. Linearize this acyclic graph
- *
+ *
* 3. Generate a local netlist; replace any instances with
* simplified subgraphs representing connectivity of their IOs
- *
+ *
* 4. Check for nontrivial strongly connected components
- *
+ *
* 5. Create a reduced representation of the netlist with only the
* module IOs as nodes, where output X (which must be a ground type,
* as only low FIRRTL is supported) will have an edge to input Y if
* and only if it combinationally depends on input Y. Associate this
* reduced graph with the module for future use.
*/
- private def run(c: Circuit): (Circuit, Seq[Annotation]) = {
+ private def run(state: CircuitState) = {
+ val c = state.circuit
val errors = new Errors()
- /* TODO(magyar): deal with exmodules! No pass warnings currently
- * exist. Maybe warn when iterating through modules.
- */
+ val extModulePaths = state.annotations.groupBy {
+ case ann: ExtModulePathAnnotation => ModuleTarget(c.main, ann.source.module)
+ case ann: Annotation => CircuitTarget(c.main)
+ }
val moduleMap = c.modules.map({m => (m.name,m) }).toMap
val iGraph = new InstanceGraph(c).graph
val moduleDeps = iGraph.getEdgeMap.map({ case (k,v) => (k.module, (v map { i => (i.name, i.module) }).toMap) }).toMap
val topoSortedModules = iGraph.transformNodes(_.module).linearize.reverse map { moduleMap(_) }
val moduleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
val simplifiedModuleGraphs = new mutable.HashMap[String,DiGraph[LogicNode]]
- for (m <- topoSortedModules) {
- val internalDeps = new MutableDiGraph[LogicNode]
- m.ports.foreach({ p => internalDeps.addVertex(LogicNode(p.name)) })
- m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
- val moduleGraph = DiGraph(internalDeps)
- moduleGraphs(m.name) = moduleGraph
- simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify((m.ports map { p => LogicNode(p.name) }).toSet)
- // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
- for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) {
- errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
- }
- for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) {
- val sccSubgraph = moduleGraph.subgraph(scc.toSet)
- val cycle = findCycleInSCC(sccSubgraph)
- (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
- val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
- errors.append(new CombLoopException(m.info, m.name, expandedCycle))
- }
+ topoSortedModules.foreach {
+ case em: ExtModule =>
+ val portSet = em.ports.map(p => LogicNode(p.name)).toSet
+ val extModuleDeps = new MutableDiGraph[LogicNode]
+ portSet.foreach(extModuleDeps.addVertex(_))
+ extModulePaths.getOrElse(ModuleTarget(c.main, em.name), Nil).collect {
+ case a: ExtModulePathAnnotation => extModuleDeps.addPairWithEdge(LogicNode(a.sink.ref), LogicNode(a.source.ref))
+ }
+ moduleGraphs(em.name) = DiGraph(extModuleDeps).simplify(portSet)
+ simplifiedModuleGraphs(em.name) = moduleGraphs(em.name)
+ case m: Module =>
+ val portSet = m.ports.map(p => LogicNode(p.name)).toSet
+ val internalDeps = new MutableDiGraph[LogicNode]
+ portSet.foreach(internalDeps.addVertex(_))
+ m.foreach(getStmtDeps(simplifiedModuleGraphs, internalDeps))
+ val moduleGraph = DiGraph(internalDeps)
+ moduleGraphs(m.name) = moduleGraph
+ simplifiedModuleGraphs(m.name) = moduleGraphs(m.name).simplify(portSet)
+ // Find combinational nodes with self-edges; this is *NOT* the same as length-1 SCCs!
+ for (unitLoopNode <- moduleGraph.getVertices.filter(v => moduleGraph.getEdges(v).contains(v))) {
+ errors.append(new CombLoopException(m.info, m.name, Seq(unitLoopNode.name)))
+ }
+ for (scc <- moduleGraph.findSCCs.filter(_.length > 1)) {
+ val sccSubgraph = moduleGraph.subgraph(scc.toSet)
+ val cycle = findCycleInSCC(sccSubgraph)
+ (cycle zip cycle.tail).foreach({ case (a,b) => require(moduleGraph.getEdges(a).contains(b)) })
+ val expandedCycle = expandInstancePaths(m.name, moduleGraphs, moduleDeps, Seq(m.name), cycle.reverse)
+ errors.append(new CombLoopException(m.info, m.name, expandedCycle))
+ }
+ case m => throwInternalError(s"Module ${m.name} has unrecognized type")
}
val mn = ModuleName(c.main, CircuitName(c.main))
val annos = simplifiedModuleGraphs(c.main).getEdgeMap.collect { case (from, tos) if tos.nonEmpty =>
@@ -225,8 +260,17 @@ class CheckCombLoops extends Transform with RegisteredTransform {
val sources = tos.map(x => ComponentName(x.name, mn))
CombinationalPath(sink, sources.toSeq)
}
- errors.trigger()
- (c, annos.toSeq)
+ (state.copy(annotations = state.annotations ++ annos), errors, simplifiedModuleGraphs)
+ }
+
+ /**
+ * Returns a Map from Module name to port connectivity
+ */
+ def analyze(state: CircuitState): collection.Map[String,DiGraph[String]] = {
+ val (result, errors, connectivity) = run(state)
+ connectivity.map {
+ case (k, v) => (k, v.transformNodes(ln => ln.name))
+ }
}
def execute(state: CircuitState): CircuitState = {
@@ -235,8 +279,9 @@ class CheckCombLoops extends Transform with RegisteredTransform {
logger.warn("Skipping Combinational Loop Detection")
state
} else {
- val (result, annos) = run(state.circuit)
- CircuitState(result, outputForm, state.annotations ++ annos, state.renames)
+ val (result, errors, connectivity) = run(state)
+ errors.trigger()
+ result
}
}
}
diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
index da7f1a46..8a273476 100644
--- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala
+++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala
@@ -67,7 +67,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
object FoldADD extends FoldCommutativeOp {
- def fold(c1: Literal, c2: Literal) = (c1, c2) match {
+ def fold(c1: Literal, c2: Literal) = ((c1, c2): @unchecked) match {
case (_: UIntLiteral, _: UIntLiteral) => UIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
case (_: SIntLiteral, _: SIntLiteral) => SIntLiteral(c1.value + c2.value, (c1.width max c2.width) + IntWidth(1))
}
@@ -137,6 +137,13 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftLeft(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shl = DoPrim(Shl, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shl), e.tpe)
+ case _ => e
+ }
+
private def foldShiftRight(e: DoPrim) = e.consts.head.toInt match {
case 0 => e.args.head
case x => e.args.head match {
@@ -148,6 +155,14 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
}
}
+ private def foldDynamicShiftRight(e: DoPrim) = e.args.last match {
+ case UIntLiteral(v, IntWidth(w)) =>
+ val shr = DoPrim(Shr, Seq(e.args.head), Seq(v), UnknownType)
+ pad(PrimOps.set_primop_type(shr), e.tpe)
+ case _ => e
+ }
+
+
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
def isUInt(e: Expression): Boolean = e.tpe match {
@@ -221,7 +236,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
private def constPropPrim(e: DoPrim): Expression = e.op match {
case Shl => foldShiftLeft(e)
+ case Dshl => foldDynamicShiftLeft(e)
case Shr => foldShiftRight(e)
+ case Dshr => foldDynamicShiftRight(e)
case Cat => foldConcat(e)
case Add => FoldADD(e)
case And => FoldAND(e)
@@ -277,6 +294,7 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e)
+
private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = {
val old = e map constPropExpression(nodeMap, instMap, constSubOutputs)
val propagated = old match {
@@ -290,7 +308,9 @@ class ConstantPropagation extends Transform with ResolvedAnnotationPaths {
constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref)
case x => x
}
- propagated
+ // We're done when the Expression no longer changes
+ if (propagated eq old) propagated
+ else constPropExpression(nodeMap, instMap, constSubOutputs)(propagated)
}
/** Constant propagate a Module
diff --git a/src/main/scala/firrtl/transforms/GroupComponents.scala b/src/main/scala/firrtl/transforms/GroupComponents.scala
index 55828e0a..8c36bb6d 100644
--- a/src/main/scala/firrtl/transforms/GroupComponents.scala
+++ b/src/main/scala/firrtl/transforms/GroupComponents.scala
@@ -4,7 +4,7 @@ import firrtl._
import firrtl.Mappers._
import firrtl.ir._
import firrtl.annotations.{Annotation, ComponentName}
-import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils}
+import firrtl.passes.{InferTypes, LowerTypes, MemPortUtils, ResolveKinds}
import firrtl.Utils.kind
import firrtl.graph.{DiGraph, MutableDiGraph}
@@ -62,7 +62,7 @@ class GroupComponents extends firrtl.Transform {
case other => Seq(other)
}
val cs = state.copy(circuit = state.circuit.copy(modules = newModules))
- val csx = InferTypes.execute(cs)
+ val csx = ResolveKinds.execute(InferTypes.execute(cs))
csx
}
@@ -119,6 +119,11 @@ class GroupComponents extends firrtl.Transform {
}
}
+ // Unused nodes are not reachable from any group nor the root--add them to root group
+ for ((v, _) <- deps.getEdgeMap) {
+ reachableNodes.getOrElseUpdate(v, mutable.Set(""))
+ }
+
// Add nodes who are reached by a single group, to that group
reachableNodes.foreach { case (node, membership) =>
if(membership.size == 1) {
@@ -307,7 +312,9 @@ class GroupComponents extends firrtl.Transform {
}
def onStmt(stmt: Statement): Unit = stmt match {
case w: WDefInstance =>
- case h: IsDeclaration => h map onExpr(WRef(h.name))
+ case h: IsDeclaration =>
+ bidirGraph.addVertex(h.name)
+ h map onExpr(WRef(h.name))
case Attach(_, exprs) => // Add edge between each expression
exprs.tail map onExpr(getWRef(exprs.head))
case Connect(_, loc, expr) =>
diff --git a/src/main/scala/firrtl/transforms/IdentityTransform.scala b/src/main/scala/firrtl/transforms/IdentityTransform.scala
new file mode 100644
index 00000000..a39ca4b7
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/IdentityTransform.scala
@@ -0,0 +1,17 @@
+// See LICENSE for license details.
+
+package firrtl.transforms
+
+import firrtl.{CircuitForm, CircuitState, Transform}
+
+/** Transform that applies an identity function. This returns an unmodified [[CircuitState]].
+ * @param form the input and output [[CircuitForm]]
+ */
+class IdentityTransform(form: CircuitForm) extends Transform {
+
+ final override def inputForm: CircuitForm = form
+ final override def outputForm: CircuitForm = form
+
+ final def execute(state: CircuitState): CircuitState = state
+
+}
diff --git a/src/main/scala/firrtl/util/ClassUtils.scala b/src/main/scala/firrtl/util/ClassUtils.scala
new file mode 100644
index 00000000..1b388035
--- /dev/null
+++ b/src/main/scala/firrtl/util/ClassUtils.scala
@@ -0,0 +1,19 @@
+package firrtl.util
+
+object ClassUtils {
+ /** Determine if a named class is loaded.
+ *
+ * @param name - name of the class: "foo.bar" or "org.foo.bar"
+ * @return true if the class has been loaded (is accessible), false otherwise.
+ */
+ def isClassLoaded(name: String): Boolean = {
+ val found = try {
+ Class.forName(name, false, getClass.getClassLoader) != null
+ } catch {
+ case e: ClassNotFoundException => false
+ case x: Throwable => throw x
+ }
+// println(s"isClassLoaded: %s $name".format(if (found) "found" else "didn't find"))
+ found
+ }
+}
diff --git a/src/main/scala/firrtl/util/TestOptions.scala b/src/main/scala/firrtl/util/TestOptions.scala
new file mode 100644
index 00000000..9ee99f8c
--- /dev/null
+++ b/src/main/scala/firrtl/util/TestOptions.scala
@@ -0,0 +1,13 @@
+package firrtl.util
+
+import ClassUtils.isClassLoaded
+
+object TestOptions {
+ // Our timing is inaccurate if we're running tests under coverage.
+ // If any of the classes known to be associated with evaluating coverage are loaded,
+ // assume we're running tests under coverage.
+ // NOTE: We assume we need only ask the class loader that loaded us.
+ // If it was loaded by another class loader (outside of our hierarchy), it wouldn't be available to us.
+ val coverageClasses = List("scoverage.Platform", "com.intellij.rt.coverage.instrumentation.TouchCounter")
+ val accurateTiming = !coverageClasses.exists(isClassLoaded(_))
+}
diff --git a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
index 8fc7dda9..98472f14 100644
--- a/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
+++ b/src/test/scala/firrtlTests/CheckCombLoopsSpec.scala
@@ -165,6 +165,92 @@ class CheckCombLoopsSpec extends SimpleTransformSpec {
}
}
+ "Combinational loop through an annotated ExtModule" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | extmodule blackbox :
+ | input in : UInt<1>
+ | output out : UInt<1>
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | inst inner of blackbox
+ | inner.in <= y
+ | z <= inner.out
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val mt = ModuleTarget("hasloops", "blackbox")
+ val annos = AnnotationSeq(Seq(ExtModulePathAnnotation(mt.ref("in"), mt.ref("out"))))
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm, annos), writer)
+ }
+ }
+
+ "Loop-free circuit with ExtModulePathAnnotations" should "not throw an exception" in {
+ val input = """circuit hasnoloops :
+ | extmodule blackbox :
+ | input in1 : UInt<1>
+ | input in2 : UInt<1>
+ | output out1 : UInt<1>
+ | output out2 : UInt<1>
+ | module hasnoloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | output b : UInt<1>
+ | wire x : UInt<1>
+ | inst inner of blackbox
+ | inner.in1 <= a
+ | x <= inner.out1
+ | inner.in2 <= x
+ | b <= inner.out2
+ |""".stripMargin
+
+ val mt = ModuleTarget("hasnoloops", "blackbox")
+ val annos = AnnotationSeq(Seq(
+ ExtModulePathAnnotation(mt.ref("in1"), mt.ref("out1")),
+ ExtModulePathAnnotation(mt.ref("in2"), mt.ref("out2"))))
+ val writer = new java.io.StringWriter
+ compile(CircuitState(parse(input), ChirrtlForm, annos), writer)
+ }
+
+ "Combinational loop through an output RHS reference" should "throw an exception" in {
+ val input = """circuit hasloops :
+ | module thru :
+ | input in : UInt<1>
+ | output tmp : UInt<1>
+ | output out : UInt<1>
+ | tmp <= in
+ | out <= tmp
+ | module hasloops :
+ | input clk : Clock
+ | input a : UInt<1>
+ | input b : UInt<1>
+ | output c : UInt<1>
+ | output d : UInt<1>
+ | wire y : UInt<1>
+ | wire z : UInt<1>
+ | c <= b
+ | inst inner of thru
+ | inner.in <= y
+ | z <= inner.out
+ | y <= z
+ | d <= z
+ |""".stripMargin
+
+ val writer = new java.io.StringWriter
+ intercept[CheckCombLoops.CombLoopException] {
+ compile(CircuitState(parse(input), ChirrtlForm), writer)
+ }
+ }
+
"Multiple simple loops in one SCC" should "throw an exception" in {
val input = """circuit hasloops :
| module hasloops :
diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
index 74d39286..a4473fe7 100644
--- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala
@@ -108,6 +108,23 @@ circuit foo :
parse(res.getEmittedCircuit.value)
}
+ "An mport that refers to an undefined memory" should "have a helpful error message" in {
+ val input =
+ """circuit testTestModule :
+ | module testTestModule :
+ | input clock : Clock
+ | input reset : UInt<1>
+ | output io : {flip in : UInt<10>, out : UInt<10>}
+ |
+ | node _T_10 = bits(io.in, 1, 0)
+ | read mport _T_11 = m[_T_10], clock
+ | io.out <= _T_11""".stripMargin
+
+ intercept[PassException]{
+ (new LowFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm), Seq()).circuit
+ }.getMessage should startWith ("Undefined memory m referenced by mport _T_11")
+ }
+
ignore should "Memories should not have validif on port clocks when declared in a when" in {
val input =
""";buildInfoPackage: chisel3, version: 3.0-SNAPSHOT, scalaVersion: 2.11.11, sbtVersion: 0.13.16, builtAtString: 2017-10-06 20:55:20.367, builtAtMillis: 1507323320367
diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
index 603ddc25..8a69fcaa 100644
--- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala
+++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala
@@ -734,6 +734,24 @@ class ConstantPropagationSingleModule extends ConstantPropagationSpec {
""".stripMargin
(parse(exec(input))) should be(parse(check))
}
+
+ // Optimizing this mux gives: z <= pad(UInt<2>(0), 4)
+ // Thus this checks that we then optimize that pad
+ "ConstProp" should "optimize nested Expressions" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<4>
+ | z <= mux(UInt(1), UInt<2>(0), UInt<4>(0))
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<4>
+ | z <= UInt<4>("h0")
+ """.stripMargin
+ (parse(exec(input))) should be(parse(check))
+ }
}
// More sophisticated tests of the full compiler
@@ -1104,6 +1122,77 @@ class ConstantPropagationIntegrationSpec extends LowTransformSpec {
| z <= _T_61""".stripMargin
execute(input, check, Seq.empty)
}
+
+ behavior of "ConstProp"
+
+ it should "optimize shl of constants" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<7>
+ | z <= shl(UInt(5), 4)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<7>
+ | z <= UInt<7>("h50")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ it should "optimize shr of constants" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<1>
+ | z <= shr(UInt(5), 2)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<1>
+ | z <= UInt<1>("h1")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ // Due to #866, we need dshl optimized away or it'll become a dshlw and error in parsing
+ // Include cat to verify width is correct
+ it should "optimize dshl of constant" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | node n = dshl(UInt<1>(0), UInt<2>(0))
+ | z <= cat(UInt<4>("hf"), n)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | z <= UInt<8>("hf0")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
+
+ // Include cat and constants to verify width is correct
+ it should "optimize dshr of constant" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | node n = dshr(UInt<4>(0), UInt<2>(2))
+ | z <= cat(UInt<4>("hf"), n)
+ """.stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | output z : UInt<8>
+ | z <= UInt<8>("hf0")
+ """.stripMargin
+ execute(input, check, Seq.empty)
+ }
}
diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala
index 43d1e733..561f0a84 100644
--- a/src/test/scala/firrtlTests/UniquifySpec.scala
+++ b/src/test/scala/firrtlTests/UniquifySpec.scala
@@ -12,6 +12,7 @@ import firrtl._
import firrtl.annotations._
import firrtl.annotations.TargetToken._
import firrtl.transforms.DontTouchAnnotation
+import firrtl.util.TestOptions
class UniquifySpec extends FirrtlFlatSpec {
@@ -283,4 +284,31 @@ class UniquifySpec extends FirrtlFlatSpec {
executeTest(input, expected)
}
+
+ it should "quickly rename deep bundles" in {
+ // We use a fixed time to determine if this test passed or failed.
+ // This test would pass under normal conditions, but would fail during coverage tests.
+ // Since executions times vary significantly under coverage testing, we check a global
+ // to see if timing measurements are accurate enough to enforce the timing checks.
+ val maxMs = 8000.0
+
+ def mkType(i: Int): String = {
+ if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}"
+ }
+
+ val depth = 500
+
+ val input =
+ s"""circuit Test:
+ | module Test :
+ | input in: ${mkType(depth)}
+ | output out: ${mkType(depth)}
+ | out <= in
+ |""".stripMargin
+
+ val (renameMs, _) = Utils.time(compileToVerilog(input))
+
+ if (TestOptions.accurateTiming)
+ renameMs shouldBe < (maxMs)
+ }
}
diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
index 3a32ec71..c54e02e3 100644
--- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
@@ -3,6 +3,10 @@ package transforms
import firrtl.annotations.{CircuitName, ComponentName, ModuleName}
import firrtl.transforms.{GroupAnnotation, GroupComponents}
+import firrtl._
+import firrtl.ir._
+
+import FirrtlCheckers._
class GroupComponentsSpec extends LowTransformSpec {
def transform = new GroupComponents()
@@ -42,6 +46,43 @@ class GroupComponentsSpec extends LowTransformSpec {
""".stripMargin
execute(input, check, groups)
}
+ "Grouping" should "work even when there are unused nodes" in {
+ val input =
+ s"""circuit $top :
+ | module $top :
+ | input in: UInt<16>
+ | output out: UInt<16>
+ | node n = UInt<16>("h0")
+ | wire w : UInt<16>
+ | wire a : UInt<16>
+ | wire b : UInt<16>
+ | a <= UInt<16>("h0")
+ | b <= a
+ | w <= in
+ | out <= w
+ """.stripMargin
+ val groups = Seq(
+ GroupAnnotation(Seq(topComp("w")), "Child", "inst", Some("_OUT"), Some("_IN"))
+ )
+ val check =
+ s"""circuit Top :
+ | module $top :
+ | input in: UInt<16>
+ | output out: UInt<16>
+ | inst inst of Child
+ | node n = UInt<16>("h0")
+ | inst.in_IN <= in
+ | node a = UInt<16>("h0")
+ | node b = a
+ | out <= inst.w_OUT
+ | module Child :
+ | input in_IN : UInt<16>
+ | output w_OUT : UInt<16>
+ | node w = in_IN
+ | w_OUT <= w
+ """.stripMargin
+ execute(input, check, groups)
+ }
"The two sets of instances" should "be grouped" in {
val input =
@@ -288,3 +329,35 @@ class GroupComponentsSpec extends LowTransformSpec {
execute(input, check, groups)
}
}
+
+class GroupComponentsIntegrationSpec extends FirrtlFlatSpec {
+ def topComp(name: String): ComponentName = ComponentName(name, ModuleName("Top", CircuitName("Top")))
+ "Grouping" should "properly set kinds" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input data: UInt<16>
+ | output out: UInt<16>
+ | reg r: UInt<16>, clk
+ | r <= data
+ | out <= r
+ """.stripMargin
+ val groups = Seq(
+ GroupAnnotation(Seq(topComp("r")), "MyModule", "inst", Some("_OUT"), Some("_IN"))
+ )
+ val result = (new VerilogCompiler).compileAndEmit(
+ CircuitState(parse(input), ChirrtlForm, groups),
+ Seq(new GroupComponents)
+ )
+ result should containTree {
+ case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "data_IN", _,_), WRef("data",_,_,_)) => true
+ }
+ result should containTree {
+ case Connect(_, WSubField(WRef("inst",_, InstanceKind,_), "clk_IN", _,_), WRef("clk",_,_,_)) => true
+ }
+ result should containTree {
+ case Connect(_, WRef("out",_,_,_), WSubField(WRef("inst",_, InstanceKind,_), "r_OUT", _,_)) => true
+ }
+ }
+}