aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/Utils.scala6
-rw-r--r--src/main/scala/firrtl/analyses/SymbolTable.scala91
-rw-r--r--src/main/scala/firrtl/ir/IR.scala16
-rw-r--r--src/main/scala/firrtl/passes/CheckHighForm.scala1
-rw-r--r--src/main/scala/firrtl/passes/CheckTypes.scala3
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala6
-rw-r--r--src/main/scala/firrtl/passes/InferBinaryPoints.scala1
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala1
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala701
-rw-r--r--src/main/scala/firrtl/passes/TrimIntervals.scala1
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala4
-rw-r--r--src/main/scala/firrtl/passes/ZeroWidth.scala1
-rw-r--r--src/main/scala/firrtl/stage/Forms.scala1
-rw-r--r--src/main/scala/firrtl/transforms/InferResets.scala1
-rw-r--r--src/main/scala/firrtl/transforms/TopWiring.scala2
-rw-r--r--src/test/scala/firrtl/analysis/SymbolTableSpec.scala95
-rw-r--r--src/test/scala/firrtl/passes/LowerTypesSpec.scala533
-rw-r--r--src/test/scala/firrtlTests/ExpandWhensSpec.scala1
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala387
-rw-r--r--src/test/scala/firrtlTests/LoweringCompilersSpec.scala21
-rw-r--r--src/test/scala/firrtlTests/MemoryInitSpec.scala4
-rw-r--r--src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala6
22 files changed, 1559 insertions, 324 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 71a4d3ef..bb814051 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -509,10 +509,16 @@ object Utils extends LazyLogging {
case Default => Flip
case Flip => Default
}
+ // Input <-> SourceFlow <-> Flip
+ // Output <-> SinkFlow <-> Default
def to_dir(g: Flow): Direction = g match {
case SourceFlow => Input
case SinkFlow => Output
}
+ def to_dir(o: Orientation): Direction = o match {
+ case Flip => Input
+ case Default => Output
+ }
def to_flow(d: Direction): Flow = d match {
case Input => SourceFlow
case Output => SinkFlow
diff --git a/src/main/scala/firrtl/analyses/SymbolTable.scala b/src/main/scala/firrtl/analyses/SymbolTable.scala
new file mode 100644
index 00000000..53ad1614
--- /dev/null
+++ b/src/main/scala/firrtl/analyses/SymbolTable.scala
@@ -0,0 +1,91 @@
+// See LICENSE for license details.
+
+package firrtl.analyses
+
+import firrtl.ir._
+import firrtl.passes.MemPortUtils
+import firrtl.{InstanceKind, Kind, WDefInstance}
+
+import scala.collection.mutable
+
+/** This trait represents a data structure that stores information
+ * on all the symbols available in a single firrtl module.
+ * The module can either be scanned all at once using the
+ * scanModule helper function from the companion object or
+ * the SymbolTable can be updated while traversing the module by
+ * calling the declare method every time a declaration is encountered.
+ * Different implementations of SymbolTable might want to store different
+ * information (e.g., only the names without the types) or build
+ * different indices depending on what information the transform needs.
+ * */
+trait SymbolTable {
+ // methods that need to be implemented by any Symbol table
+ def declare(name: String, tpe: Type, kind: Kind): Unit
+ def declareInstance(name: String, module: String): Unit
+
+ // convenience methods
+ def declare(d: DefInstance): Unit = declareInstance(d.name, d.module)
+ def declare(d: DefMemory): Unit = declare(d.name, MemPortUtils.memType(d), firrtl.MemKind)
+ def declare(d: DefNode): Unit = declare(d.name, d.value.tpe, firrtl.NodeKind)
+ def declare(d: DefWire): Unit = declare(d.name, d.tpe, firrtl.WireKind)
+ def declare(d: DefRegister): Unit = declare(d.name, d.tpe, firrtl.RegKind)
+ def declare(d: Port): Unit = declare(d.name, d.tpe, firrtl.PortKind)
+}
+
+/** Trusts the type annotation on DefInstance nodes instead of re-deriving the type from
+ * the module ports which would require global (cross-module) information. */
+private[firrtl] abstract class LocalSymbolTable extends SymbolTable {
+ def declareInstance(name: String, module: String): Unit = declare(name, UnknownType, InstanceKind)
+ override def declare(d: WDefInstance): Unit = declare(d.name, d.tpe, InstanceKind)
+}
+
+/** Uses a function to derive instance types from module names */
+private[firrtl] abstract class ModuleTypesSymbolTable(moduleTypes: String => Type) extends SymbolTable {
+ def declareInstance(name: String, module: String): Unit = declare(name, moduleTypes(module), InstanceKind)
+}
+
+/** Uses a single buffer. No O(1) access, but deterministic Symbol order. */
+private[firrtl] trait WithSeq extends SymbolTable {
+ private val symbols = mutable.ArrayBuffer[Symbol]()
+ override def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(Sym(name, tpe, kind))
+ def getSymbols: Iterable[Symbol] = symbols
+}
+
+/** Uses a mutable map to provide O(1) access to symbols by name. */
+private[firrtl] trait WithMap extends SymbolTable {
+ private val symbols = mutable.HashMap[String, Symbol]()
+ override def declare(name: String, tpe: Type, kind: Kind): Unit = {
+ assert(!symbols.contains(name), s"Symbol $name already declared: ${symbols(name)}")
+ symbols(name) = Sym(name, tpe, kind)
+ }
+ def apply(name: String): Symbol = symbols(name)
+ def size: Int = symbols.size
+}
+
+private case class Sym(name: String, tpe: Type, kind: Kind) extends Symbol
+private[firrtl] trait Symbol { def name: String; def tpe: Type; def kind: Kind }
+
+/** only remembers the names of symbols */
+private[firrtl] class NamespaceTable extends LocalSymbolTable {
+ private var names = List[String]()
+ override def declare(name: String, tpe: Type, kind: Kind): Unit = names = name :: names
+ def getNames: Seq[String] = names
+}
+
+/** Provides convenience methods to populate SymbolTables. */
+object SymbolTable {
+ def scanModule[T <: SymbolTable](m: DefModule, t: T): T = {
+ implicit val table: T = t
+ m.foreachPort(table.declare)
+ m.foreachStmt(scanStatement)
+ table
+ }
+ private def scanStatement(s: Statement)(implicit table: SymbolTable): Unit = s match {
+ case d: DefInstance => table.declare(d)
+ case d: DefMemory => table.declare(d)
+ case d: DefNode => table.declare(d)
+ case d: DefWire => table.declare(d)
+ case d: DefRegister => table.declare(d)
+ case other => other.foreachStmt(scanStatement)
+ }
+}
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index cd8cd975..5263d9c0 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -206,6 +206,14 @@ abstract class Expression extends FirrtlNode {
def foreachWidth(f: Width => Unit): Unit
}
+/** Represents reference-like expression nodes: SubField, SubIndex, SubAccess and Reference
+ * The following fields can be cast to RefLikeExpression in every well formed firrtl AST:
+ * - SubField.expr, SubIndex.expr, SubAccess.expr
+ * - IsInvalid.expr, Connect.loc, PartialConnect.loc
+ * - Attach.exprs
+ */
+sealed trait RefLikeExpression extends Expression { def flow: Flow }
+
object Reference {
/** Creates a Reference from a Wire */
def apply(wire: DefWire): Reference = Reference(wire.name, wire.tpe, WireKind, UnknownFlow)
@@ -222,7 +230,7 @@ object Reference {
}
case class Reference(name: String, tpe: Type = UnknownType, kind: Kind = UnknownKind, flow: Flow = UnknownFlow)
- extends Expression with HasName with UseSerializer {
+ extends Expression with HasName with UseSerializer with RefLikeExpression {
def mapExpr(f: Expression => Expression): Expression = this
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
@@ -232,7 +240,7 @@ case class Reference(name: String, tpe: Type = UnknownType, kind: Kind = Unknown
}
case class SubField(expr: Expression, name: String, tpe: Type = UnknownType, flow: Flow = UnknownFlow)
- extends Expression with HasName with UseSerializer {
+ extends Expression with HasName with UseSerializer with RefLikeExpression {
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
@@ -242,7 +250,7 @@ case class SubField(expr: Expression, name: String, tpe: Type = UnknownType, flo
}
case class SubIndex(expr: Expression, value: Int, tpe: Type, flow: Flow = UnknownFlow)
- extends Expression with UseSerializer {
+ extends Expression with UseSerializer with RefLikeExpression {
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
@@ -252,7 +260,7 @@ case class SubIndex(expr: Expression, value: Int, tpe: Type, flow: Flow = Unknow
}
case class SubAccess(expr: Expression, index: Expression, tpe: Type, flow: Flow = UnknownFlow)
- extends Expression with UseSerializer {
+ extends Expression with UseSerializer with RefLikeExpression {
def mapExpr(f: Expression => Expression): Expression = this.copy(expr = f(expr), index = f(index))
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
def mapWidth(f: Width => Width): Expression = this
diff --git a/src/main/scala/firrtl/passes/CheckHighForm.scala b/src/main/scala/firrtl/passes/CheckHighForm.scala
index 3ba2a3db..2f706d35 100644
--- a/src/main/scala/firrtl/passes/CheckHighForm.scala
+++ b/src/main/scala/firrtl/passes/CheckHighForm.scala
@@ -344,7 +344,6 @@ object CheckHighForm extends Pass with CheckHighFormLike {
override def optionalPrerequisiteOf =
Seq( Dependency(passes.ResolveKinds),
Dependency(passes.InferTypes),
- Dependency(passes.Uniquify),
Dependency(passes.ResolveFlows),
Dependency[passes.InferWidths],
Dependency[transforms.InferResets] )
diff --git a/src/main/scala/firrtl/passes/CheckTypes.scala b/src/main/scala/firrtl/passes/CheckTypes.scala
index 601ee524..c94928a1 100644
--- a/src/main/scala/firrtl/passes/CheckTypes.scala
+++ b/src/main/scala/firrtl/passes/CheckTypes.scala
@@ -16,8 +16,7 @@ object CheckTypes extends Pass {
override def prerequisites = Dependency(InferTypes) +: firrtl.stage.Forms.WorkingIR
override def optionalPrerequisiteOf =
- Seq( Dependency(passes.Uniquify),
- Dependency(passes.ResolveFlows),
+ Seq( Dependency(passes.ResolveFlows),
Dependency(passes.CheckFlows),
Dependency[passes.InferWidths],
Dependency(passes.CheckWidths) )
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 4384aca7..ab7f02db 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -31,8 +31,7 @@ object ExpandWhens extends Pass {
Seq( Dependency(PullMuxes),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency(Uniquify) ) ++ firrtl.stage.Forms.Resolved
+ Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Resolved
override def invalidates(a: Transform): Boolean = a match {
case CheckInitialization | ResolveKinds | InferTypes => true
@@ -294,8 +293,7 @@ class ExpandWhensAndCheck extends Transform with DependencyAPIMigration {
Seq( Dependency(PullMuxes),
Dependency(ReplaceAccesses),
Dependency(ExpandConnects),
- Dependency(RemoveAccesses),
- Dependency(Uniquify) ) ++ firrtl.stage.Forms.Deduped
+ Dependency(RemoveAccesses) ) ++ firrtl.stage.Forms.Deduped
override def invalidates(a: Transform): Boolean = a match {
case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true
diff --git a/src/main/scala/firrtl/passes/InferBinaryPoints.scala b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
index 4b62d5f7..a16205a7 100644
--- a/src/main/scala/firrtl/passes/InferBinaryPoints.scala
+++ b/src/main/scala/firrtl/passes/InferBinaryPoints.scala
@@ -15,7 +15,6 @@ class InferBinaryPoints extends Pass {
override def prerequisites =
Seq( Dependency(ResolveKinds),
Dependency(InferTypes),
- Dependency(Uniquify),
Dependency(ResolveFlows) )
override def optionalPrerequisiteOf = Seq.empty
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index d481b713..3720523b 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -67,7 +67,6 @@ class InferWidths extends Transform
override def prerequisites =
Seq( Dependency(passes.ResolveKinds),
Dependency(passes.InferTypes),
- Dependency(passes.Uniquify),
Dependency(passes.ResolveFlows),
Dependency[passes.InferBinaryPoints],
Dependency[passes.TrimIntervals] ) ++ firrtl.stage.Forms.WorkingIR
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 29792d17..ace4f3e8 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -2,35 +2,31 @@
package firrtl.passes
-import scala.collection.mutable
-import firrtl._
+import firrtl.analyses.{InstanceKeyGraph, SymbolTable}
+import firrtl.annotations.{CircuitTarget, MemoryInitAnnotation, MemoryRandomInitAnnotation, ModuleTarget, ReferenceTarget}
+import firrtl.{CircuitForm, CircuitState, DependencyAPIMigration, InstanceKind, Kind, MemKind, PortKind, RenameMap, Transform, UnknownForm, Utils}
import firrtl.ir._
-import firrtl.Utils._
-import MemPortUtils.memType
-import firrtl.Mappers._
-import firrtl.annotations.MemoryInitAnnotation
-
-/** Removes all aggregate types from a [[firrtl.ir.Circuit]]
- *
- * @note Assumes [[firrtl.ir.SubAccess]]es have been removed
- * @note Assumes [[firrtl.ir.Connect]]s and [[firrtl.ir.IsInvalid]]s only operate on [[firrtl.ir.Expression]]s of ground type
- * @example
- * {{{
- * wire foo : { a : UInt<32>, b : UInt<16> }
- * }}} lowers to
- * {{{
- * wire foo_a : UInt<32>
- * wire foo_b : UInt<16>
- * }}}
- */
-object LowerTypes extends Transform with DependencyAPIMigration {
-
- override def prerequisites = firrtl.stage.Forms.MidForm
+import firrtl.options.Dependency
+import firrtl.stage.TransformManager.TransformDependency
- override def optionalPrerequisiteOf = Seq.empty
+import scala.annotation.tailrec
+import scala.collection.mutable
+/** Flattens Bundles and Vecs.
+ * - Some implicit bundle types remain, but with a limited depth:
+ * - the type of a memory is still a bundle with depth 2 (mem -> port -> field), see [[MemPortUtils.memType]]
+ * - the type of a module instance is still a bundle with depth 1 (instance -> port)
+ */
+object LowerTypes extends Transform with DependencyAPIMigration {
+ override def prerequisites: Seq[TransformDependency] = Seq(
+ Dependency(RemoveAccesses), // we require all SubAccess nodes to have been removed
+ Dependency(CheckTypes), // we require all types to be correct
+ Dependency(InferTypes), // we require instance types to be resolved (i.e., DefInstance.tpe != UnknownType)
+ Dependency(ExpandConnects) // we require all PartialConnect nodes to have been expanded
+ )
+ override def optionalPrerequisiteOf: Seq[TransformDependency] = Seq.empty
override def invalidates(a: Transform): Boolean = a match {
- case ResolveKinds | InferTypes | ResolveFlows | _: InferWidths => true
+ case ResolveFlows => true // we generate UnknownFlow for now (could be fixed)
case _ => false
}
@@ -39,266 +35,451 @@ object LowerTypes extends Transform with DependencyAPIMigration {
/** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name
* @param e [[firrtl.ir.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]]
* @return Lowered name of e
+ * @note Please make sure that there will be no name collisions when you use this outside of the context of LowerTypes!
*/
def loweredName(e: Expression): String = e match {
- case e: WRef => e.name
- case e: WSubField => s"${loweredName(e.expr)}$delim${e.name}"
- case e: WSubIndex => s"${loweredName(e.expr)}$delim${e.value}"
+ case e: Reference => e.name
+ case e: SubField => s"${loweredName(e.expr)}$delim${e.name}"
+ case e: SubIndex => s"${loweredName(e.expr)}$delim${e.value}"
}
- def loweredName(s: Seq[String]): String = s mkString delim
- def renameExps(renames: RenameMap, n: String, t: Type, root: String): Seq[String] =
- renameExps(renames, WRef(n, t, ExpKind, UnknownFlow), root)
- def renameExps(renames: RenameMap, n: String, t: Type): Seq[String] =
- renameExps(renames, WRef(n, t, ExpKind, UnknownFlow), "")
- def renameExps(renames: RenameMap, e: Expression, root: String): Seq[String] = e.tpe match {
- case (_: GroundType) =>
- val name = root + loweredName(e)
- renames.rename(root + e.serialize, name)
- Seq(name)
- case (t: BundleType) =>
- val subNames = t.fields.flatMap { f =>
- renameExps(renames, WSubField(e, f.name, f.tpe, times(flow(e), f.flip)), root)
- }
- renames.rename(root + e.serialize, subNames)
- subNames
- case (t: VectorType) =>
- val subNames = (0 until t.size).flatMap { i => renameExps(renames, WSubIndex(e, i, t.tpe,flow(e)), root) }
- renames.rename(root + e.serialize, subNames)
- subNames
+ def loweredName(s: Seq[String]): String = s.mkString(delim)
+
+ override def execute(state: CircuitState): CircuitState = {
+ // When memories are lowered to ground type, we have to fix the init annotation or error on it.
+ val (memInitAnnos, otherAnnos) = state.annotations.partition {
+ case _: MemoryRandomInitAnnotation => false
+ case _: MemoryInitAnnotation => true
+ case _ => false
+ }
+ val memInitByModule = memInitAnnos.map(_.asInstanceOf[MemoryInitAnnotation]).groupBy(_.target.encapsulatingModule)
+
+ val c = CircuitTarget(state.circuit.main)
+ val resultAndRenames = state.circuit.modules.map(m => onModule(c, m, memInitByModule.getOrElse(m.name, Seq())))
+ val result = state.circuit.copy(modules = resultAndRenames.map(_._1))
+
+ // memory init annotations could have been modified
+ val newAnnos = otherAnnos ++ resultAndRenames.flatMap(_._3)
+
+ // chain module renames in topological order
+ val moduleRenames = resultAndRenames.map{ case(m,r, _) => m.name -> r }.toMap
+ val moduleOrderBottomUp = InstanceKeyGraph(result).moduleOrder.reverseIterator
+ val renames = moduleOrderBottomUp.map(m => moduleRenames(m.name)).reduce((a,b) => a.andThen(b))
+
+ state.copy(circuit = result, renames = Some(renames), annotations = newAnnos)
}
- private def renameMemExps(renames: RenameMap, e: Expression, portAndField: Expression): Seq[String] = e.tpe match {
- case (_: GroundType) =>
- val (mem, tail) = splitRef(e)
- val loRef = mergeRef(WRef(loweredName(e)), portAndField)
- val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
- renames.rename(hiRef.serialize, loRef.serialize)
- Seq(loRef.serialize)
- case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) =>
- val subNames = renameMemExps(renames, WSubField(e, f.name, f.tpe, times(flow(e), f.flip)), portAndField)
- val (mem, tail) = splitRef(e)
- val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
- renames.rename(hiRef.serialize, subNames)
- names ++ subNames
+ private def onModule(c: CircuitTarget, m: DefModule, memoryInit: Seq[MemoryInitAnnotation]): (DefModule, RenameMap, Seq[MemoryInitAnnotation]) = {
+ val renameMap = RenameMap()
+ val ref = c.module(m.name)
+
+ // first we lower the ports in order to ensure that their names are independent of the module body
+ val (mLoweredPorts, portRefs) = lowerPorts(ref, m, renameMap)
+
+ // scan modules to find all references
+ val scan = SymbolTable.scanModule(mLoweredPorts, new LoweringSymbolTable)
+ // replace all declarations and references with the destructed types
+ implicit val symbols: LoweringTable = new LoweringTable(scan, renameMap, ref, portRefs)
+ implicit val memInit: Seq[MemoryInitAnnotation] = memoryInit
+ val newMod = mLoweredPorts.mapStmt(onStatement)
+
+ (newMod, renameMap, memInit)
+ }
+
+ // We lower ports in a separate pass in order to ensure that statements inside the module do not influence port names.
+ private def lowerPorts(ref: ModuleTarget, m: DefModule, renameMap: RenameMap):
+ (DefModule, Seq[(String, Seq[Reference])]) = {
+ val namespace = mutable.HashSet[String]() ++ m.ports.map(_.name)
+ val loweredPortsAndRefs = m.ports.flatMap { p =>
+ val fieldsAndRefs = DestructTypes.destruct(ref, Field(p.name, Utils.to_flip(p.direction), p.tpe), namespace, renameMap, Set())
+ fieldsAndRefs.map { case (f, ref) =>
+ (Port(p.info, f.name, Utils.to_dir(f.flip), f.tpe), ref -> Seq(Reference(f.name, f.tpe, PortKind)))
+ }
}
- case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) =>
- val subNames = renameMemExps(renames, WSubIndex(e, i, t.tpe,flow(e)), portAndField)
- val (mem, tail) = splitRef(e)
- val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
- renames.rename(hiRef.serialize, subNames)
- names ++ subNames
+ val newM = m match {
+ case e : ExtModule => e.copy(ports = loweredPortsAndRefs.map(_._1))
+ case mod: Module => mod.copy(ports = loweredPortsAndRefs.map(_._1))
}
+ (newM, loweredPortsAndRefs.map(_._2))
}
- private case class LowerTypesException(msg: String) extends FirrtlInternalException(msg)
- private def error(msg: String)(info: Info, mname: String) =
- throw LowerTypesException(s"$info: [module $mname] $msg")
-
- // TODO Improve? Probably not the best way to do this
- private def splitMemRef(e1: Expression): (WRef, WRef, WRef, Option[Expression]) = {
- val (mem, tail1) = splitRef(e1)
- val (port, tail2) = splitRef(tail1)
- tail2 match {
- case e2: WRef =>
- (mem, port, e2, None)
- case _ =>
- val (field, tail3) = splitRef(tail2)
- (mem, port, field, Some(tail3))
+
+ private def onStatement(s: Statement)(implicit symbols: LoweringTable, memInit: Seq[MemoryInitAnnotation]): Statement = s match {
+ // declarations
+ case d : DefWire =>
+ Block(symbols.lower(d.name, d.tpe, firrtl.WireKind).map { case (name, tpe, _) => d.copy(name=name, tpe=tpe) })
+ case d @ DefRegister(info, _, _, clock, reset, _) =>
+ // clock and reset are always of ground type
+ val loweredClock = onExpression(clock)
+ val loweredReset = onExpression(reset)
+ // It is important to first lower the declaration, because the reset can refer to the register itself!
+ val loweredRegs = symbols.lower(d.name, d.tpe, firrtl.RegKind)
+ val inits = Utils.create_exps(d.init).map(onExpression)
+ Block(
+ loweredRegs.zip(inits).map { case ((name, tpe, _), init) =>
+ DefRegister(info, name, tpe, loweredClock, loweredReset, init)
+ })
+ case d : DefNode =>
+ val values = Utils.create_exps(d.value).map(onExpression)
+ Block(
+ symbols.lower(d.name, d.value.tpe, firrtl.NodeKind).zip(values).map{ case((name, tpe, _), value) =>
+ assert(tpe == value.tpe)
+ DefNode(d.info, name, value)
+ })
+ case d : DefMemory =>
+ // TODO: as an optimization, we could just skip ground type memories here.
+ // This would require that we don't error in getReferences() but instead return the old reference.
+ val mems = symbols.lower(d)
+ if(mems.length > 1 && memInit.exists(_.target.ref == d.name)) {
+ val mod = memInit.find(_.target.ref == d.name).get.target.encapsulatingModule
+ val msg = s"[module $mod] Cannot initialize memory ${d.name} of non ground type ${d.dataType.serialize}"
+ throw new RuntimeException(msg)
+ }
+ Block(mems)
+ case d : DefInstance => symbols.lower(d)
+ // connections
+ case Connect(info, loc, expr) =>
+ if(!expr.tpe.isInstanceOf[GroundType]) {
+ throw new RuntimeException(s"LowerTypes expects Connects to have been expanded! ${expr.tpe.serialize}")
+ }
+ val rhs = onExpression(expr)
+ // We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
+ val lhs = symbols.getReferences(loc.asInstanceOf[RefLikeExpression])
+ Block(lhs.map(loc => Connect(info, loc, rhs)))
+ case p : PartialConnect =>
+ throw new RuntimeException(s"LowerTypes expects PartialConnects to be resolved! $p")
+ case IsInvalid(info, expr) =>
+ if(!expr.tpe.isInstanceOf[GroundType]) {
+ throw new RuntimeException(s"LowerTypes expects IsInvalids to have been expanded! ${expr.tpe.serialize}")
+ }
+ // We can get multiple refs on the lhs because of ground-type memory ports like "clk" which can get duplicated.
+ val lhs = symbols.getReferences(expr.asInstanceOf[RefLikeExpression])
+ Block(lhs.map(loc => IsInvalid(info, loc)))
+ // others
+ case other => other.mapExpr(onExpression).mapStmt(onStatement)
+ }
+
+ /** Replaces all Reference, SubIndex and SubField nodes with the updated references */
+ private def onExpression(e: Expression)(implicit symbols: LoweringTable): Expression = e match {
+ case r: RefLikeExpression =>
+ // When reading (and not assigning to) an expression, we can always just pick the first one.
+ // Only very few ground-type references are duplicated and they are all related to lowered memories.
+ // e.g., the `clk` field of a memory port gets duplicated when the memory is split into ground-types.
+ // We ensure that all of these references carry the same value when they are expanded in onStatement.
+ symbols.getReferences(r).head
+ case other => other.mapExpr(onExpression)
+ }
+}
+
+// Holds the first level of the module-level namespace.
+// (i.e. everything that can be addressed directly by a Reference node)
+private class LoweringSymbolTable extends SymbolTable {
+ def declare(name: String, tpe: Type, kind: Kind): Unit = symbols.append(name)
+ def declareInstance(name: String, module: String): Unit = symbols.append(name)
+ private val symbols = mutable.ArrayBuffer[String]()
+ def getSymbolNames: Iterable[String] = symbols
+}
+
+// Lowers types and keeps track of references to lowered types.
+private class LoweringTable(table: LoweringSymbolTable, renameMap: RenameMap, m: ModuleTarget,
+ portNameToExprs: Seq[(String, Seq[Reference])]) {
+ private val portNames: Set[String] = portNameToExprs.map(_._2.head.name).toSet
+ private val namespace = mutable.HashSet[String]() ++ table.getSymbolNames
+ // Serialized old access string to new ground type reference.
+ private val nameToExprs = mutable.HashMap[String, Seq[RefLikeExpression]]() ++ portNameToExprs
+
+ def lower(mem: DefMemory): Seq[DefMemory] = {
+ val (mems, refs) = DestructTypes.destructMemory(m, mem, namespace, renameMap, portNames)
+ nameToExprs ++= refs.groupBy(_._1).mapValues(_.map(_._2))
+ mems
+ }
+ def lower(inst: DefInstance): DefInstance = {
+ val (newInst, refs) = DestructTypes.destructInstance(m, inst, namespace, renameMap, portNames)
+ nameToExprs ++= refs.map { case (name, r) => name -> List(r) }
+ newInst
+ }
+ /** used to lower nodes, registers and wires */
+ def lower(name: String, tpe: Type, kind: Kind, flip: Orientation = Default): Seq[(String, Type, Orientation)] = {
+ val fieldsAndRefs = DestructTypes.destruct(m, Field(name, flip, tpe), namespace, renameMap, portNames)
+ nameToExprs ++= fieldsAndRefs.map{ case (f, ref) => ref -> List(Reference(f.name, f.tpe, kind)) }
+ fieldsAndRefs.map { case (f, _) => (f.name, f.tpe, f.flip) }
+ }
+ def lower(p: Port): Seq[Port] = {
+ val fields = lower(p.name, p.tpe, PortKind, Utils.to_flip(p.direction))
+ fields.map { case (name, tpe, flip) => Port(p.info, name, Utils.to_dir(flip), tpe) }
+ }
+
+ def getReferences(expr: RefLikeExpression): Seq[RefLikeExpression] = nameToExprs(serialize(expr))
+
+ // We could just use FirrtlNode.serialize here, but we want to make sure there are not SubAccess nodes left.
+ private def serialize(expr: RefLikeExpression): String = expr match {
+ case Reference(name, _, _, _) => name
+ case SubField(expr, name, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "." + name
+ case SubIndex(expr, index, _, _) => serialize(expr.asInstanceOf[RefLikeExpression]) + "[" + index.toString + "]"
+ case a : SubAccess =>
+ throw new RuntimeException(s"LowerTypes expects all SubAccesses to have been expanded! ${a.serialize}")
+ }
+}
+
+/** Calculate new type layouts and names. */
+private object DestructTypes {
+ type Namespace = mutable.HashSet[String]
+
+ /** Does the following with a reference:
+ * - rename reference and any bundle fields to avoid name collisions after destruction
+ * - updates rename map with new targets
+ * - generates all ground type fields
+ * - generates a list of all old reference name that now refer to the particular ground type field
+ * - updates namespace with all possibly conflicting names
+ */
+ def destruct(m: ModuleTarget, ref: Field, namespace: Namespace, renameMap: RenameMap, reserved: Set[String]):
+ Seq[(Field, String)] = {
+ // field renames (uniquify) are computed bottom up
+ val (rename, _) = uniquify(ref, namespace, reserved)
+
+ // early exit for ground types that do not need renaming
+ if(ref.tpe.isInstanceOf[GroundType] && rename.isEmpty) {
+ return List((ref, ref.name))
+ }
+
+ // the reference renames are computed top down since they do need the full path
+ val res = destruct(m, ref, rename)
+ recordRenames(res, renameMap, ModuleParentRef(m))
+
+ res.map { case (c, r) => c -> extractGroundTypeRefString(r) }
+ }
+
+ /** instances are special because they remain a 1-deep bundle
+ * @note this relies on the ports of the module having been properly renamed.
+ * @return The potentially renamed instance with newly flattened type.
+ * Note that the list of fields is only of the child fields, and needs a SubField node
+ * instead of a flat Reference when turning them into access expressions.
+ */
+ def destructInstance(m: ModuleTarget, instance: DefInstance, namespace: Namespace, renameMap: RenameMap,
+ reserved: Set[String]): (DefInstance, Seq[(String, SubField)]) = {
+ val (rename, _) = uniquify(Field(instance.name, Default, instance.tpe), namespace, reserved)
+ val newName = rename.map(_.name).getOrElse(instance.name)
+
+ // only destruct the sub-fields (aka ports)
+ val oldParent = RefParentRef(m.ref(instance.name))
+ val children = instance.tpe.asInstanceOf[BundleType].fields.flatMap { f =>
+ val childRename = rename.flatMap(_.children.get(f.name))
+ destruct("", oldParent, f, isVecField = false, rename = childRename)
}
+
+ // rename all references to the instance if necessary
+ if(newName != instance.name) {
+ renameMap.record(m.instOf(instance.name, instance.module), m.instOf(newName, instance.module))
+ }
+ // The ports do not need to be explicitly renamed here. They are renamed when the module ports are lowered.
+
+ val newInstance = instance.copy(name = newName, tpe = BundleType(children.map(_._1)))
+ val instanceRef = Reference(newName, newInstance.tpe, InstanceKind)
+ val refs = children.map{ case(c,r) => extractGroundTypeRefString(r) -> SubField(instanceRef, c.name, c.tpe) }
+
+ (newInstance, refs)
}
- // Lowers an expression of MemKind
- // Since mems with Bundle type must be split into multiple ground type
- // mem, references to fields addr, en, clk, and rmode must be replicated
- // for each resulting memory
- // References to data, mask, rdata, wdata, and wmask have already been split in expand connects
- // and just need to be converted to refer to the correct new memory
- type MemDataTypeMap = collection.mutable.HashMap[String, Type]
- def lowerTypesMemExp(memDataTypeMap: MemDataTypeMap,
- info: Info, mname: String)(e: Expression): Seq[Expression] = {
- val (mem, port, field, tail) = splitMemRef(e)
- field.name match {
- // Fields that need to be replicated for each resulting mem
- case "addr" | "en" | "clk" | "wmode" =>
- require(tail.isEmpty) // there can't be a tail for these
- memDataTypeMap(mem.name) match {
- case _: GroundType => Seq(e)
- case memType => create_exps(mem.name, memType) map { e =>
- val loMemName = loweredName(e)
- val loMem = WRef(loMemName, UnknownType, kind(mem), UnknownFlow)
- mergeRef(loMem, mergeRef(port, field))
+ private val BoolType = UIntType(IntWidth(1))
+
+ /** memories are special because they end up a 2-deep bundle.
+ * @note That a single old ground type reference could be replaced with multiple new ground type reference.
+ * e.g. ("mem_a.r.clk", "mem.r.clk") and ("mem_b.r.clk", "mem.r.clk")
+ * Thus it is appropriate to groupBy old reference string instead of just inserting into a hash table.
+ */
+ def destructMemory(m: ModuleTarget, mem: DefMemory, namespace: Namespace, renameMap: RenameMap,
+ reserved: Set[String]): (Seq[DefMemory], Seq[(String, SubField)]) = {
+ // Uniquify the lowered memory names: When memories get split up into ground types, the access order is changes.
+ // E.g. `mem.r.data.x` becomes `mem_x.r.data`.
+ // This is why we need to create the new bundle structure before we can resolve any name clashes.
+ val bundle = memBundle(mem)
+ val (dataTypeRenames, _) = uniquify(bundle, namespace, reserved)
+ val res = destruct(m, Field(mem.name, Default, mem.dataType), dataTypeRenames)
+
+ // Renames are now of the form `mem.a.b` --> `mem_a_b`.
+ // We want to turn them into `mem.r.data.a.b` --> `mem_a_b.r.data`, etc. (for all readers, writers and for all ports)
+ val oldMemRef = m.ref(mem.name)
+
+ // the "old dummy field" is used as a template for the new memory port types
+ val oldDummyField = Field("dummy", Default, MemPortUtils.memType(mem.copy(dataType = BoolType)))
+
+ val newMemAndSubFields = res.map { case (field, refs) =>
+ val newMem = mem.copy(name = field.name, dataType = field.tpe)
+ val newMemRef = m.ref(field.name)
+ val memWasRenamed = field.name != mem.name // false iff the dataType was a GroundType
+ if(memWasRenamed) { renameMap.record(oldMemRef, newMemRef) }
+
+ val newMemReference = Reference(field.name, MemPortUtils.memType(newMem), MemKind)
+ val refSuffixes = refs.map(_.component).filterNot(_.isEmpty)
+
+ val subFields = oldDummyField.tpe.asInstanceOf[BundleType].fields.flatMap { port =>
+ val oldPortRef = oldMemRef.field(port.name)
+ val newPortRef = newMemRef.field(port.name)
+
+ val newPortType = newMemReference.tpe.asInstanceOf[BundleType].fields.find(_.name == port.name).get.tpe
+ val newPortAccess = SubField(newMemReference, port.name, newPortType)
+
+ port.tpe.asInstanceOf[BundleType].fields.map { portField =>
+ val isDataField = portField.name == "data" || portField.name == "wdata" || portField.name == "rdata"
+ val isMaskField = portField.name == "mask" || portField.name == "wmask"
+ val isDataOrMaskField = isDataField || isMaskField
+ val oldFieldRefs = if(memWasRenamed && isDataOrMaskField) {
+ // there might have been multiple different fields which now alias to the same lowered field.
+ val oldPortFieldBaseRef = oldPortRef.field(portField.name)
+ refSuffixes.map(s => oldPortFieldBaseRef.copy(component = oldPortFieldBaseRef.component ++ s))
+ } else {
+ List(oldPortRef.field(portField.name))
}
+
+ val newPortType = if(isDataField) { newMem.dataType } else { portField.tpe }
+ val newPortFieldAccess = SubField(newPortAccess, portField.name, newPortType)
+
+ // record renames only for the data field which is the only port field of non-ground type
+ val newPortFieldRef = newPortRef.field(portField.name)
+ if(memWasRenamed && isDataOrMaskField) {
+ oldFieldRefs.foreach { o => renameMap.record(o, newPortFieldRef) }
+ }
+
+ val oldFieldStringRef = extractGroundTypeRefString(oldFieldRefs)
+ (oldFieldStringRef, newPortFieldAccess)
}
- // Fields that need not be replicated for each
- // eg. mem.reader.data[0].a
- // (Connect/IsInvalid must already have been split to ground types)
- case "data" | "mask" | "rdata" | "wdata" | "wmask" =>
- val loMem = tail match {
- case Some(ex) =>
- val loMemExp = mergeRef(mem, ex)
- val loMemName = loweredName(loMemExp)
- WRef(loMemName, UnknownType, kind(mem), UnknownFlow)
- case None => mem
- }
- Seq(mergeRef(loMem, mergeRef(port, field)))
- case name => error(s"Error! Unhandled memory field $name")(info, mname)
+ }
+ (newMem, subFields)
}
+
+ (newMemAndSubFields.map(_._1), newMemAndSubFields.flatMap(_._2))
}
- def lowerTypesExp(memDataTypeMap: MemDataTypeMap,
- info: Info, mname: String)(e: Expression): Expression = e match {
- case e: WRef => e
- case (_: WSubField | _: WSubIndex) => kind(e) match {
- case InstanceKind =>
- val (root, tail) = splitRef(e)
- val name = loweredName(tail)
- WSubField(root, name, e.tpe, flow(e))
- case MemKind =>
- val exps = lowerTypesMemExp(memDataTypeMap, info, mname)(e)
- exps.size match {
- case 1 => exps.head
- case _ => error("Error! lowerTypesExp called on MemKind " +
- "SubField that needs to be expanded!")(info, mname)
- }
- case _ => WRef(loweredName(e), e.tpe, kind(e), flow(e))
+ private def memBundle(mem: DefMemory): Field = mem.dataType match {
+ case _: GroundType => Field(mem.name, Default, mem.dataType)
+ case _: BundleType | _: VectorType =>
+ val subMems = getFields(mem.dataType).map(f => mem.copy(name = f.name, dataType = f.tpe))
+ val fields = subMems.map(memBundle)
+ Field(mem.name, Default, BundleType(fields))
+ }
+
+ private def recordRenames(fieldToRefs: Seq[(Field, Seq[ReferenceTarget])], renameMap: RenameMap, parent: ParentRef):
+ Unit = {
+ // TODO: if we group by ReferenceTarget, we could reduce the number of calls to `record`. Is it worth it?
+ fieldToRefs.foreach { case(field, refs) =>
+ val fieldRef = parent.ref(field.name)
+ refs.foreach{ r => renameMap.record(r, fieldRef) }
}
- case e: Mux => e map lowerTypesExp(memDataTypeMap, info, mname)
- case e: ValidIf => e map lowerTypesExp(memDataTypeMap, info, mname)
- case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname)
- case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
- def lowerTypesStmt(memDataTypeMap: MemDataTypeMap,
- minfo: Info, mname: String, renames: RenameMap, initializedMems: Set[(String, String)])(s: Statement): Statement = {
- val info = get_info(s) match {case NoInfo => minfo case x => x}
- s map lowerTypesStmt(memDataTypeMap, info, mname, renames, initializedMems) match {
- case s: DefWire => s.tpe match {
- case _: GroundType => s
- case _ =>
- val exps = create_exps(s.name, s.tpe)
- val names = exps map loweredName
- renameExps(renames, s.name, s.tpe)
- Block((exps zip names) map { case (e, n) =>
- DefWire(s.info, n, e.tpe)
- })
- }
- case sx: DefRegister => sx.tpe match {
- case _: GroundType => sx map lowerTypesExp(memDataTypeMap, info, mname)
- case _ =>
- val es = create_exps(sx.name, sx.tpe)
- val names = es map loweredName
- renameExps(renames, sx.name, sx.tpe)
- val inits = create_exps(sx.init) map lowerTypesExp(memDataTypeMap, info, mname)
- val clock = lowerTypesExp(memDataTypeMap, info, mname)(sx.clock)
- val reset = lowerTypesExp(memDataTypeMap, info, mname)(sx.reset)
- Block((es zip names) zip inits map { case ((e, n), i) =>
- DefRegister(sx.info, n, e.tpe, clock, reset, i)
- })
- }
- // Could instead just save the type of each Module as it gets processed
- case sx: WDefInstance => sx.tpe match {
- case t: BundleType =>
- val fieldsx = t.fields flatMap { f =>
- renameExps(renames, f.name, f.tpe, s"${sx.name}.")
- create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, SourceFlow))) map { e =>
- // Flip because inst flows are reversed from Module type
- Field(loweredName(e), swap(to_flip(flow(e))), e.tpe)
- }
- }
- WDefInstance(sx.info, sx.name, sx.module, BundleType(fieldsx))
- case _ => error("WDefInstance type should be Bundle!")(info, mname)
- }
- case sx: DefMemory =>
- memDataTypeMap(sx.name) = sx.dataType
- sx.dataType match {
- case _: GroundType => sx
- case _ =>
- // right now only ground type memories can be initialized
- if(initializedMems.contains((mname, sx.name))) {
- error(s"Cannot initialize memory of non ground type ${sx.dataType.serialize}")(info, mname)
- }
- // Rename ports
- val seen: mutable.Set[String] = mutable.Set[String]()
- create_exps(sx.name, memType(sx)) foreach { e =>
- val (mem, port, field, tail) = splitMemRef(e)
- if (!seen.contains(field.name)) {
- seen += field.name
- val d = WRef(mem.name, sx.dataType)
- tail match {
- case None =>
- val names = create_exps(mem.name, sx.dataType).map { x =>
- s"${loweredName(x)}.${port.serialize}.${field.serialize}"
- }
- renames.rename(e.serialize, names)
- case Some(_) =>
- renameMemExps(renames, d, mergeRef(port, field))
- }
- }
- }
- Block(create_exps(sx.name, sx.dataType) map {e =>
- val newName = loweredName(e)
- // Rename mems
- renames.rename(sx.name, newName)
- sx copy (name = newName, dataType = e.tpe)
- })
+
+ private def extractGroundTypeRefString(refs: Seq[ReferenceTarget]): String = {
+ if (refs.isEmpty) { "" } else {
+ // Since we depend on ExpandConnects any reference we encounter will be of ground type
+ // and thus the one with the longest access path.
+ refs.reduceLeft((x, y) => if (x.component.length > y.component.length) x else y)
+ // convert references to strings relative to the module
+ .serialize.dropWhile(_ != '>').tail
+ }
+ }
+
+ private def destruct(m: ModuleTarget, field: Field, rename: Option[RenameNode]): Seq[(Field, Seq[ReferenceTarget])] =
+ destruct(prefix = "", oldParent = ModuleParentRef(m), oldField = field, isVecField = false, rename = rename)
+
+ /** Lowers a field into its ground type fields.
+ * @param prefix carries the prefix of the new ground type name
+ * @param isVecField is used to generate an appropriate old (field/index) reference
+ * @param rename The information from the `uniquify` function is consumed to appropriately rename generated fields.
+ * @return a sequence of ground type fields with new names and, for each field,
+ * a sequence of old references that should to be renamed to point to the particular field
+ */
+ private def destruct(prefix: String, oldParent: ParentRef, oldField: Field,
+ isVecField: Boolean, rename: Option[RenameNode]): Seq[(Field, Seq[ReferenceTarget])] = {
+ val newName = rename.map(_.name).getOrElse(oldField.name)
+ val oldRef = oldParent.ref(oldField.name, isVecField)
+
+ oldField.tpe match {
+ case _ : GroundType => List((oldField.copy(name = prefix + newName), List(oldRef)))
+ case _ : BundleType | _ : VectorType =>
+ val newPrefix = prefix + newName + LowerTypes.delim
+ val isVecField = oldField.tpe.isInstanceOf[VectorType]
+ val fields = getFields(oldField.tpe)
+ val fieldsWithCorrectOrientation = fields.map(f => f.copy(flip = Utils.times(f.flip, oldField.flip)))
+ val children = fieldsWithCorrectOrientation.flatMap { f =>
+ destruct(newPrefix, RefParentRef(oldRef), f, isVecField, rename.flatMap(_.children.get(f.name)))
}
- // wire foo : { a , b }
- // node x = foo
- // node y = x.a
- // ->
- // node x_a = foo_a
- // node x_b = foo_b
- // node y = x_a
- case sx: DefNode =>
- val names = create_exps(sx.name, sx.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname)
- val exps = create_exps(sx.value) map lowerTypesExp(memDataTypeMap, info, mname)
- renameExps(renames, sx.name, sx.value.tpe)
- Block(names zip exps map { case (n, e) =>
- DefNode(info, loweredName(n), e)
- })
- case sx: IsInvalid => kind(sx.expr) match {
- case MemKind =>
- Block(lowerTypesMemExp(memDataTypeMap, info, mname)(sx.expr) map (IsInvalid(info, _)))
- case _ => sx map lowerTypesExp(memDataTypeMap, info, mname)
- }
- case sx: Connect => kind(sx.loc) match {
- case MemKind =>
- val exp = lowerTypesExp(memDataTypeMap, info, mname)(sx.expr)
- val locs = lowerTypesMemExp(memDataTypeMap, info, mname)(sx.loc)
- Block(locs map (Connect(info, _, exp)))
- case _ => sx map lowerTypesExp(memDataTypeMap, info, mname)
- }
- case sx => sx map lowerTypesExp(memDataTypeMap, info, mname)
+ // the bundle/vec reference refers to all children
+ children.map{ case(c, r) => (c, r :+ oldRef) }
}
}
- def lowerTypes(renames: RenameMap, initializedMems: Set[(String, String)])(m: DefModule): DefModule = {
- val memDataTypeMap = new MemDataTypeMap
- renames.setModule(m.name)
- // Lower Ports
- val portsx = m.ports flatMap { p =>
- val exps = create_exps(WRef(p.name, p.tpe, PortKind, to_flow(p.direction)))
- val names = exps map loweredName
- renameExps(renames, p.name, p.tpe)
- (exps zip names) map { case (e, n) =>
- Port(p.info, n, to_dir(flow(e)), e.tpe)
- }
+ private case class RenameNode(name: String, children: Map[String, RenameNode])
+
+ /** Implements the core functionality of the old Uniquify pass: rename bundle fields and top-level references
+ * where necessary in order to avoid name clashes when lowering aggregate type with the `_` delimiter.
+ * We don't actually do the rename here but just calculate a rename tree. */
+ private def uniquify(ref: Field, namespace: Namespace, reserved: Set[String]): (Option[RenameNode], Seq[String]) = {
+ // ensure that there are no name clashes with the list of reserved (port) names
+ val newRefName = findValidPrefix(ref.name, reserved.contains)
+ ref.tpe match {
+ case BundleType(fields) =>
+ // we rename bottom-up
+ val localNamespace = new Namespace() ++ fields.map(_.name)
+ val renamedFields = fields.map(f => uniquify(f, localNamespace, Set()))
+
+ // Need leading _ for findValidPrefix, it doesn't add _ for checks
+ val renamedFieldNames = renamedFields.flatMap(_._2)
+ val suffixNames: Seq[String] = renamedFieldNames.map(f => LowerTypes.delim + f)
+ val prefix = findValidPrefix(newRefName, namespace.contains, suffixNames)
+ // We added f.name in previous map, delete if we change it
+ val renamed = prefix != ref.name
+ if (renamed) {
+ if(!reserved.contains(ref.name)) namespace -= ref.name
+ namespace += prefix
+ }
+ val suffixes = renamedFieldNames.map(f => prefix + LowerTypes.delim + f)
+
+ val anyChildRenamed = renamedFields.exists(_._1.isDefined)
+ val rename = if(renamed || anyChildRenamed){
+ val children = renamedFields.map(_._1).zip(fields).collect{ case (Some(r), f) => f.name -> r }.toMap
+ Some(RenameNode(prefix, children))
+ } else { None }
+
+ (rename, suffixes :+ prefix)
+ case v : VectorType=>
+ // if Vecs are to be lowered, we can just treat them like a bundle
+ uniquify(ref.copy(tpe = vecToBundle(v)), namespace, reserved)
+ case _ : GroundType =>
+ if(newRefName == ref.name) {
+ (None, List(ref.name))
+ } else {
+ (Some(RenameNode(newRefName, Map())), List(newRefName))
+ }
+ case UnknownType => throw new RuntimeException(s"Cannot uniquify field of unknown type: $ref")
}
- m match {
- case m: ExtModule =>
- m copy (ports = portsx)
- case m: Module =>
- m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name, renames, initializedMems)
+ }
+
+ /** Appends delim to prefix until no collisions of prefix + elts in names We don't add an _ in the collision check
+ * because elts could be Seq("") In this case, we're just really checking if prefix itself collides */
+ @tailrec
+ private def findValidPrefix(prefix: String, inNamespace: String => Boolean, elts: Seq[String] = List("")): String = {
+ elts.find(elt => inNamespace(prefix + elt)) match {
+ case Some(_) => findValidPrefix(prefix + "_", inNamespace, elts)
+ case None => prefix
}
}
- def execute(state: CircuitState): CircuitState = {
- // remember which memories need to be initialized, for these memories, lowering non-ground types is not supported
- val initializedMems = state.annotations.collect{
- case m : MemoryInitAnnotation if !m.isRandomInit =>
- (m.target.encapsulatingModule, m.target.ref) }.toSet
- val c = state.circuit
- val renames = RenameMap()
- renames.setCircuit(c.main)
- val result = c copy (modules = c.modules map lowerTypes(renames, initializedMems))
- CircuitState(result, outputForm, state.annotations, Some(renames))
+ private def getFields(tpe: Type): Seq[Field] = tpe match {
+ case BundleType(fields) => fields
+ case v : VectorType => vecToBundle(v).fields
+ }
+
+ private def vecToBundle(v: VectorType): BundleType = {
+ BundleType(( 0 until v.size).map(i => Field(i.toString, Default, v.tpe)))
+ }
+
+ /** Used to abstract over module and reference parents.
+ * This helps us simplify the `destruct` method as it does not need to distinguish between
+ * a module (in the initial call) or a bundle/vector (in the recursive call) reference as parent.
+ */
+ private trait ParentRef { def ref(name: String, asVecField: Boolean = false): ReferenceTarget }
+ private case class ModuleParentRef(m: ModuleTarget) extends ParentRef {
+ override def ref(name: String, asVecField: Boolean): ReferenceTarget = m.ref(name)
+ }
+ private case class RefParentRef(r: ReferenceTarget) extends ParentRef {
+ override def ref(name: String, asVecField: Boolean): ReferenceTarget =
+ if(asVecField) { r.index(name.toInt) } else { r.field(name) }
}
}
diff --git a/src/main/scala/firrtl/passes/TrimIntervals.scala b/src/main/scala/firrtl/passes/TrimIntervals.scala
index cb87e10e..822a8125 100644
--- a/src/main/scala/firrtl/passes/TrimIntervals.scala
+++ b/src/main/scala/firrtl/passes/TrimIntervals.scala
@@ -25,7 +25,6 @@ class TrimIntervals extends Pass {
override def prerequisites =
Seq( Dependency(ResolveKinds),
Dependency(InferTypes),
- Dependency(Uniquify),
Dependency(ResolveFlows),
Dependency[InferBinaryPoints] )
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index 89a99780..b9cd32fa 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -12,7 +12,7 @@ import firrtl.options.Dependency
import MemPortUtils.memType
-/** Resolve name collisions that would occur in [[LowerTypes]]
+/** Resolve name collisions that would occur in the old [[LowerTypes]] pass
*
* @note Must be run after [[InferTypes]] because [[ir.DefNode]]s need type
* @example
@@ -244,6 +244,8 @@ object Uniquify extends Transform with DependencyAPIMigration {
}
// Everything wrapped in run so that it's thread safe
+ @deprecated("The functionality of Uniquify is now part of LowerTypes." +
+ "Please file an issue with firrtl if you use Uniquify outside of the context of LowerTypes.", "Firrtl 1.4")
def execute(state: CircuitState): CircuitState = {
val c = state.circuit
val renames = RenameMap()
diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala
index 4f7e2369..56d66ef0 100644
--- a/src/main/scala/firrtl/passes/ZeroWidth.scala
+++ b/src/main/scala/firrtl/passes/ZeroWidth.scala
@@ -15,7 +15,6 @@ object ZeroWidth extends Transform with DependencyAPIMigration {
Dependency(ReplaceAccesses),
Dependency(ExpandConnects),
Dependency(RemoveAccesses),
- Dependency(Uniquify),
Dependency[ExpandWhensAndCheck],
Dependency(ConvertFixedToSInt) ) ++ firrtl.stage.Forms.Deduped
diff --git a/src/main/scala/firrtl/stage/Forms.scala b/src/main/scala/firrtl/stage/Forms.scala
index 933db4f4..55292fc5 100644
--- a/src/main/scala/firrtl/stage/Forms.scala
+++ b/src/main/scala/firrtl/stage/Forms.scala
@@ -33,7 +33,6 @@ object Forms {
val Resolved: Seq[TransformDependency] = WorkingIR ++ Checks ++
Seq( Dependency(passes.ResolveKinds),
Dependency(passes.InferTypes),
- Dependency(passes.Uniquify),
Dependency(passes.ResolveFlows),
Dependency[passes.InferBinaryPoints],
Dependency[passes.TrimIntervals],
diff --git a/src/main/scala/firrtl/transforms/InferResets.scala b/src/main/scala/firrtl/transforms/InferResets.scala
index ebf1d67a..dd073001 100644
--- a/src/main/scala/firrtl/transforms/InferResets.scala
+++ b/src/main/scala/firrtl/transforms/InferResets.scala
@@ -115,7 +115,6 @@ class InferResets extends Transform with DependencyAPIMigration {
override def prerequisites =
Seq( Dependency(passes.ResolveKinds),
Dependency(passes.InferTypes),
- Dependency(passes.Uniquify),
Dependency(passes.ResolveFlows),
Dependency[passes.InferWidths] ) ++ stage.Forms.WorkingIR
diff --git a/src/main/scala/firrtl/transforms/TopWiring.scala b/src/main/scala/firrtl/transforms/TopWiring.scala
index aa046770..f5a5e2a3 100644
--- a/src/main/scala/firrtl/transforms/TopWiring.scala
+++ b/src/main/scala/firrtl/transforms/TopWiring.scala
@@ -4,7 +4,7 @@ package TopWiring
import firrtl._
import firrtl.ir._
-import firrtl.passes.{ExpandConnects, InferTypes, LowerTypes, ResolveFlows, ResolveKinds}
+import firrtl.passes.{InferTypes, LowerTypes, ResolveKinds, ResolveFlows, ExpandConnects}
import firrtl.annotations._
import firrtl.Mappers._
import firrtl.analyses.InstanceKeyGraph
diff --git a/src/test/scala/firrtl/analysis/SymbolTableSpec.scala b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala
new file mode 100644
index 00000000..599b4e52
--- /dev/null
+++ b/src/test/scala/firrtl/analysis/SymbolTableSpec.scala
@@ -0,0 +1,95 @@
+// See LICENSE for license details.
+
+package firrtl.analysis
+
+import firrtl.analyses._
+import firrtl.ir
+import firrtl.options.Dependency
+import org.scalatest.flatspec.AnyFlatSpec
+
+class SymbolTableSpec extends AnyFlatSpec {
+ behavior of "SymbolTable"
+
+ private val src =
+ """circuit m:
+ | module child:
+ | input x : UInt<2>
+ | skip
+ | module m:
+ | input clk : Clock
+ | input x : UInt<1>
+ | output y : UInt<3>
+ | wire z : SInt<1>
+ | node a = cat(asUInt(z), x)
+ | inst i of child
+ | reg r: SInt<4>, clk
+ | mem m:
+ | data-type => UInt<8>
+ | depth => 31
+ | reader => r
+ | read-latency => 1
+ | write-latency => 1
+ | read-under-write => undefined
+ |""".stripMargin
+
+ it should "find all declarations in module m before InferTypes" in {
+ val c = firrtl.Parser.parse(src)
+ val m = c.modules.find(_.name == "m").get
+
+ val syms = SymbolTable.scanModule(m, new LocalSymbolTable with WithMap)
+ assert(syms.size == 8)
+ assert(syms("clk").tpe == ir.ClockType && syms("clk").kind == firrtl.PortKind)
+ assert(syms("x").tpe == ir.UIntType(ir.IntWidth(1)) && syms("x").kind == firrtl.PortKind)
+ assert(syms("y").tpe == ir.UIntType(ir.IntWidth(3)) && syms("y").kind == firrtl.PortKind)
+ assert(syms("z").tpe == ir.SIntType(ir.IntWidth(1)) && syms("z").kind == firrtl.WireKind)
+ // The expression type which determines the node type is only known after InferTypes.
+ assert(syms("a").tpe == ir.UnknownType && syms("a").kind == firrtl.NodeKind)
+ // The type of the instance is unknown because we scanned the module before InferTypes and the table
+ // uses only local information.
+ assert(syms("i").tpe == ir.UnknownType && syms("i").kind == firrtl.InstanceKind)
+ assert(syms("r").tpe == ir.SIntType(ir.IntWidth(4)) && syms("r").kind == firrtl.RegKind)
+ val mType = firrtl.passes.MemPortUtils.memType(
+ // only dataType, depth and reader, writer, readwriter properties affect the data type
+ ir.DefMemory(ir.NoInfo, "???", ir.UIntType(ir.IntWidth(8)), 32, 10, 10, Seq("r"), Seq(), Seq(), ir.ReadUnderWrite.New)
+ )
+ assert(syms("m") .tpe == mType && syms("m").kind == firrtl.MemKind)
+ }
+
+ it should "find all declarations in module m after InferTypes" in {
+ val c = firrtl.Parser.parse(src)
+ val inferTypesCompiler = new firrtl.stage.TransformManager(Seq(Dependency(firrtl.passes.InferTypes)))
+ val inferredC = inferTypesCompiler.execute(firrtl.CircuitState(c, Seq())).circuit
+ val m = inferredC.modules.find(_.name == "m").get
+
+ val syms = SymbolTable.scanModule(m, new LocalSymbolTable with WithMap)
+ // The node type is now known
+ assert(syms("a").tpe == ir.UIntType(ir.IntWidth(2)) && syms("a").kind == firrtl.NodeKind)
+ // The type of the instance is now known because it has been filled in by InferTypes.
+ val iType = ir.BundleType(Seq(ir.Field("x", ir.Flip, ir.UIntType(ir.IntWidth(2)))))
+ assert(syms("i").tpe == iType && syms("i").kind == firrtl.InstanceKind)
+ }
+
+ behavior of "WithSeq"
+
+ it should "preserve declaration order" in {
+ val c = firrtl.Parser.parse(src)
+ val m = c.modules.find(_.name == "m").get
+
+ val syms = SymbolTable.scanModule(m, new LocalSymbolTable with WithSeq)
+ assert(syms.getSymbols.map(_.name) == Seq("clk", "x", "y", "z", "a", "i", "r", "m"))
+ }
+
+ behavior of "ModuleTypesSymbolTable"
+
+ it should "derive the module type from the module types map" in {
+ val c = firrtl.Parser.parse(src)
+ val m = c.modules.find(_.name == "m").get
+
+ val childType = ir.BundleType(Seq(ir.Field("x", ir.Flip, ir.UIntType(ir.IntWidth(2)))))
+ val moduleTypes = Map("child" -> childType)
+
+ val syms = SymbolTable.scanModule(m, new ModuleTypesSymbolTable(moduleTypes) with WithMap)
+ assert(syms.size == 8)
+ assert(syms("i").tpe == childType && syms("i").kind == firrtl.InstanceKind)
+ }
+}
diff --git a/src/test/scala/firrtl/passes/LowerTypesSpec.scala b/src/test/scala/firrtl/passes/LowerTypesSpec.scala
new file mode 100644
index 00000000..884e51b8
--- /dev/null
+++ b/src/test/scala/firrtl/passes/LowerTypesSpec.scala
@@ -0,0 +1,533 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+import firrtl.annotations.{CircuitTarget, IsMember}
+import firrtl.{CircuitState, RenameMap, Utils}
+import firrtl.options.Dependency
+import firrtl.stage.TransformManager
+import firrtl.stage.TransformManager.TransformDependency
+import org.scalatest.flatspec.AnyFlatSpec
+
+
+/** Unit test style tests for [[LowerTypes]].
+ * You can find additional integration style tests in [[firrtlTests.LowerTypesSpec]]
+ */
+class LowerTypesUnitTestSpec extends LowerTypesBaseSpec {
+ import LowerTypesSpecUtils._
+ override protected def lower(n: String, tpe: String, namespace: Set[String]): Seq[String] =
+ destruct(n, tpe, namespace).fields
+}
+
+/** Runs the lowering pass in the context of the compiler instead of directly calling internal functions. */
+class LowerTypesEndToEndSpec extends LowerTypesBaseSpec {
+ private lazy val lowerTypesCompiler = new TransformManager(Seq(Dependency(LowerTypes)))
+ private def legacyLower(n: String, tpe: String, namespace: Set[String]): Seq[String] = {
+ val inputs = namespace.map(n => s" input $n : UInt<1>").mkString("\n")
+ val src =
+ s"""circuit c:
+ | module c:
+ |$inputs
+ | output $n : $tpe
+ | $n is invalid
+ |""".stripMargin
+ val c = CircuitState(firrtl.Parser.parse(src), Seq())
+ val c2 = lowerTypesCompiler.execute(c)
+ val ps = c2.circuit.modules.head.ports.filterNot(p => namespace.contains(p.name))
+ ps.map{p =>
+ val orientation = Utils.to_flip(p.direction)
+ s"${orientation.serialize}${p.name} : ${p.tpe.serialize}"}
+ }
+
+ override protected def lower(n: String, tpe: String, namespace: Set[String]): Seq[String] =
+ legacyLower(n, tpe, namespace)
+}
+
+/** this spec can be tested with either the new or the old LowerTypes pass */
+abstract class LowerTypesBaseSpec extends AnyFlatSpec {
+ protected def lower(n: String, tpe: String, namespace: Set[String] = Set()): Seq[String]
+
+ it should "lower bundles and vectors" in {
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}") == Seq("a_a : UInt<1>", "a_b : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}") == Seq("a_a : UInt<1>", "a_b_c : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}") == Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]") ==
+ Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>"))
+
+ // with conflicts
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_a")) == Seq("a__a : UInt<1>", "a__b : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_b")) == Seq("a__a : UInt<1>", "a__b : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}", Set("a_c")) == Seq("a_a : UInt<1>", "a_b : UInt<1>"))
+
+ assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_a")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>"))
+ // in this case we do not have a "real" conflict, but it could be in a reference and thus a is still changed to a_
+ assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", Set("a_b_c")) == Seq("a__a : UInt<1>", "a__b_c : UInt<1>"))
+
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a")) ==
+ Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_a", "a_b_0")) ==
+ Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", Set("a_b_0")) ==
+ Seq("a__a : UInt<1>", "a__b_0 : UInt<1>", "a__b_1 : UInt<1>"))
+
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0")) ==
+ Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_3")) ==
+ Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_a")) ==
+ Seq("a__0_a : UInt<1>", "a__0_b : UInt<1>", "a__1_a : UInt<1>", "a__1_b : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", Set("a_0_c")) ==
+ Seq("a_0_a : UInt<1>", "a_0_b : UInt<1>", "a_1_a : UInt<1>", "a_1_b : UInt<1>"))
+
+ // collisions inside the bundle
+ assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b__c : UInt<1>", "a_b_c : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b_c : UInt<1>", "a_b_b : UInt<1>"))
+
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b__0 : UInt<1>", "a_b__1 : UInt<1>", "a_b_0 : UInt<1>"))
+ assert(lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_c : UInt<1>}") ==
+ Seq("a_a : UInt<1>", "a_b_0 : UInt<1>", "a_b_1 : UInt<1>", "a_b_c : UInt<1>"))
+ }
+
+ it should "correctly lower the orientation" in {
+ assert(lower("a", "{ flip a : UInt<1>, b : UInt<1>}") == Seq("flip a_a : UInt<1>", "a_b : UInt<1>"))
+ assert(lower("a", "{ flip a : UInt<1>[2], b : UInt<1>}") ==
+ Seq("flip a_a_0 : UInt<1>", "flip a_a_1 : UInt<1>", "a_b : UInt<1>"))
+ assert(lower("a", "{ a : { flip c : UInt<1>, d : UInt<1>}[2], b : UInt<1>}") ==
+ Seq("flip a_a_0_c : UInt<1>", "a_a_0_d : UInt<1>", "flip a_a_1_c : UInt<1>", "a_a_1_d : UInt<1>", "a_b : UInt<1>")
+ )
+ }
+}
+
+/** Test the renaming for "regular" references, i.e. Wires, Nodes and Register.
+ * Memories and Instances are special cases.
+ */
+class LowerTypesRenamingSpec extends AnyFlatSpec {
+ import LowerTypesSpecUtils._
+ protected def lower(n: String, tpe: String, namespace: Set[String] = Set()): RenameMap =
+ destruct(n, tpe, namespace).renameMap
+
+ private val m = CircuitTarget("m").module("m")
+
+ it should "not rename ground types" in {
+ val r = lower("a", "UInt<1>")
+ assert(r.underlying.isEmpty)
+ }
+
+ it should "properly rename lowered bundles and vectors" in {
+ val a = m.ref("a")
+
+ def one(namespace: Set[String], prefix: String): Unit = {
+ val r = lower("a", "{ a : UInt<1>, b : UInt<1>}", namespace)
+ assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b")))
+ assert(get(r,a.field("a")) == Set(m.ref(prefix + "a")))
+ assert(get(r,a.field("b")) == Set(m.ref(prefix + "b")))
+ }
+ one(Set(), "a_")
+ one(Set("a_a"), "a__")
+
+ def two(namespace: Set[String], prefix: String): Unit = {
+ val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}}", namespace)
+ assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_c")))
+ assert(get(r,a.field("a")) == Set(m.ref(prefix + "a")))
+ assert(get(r,a.field("b")) == Set(m.ref(prefix + "b_c")))
+ assert(get(r,a.field("b").field("c")) == Set(m.ref(prefix + "b_c")))
+ }
+ two(Set(), "a_")
+ two(Set("a_a"), "a__")
+
+ def three(namespace: Set[String], prefix: String): Unit = {
+ val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2]}", namespace)
+ assert(get(r,a) == Set(m.ref(prefix + "a"), m.ref(prefix + "b_0"), m.ref(prefix + "b_1")))
+ assert(get(r,a.field("a")) == Set(m.ref(prefix + "a")))
+ assert(get(r,a.field("b")) == Set( m.ref(prefix + "b_0"), m.ref(prefix + "b_1")))
+ assert(get(r,a.field("b").index(0)) == Set(m.ref(prefix + "b_0")))
+ assert(get(r,a.field("b").index(1)) == Set(m.ref(prefix + "b_1")))
+ }
+ three(Set(), "a_")
+ three(Set("a_b_0"), "a__")
+
+ def four(namespace: Set[String], prefix: String): Unit = {
+ val r = lower("a", "{ a : UInt<1>, b : UInt<1>}[2]", namespace)
+ assert(get(r,a) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "1_a"), m.ref(prefix + "0_b"), m.ref(prefix + "1_b")))
+ assert(get(r,a.index(0)) == Set(m.ref(prefix + "0_a"), m.ref(prefix + "0_b")))
+ assert(get(r,a.index(1)) == Set(m.ref(prefix + "1_a"), m.ref(prefix + "1_b")))
+ assert(get(r,a.index(0).field("a")) == Set(m.ref(prefix + "0_a")))
+ assert(get(r,a.index(0).field("b")) == Set(m.ref(prefix + "0_b")))
+ assert(get(r,a.index(1).field("a")) == Set(m.ref(prefix + "1_a")))
+ assert(get(r,a.index(1).field("b")) == Set(m.ref(prefix + "1_b")))
+ }
+ four(Set(), "a_")
+ four(Set("a_0"), "a__")
+ four(Set("a_3"), "a_")
+
+ // collisions inside the bundle
+ {
+ val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_c : UInt<1>}")
+ assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__c"), m.ref("a_b_c")))
+ assert(get(r,a.field("a")) == Set(m.ref("a_a")))
+ assert(get(r,a.field("b")) == Set(m.ref("a_b__c")))
+ assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b__c")))
+ assert(get(r,a.field("b_c")) == Set(m.ref("a_b_c")))
+ }
+ {
+ val r = lower("a", "{ a : UInt<1>, b : { c : UInt<1>}, b_b : UInt<1>}")
+ assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b_c"), m.ref("a_b_b")))
+ assert(get(r,a.field("a")) == Set(m.ref("a_a")))
+ assert(get(r,a.field("b")) == Set(m.ref("a_b_c")))
+ assert(get(r,a.field("b").field("c")) == Set(m.ref("a_b_c")))
+ assert(get(r,a.field("b_b")) == Set(m.ref("a_b_b")))
+ }
+ {
+ val r = lower("a", "{ a : UInt<1>, b : UInt<1>[2], b_0 : UInt<1>}")
+ assert(get(r,a) == Set(m.ref("a_a"), m.ref("a_b__0"), m.ref("a_b__1"), m.ref("a_b_0")))
+ assert(get(r,a.field("a")) == Set(m.ref("a_a")))
+ assert(get(r,a.field("b")) == Set(m.ref("a_b__0"), m.ref("a_b__1")))
+ assert(get(r,a.field("b").index(0)) == Set(m.ref("a_b__0")))
+ assert(get(r,a.field("b").index(1)) == Set(m.ref("a_b__1")))
+ assert(get(r,a.field("b_0")) == Set(m.ref("a_b_0")))
+ }
+ }
+}
+
+/** Instances are a special case since they do not get completely destructed but instead become a 1-deep bundle. */
+class LowerTypesOfInstancesSpec extends AnyFlatSpec {
+ import LowerTypesSpecUtils._
+ private case class Lower(inst: firrtl.ir.DefInstance, fields: Seq[String], renameMap: RenameMap)
+ private val m = CircuitTarget("m").module("m")
+ def resultToFieldSeq(res: Seq[(String, firrtl.ir.SubField)]): Seq[String] =
+ res.map(_._2).map(r => s"${r.name} : ${r.tpe.serialize}")
+ private def lower(n: String, tpe: String, module: String, namespace: Set[String], renames: RenameMap = RenameMap()):
+ Lower = {
+ val ref = firrtl.ir.DefInstance(firrtl.ir.NoInfo, n, module, parseType(tpe))
+ val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace
+ val (newInstance, res) = DestructTypes.destructInstance(m, ref, mutableSet, renames, Set())
+ Lower(newInstance, resultToFieldSeq(res), renames)
+ }
+ private def get(l: Lower, m: IsMember): Set[IsMember] = l.renameMap.get(m).get.toSet
+
+ it should "not rename instances if the instance name does not change" in {
+ val l = lower("i", "{ a : UInt<1>}", "c", Set())
+ assert(l.renameMap.underlying.isEmpty)
+ }
+
+ it should "lower an instance correctly" in {
+ val i = m.instOf("i", "c")
+ val l = lower("i", "{ a : UInt<1>}", "c", Set("i_a"))
+ assert(l.inst.name == "i_")
+ assert(l.inst.tpe.isInstanceOf[firrtl.ir.BundleType])
+ assert(l.inst.tpe.serialize == "{ a : UInt<1>}")
+
+ assert(get(l, i) == Set(m.instOf("i_", "c")))
+ assert(l.fields == Seq("a : UInt<1>"))
+ }
+
+ it should "update the rename map with the changed port names" in {
+ // without lowering ports
+ {
+ val i = m.instOf("i", "c")
+ val l = lower("i", "{ b : { c : UInt<1>}, b_c : UInt<1>}", "c", Set("i_b_c"))
+ // the instance was renamed because of the collision with "i_b_c"
+ assert(get(l, i) == Set(m.instOf("i_", "c")))
+ // the rename of e.g. `instance.b` to `instance_.b__c` was not recorded since we never performed the
+ // port renaming and thus we won't get a result
+ assert(get(l, i.ref("b")) == Set(m.instOf("i_", "c").ref("b")))
+ }
+
+ // same as above but with lowered port
+ {
+ // We need two distinct rename maps: one for the port renaming and one for everything else.
+ // This is to accommodate the use-case where a port as well as an instance needs to be renames
+ // thus requiring a two-stage translation process for reference to the port of the instance.
+ // This two-stage translation is only supported through chaining rename maps.
+ val portRenames = RenameMap()
+ val otherRenames = RenameMap()
+
+ // The child module "c" which we assume has the following ports: b : { c : UInt<1>} and b_c : UInt<1>
+ val c = CircuitTarget("m").module("c")
+ val portB = firrtl.ir.Field("b", firrtl.ir.Default, parseType("{ c : UInt<1>}"))
+ val portB_C = firrtl.ir.Field("b_c", firrtl.ir.Default, parseType("UInt<1>"))
+
+ // lower ports
+ val namespaceC = scala.collection.mutable.HashSet[String]() ++ Seq("b", "b_c")
+ DestructTypes.destruct(c, portB, namespaceC, portRenames, Set())
+ DestructTypes.destruct(c, portB_C, namespaceC, portRenames, Set())
+ // only port b is renamed, port b_c stays the same
+ assert(portRenames.get(c.ref("b")).get == Seq(c.ref("b__c")))
+
+ // in module m we then lower the instance i of c
+ val l = lower("i", "{ b : { c : UInt<1>}, b_c : UInt<1>}", "c", Set("i_b_c"), otherRenames)
+ val i = m.instOf("i", "c")
+ // the instance was renamed because of the collision with "i_b_c"
+ val i_ = m.instOf("i_", "c")
+ assert(get(l, i) == Set(i_))
+
+ // the ports renaming is also noted
+ val r = portRenames.andThen(otherRenames)
+ assert(r.get(i.ref("b")).get == Seq(i_.ref("b__c")))
+ assert(r.get(i.ref("b").field("c")).get == Seq(i_.ref("b__c")))
+ assert(r.get(i.ref("b_c")).get == Seq(i_.ref("b_c")))
+ }
+ }
+}
+
+/** Memories are a special case as they remain 2-deep bundles and fields of the datatype are pulled into the front.
+ * E.g., `mem.r.data.a` becomes `mem_a.r.data`
+ */
+class LowerTypesOfMemorySpec extends AnyFlatSpec {
+ import LowerTypesSpecUtils._
+ private case class Lower(mems: Seq[firrtl.ir.DefMemory], refs: Seq[(String, firrtl.ir.SubField)],
+ renameMap: RenameMap)
+ private val m = CircuitTarget("m").module("m")
+ private val mem = m.ref("mem")
+ private def lower(name: String, tpe: String, namespace: Set[String],
+ r: Seq[String] = List("r"), w: Seq[String] = List("w"), rw: Seq[String] = List(), depth: Int = 2): Lower = {
+ val dataType = parseType(tpe)
+ val mem = firrtl.ir.DefMemory(firrtl.ir.NoInfo, name, dataType, depth = depth, writeLatency = 1, readLatency = 1,
+ readUnderWrite = firrtl.ir.ReadUnderWrite.Undefined, readers = r, writers = w, readwriters = rw)
+ val renames = RenameMap()
+ val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace
+ val(mems, refs) = DestructTypes.destructMemory(m, mem, mutableSet, renames, Set())
+ Lower(mems, refs, renames)
+ }
+ private val UInt1 = firrtl.ir.UIntType(firrtl.ir.IntWidth(1))
+
+ it should "not rename anything for a ground type memory if there was no conflict" in {
+ val l = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w"))
+ assert(l.renameMap.underlying.isEmpty)
+ }
+
+ it should "still produce reference lookups, even for a ground type memory with no conflicts" in {
+ val nameToRef = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("w")).refs
+ .map{case (n,r) => n -> r.serialize}.toSet
+
+ assert(nameToRef == Set(
+ "mem.r.clk" -> "mem.r.clk",
+ "mem.r.en" -> "mem.r.en",
+ "mem.r.addr" -> "mem.r.addr",
+ "mem.r.data" -> "mem.r.data",
+ "mem.w.clk" -> "mem.w.clk",
+ "mem.w.en" -> "mem.w.en",
+ "mem.w.addr" -> "mem.w.addr",
+ "mem.w.data" -> "mem.w.data",
+ "mem.w.mask" -> "mem.w.mask"
+ ))
+ }
+
+ it should "produce references of correct type" in {
+ val nameToType = lower("mem", "UInt<4>", Set("mem_r", "mem_r_data"), w=Seq("w"), depth = 3).refs
+ .map{case (n,r) => n -> r.tpe.serialize}.toSet
+
+ assert(nameToType == Set(
+ "mem.r.clk" -> "Clock",
+ "mem.r.en" -> "UInt<1>",
+ "mem.r.addr" -> "UInt<2>", // depth = 3
+ "mem.r.data" -> "UInt<4>",
+ "mem.w.clk" -> "Clock",
+ "mem.w.en" -> "UInt<1>",
+ "mem.w.addr" -> "UInt<2>",
+ "mem.w.data" -> "UInt<4>",
+ "mem.w.mask" -> "UInt<1>"
+ ))
+ }
+
+ it should "not rename ground type memories even if there are conflicts on the ports" in {
+ // There actually isn't such a thing as conflicting ports, because they do not get flattened by LowerTypes.
+ val r = lower("mem", "UInt<1>", Set("mem_r", "mem_r_data"), w=Seq("r_data")).renameMap
+ assert(r.underlying.isEmpty)
+ }
+
+ it should "rename references to lowered ports" in {
+ val r = lower("mem", "{ a : UInt<1>, b : UInt<1>}", Set("mem_a"), r=Seq("r", "r_data")).renameMap
+
+ // complete memory
+ assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b")))
+
+ // read ports
+ assert(get(r, mem.field("r")) ==
+ Set(m.ref("mem__a").field("r"), m.ref("mem__b").field("r")))
+ assert(get(r, mem.field("r_data")) ==
+ Set(m.ref("mem__a").field("r_data"), m.ref("mem__b").field("r_data")))
+
+ // port fields
+ assert(get(r, mem.field("r").field("data")) ==
+ Set(m.ref("mem__a").field("r").field("data"),
+ m.ref("mem__b").field("r").field("data")))
+ assert(get(r, mem.field("r").field("addr")) ==
+ Set(m.ref("mem__a").field("r").field("addr"),
+ m.ref("mem__b").field("r").field("addr")))
+ assert(get(r, mem.field("r").field("en")) ==
+ Set(m.ref("mem__a").field("r").field("en"),
+ m.ref("mem__b").field("r").field("en")))
+ assert(get(r, mem.field("r").field("clk")) ==
+ Set(m.ref("mem__a").field("r").field("clk"),
+ m.ref("mem__b").field("r").field("clk")))
+ assert(get(r, mem.field("w").field("mask")) ==
+ Set(m.ref("mem__a").field("w").field("mask"),
+ m.ref("mem__b").field("w").field("mask")))
+
+ // port sub-fields
+ assert(get(r, mem.field("r").field("data").field("a")) ==
+ Set(m.ref("mem__a").field("r").field("data")))
+ assert(get(r, mem.field("r").field("data").field("b")) ==
+ Set(m.ref("mem__b").field("r").field("data")))
+
+ // need to rename the following:
+ // mem -> mem__a, mem__b
+ // mem.r.data.{a,b} -> mem__{a,b}.r.data
+ // mem.w.data.{a,b} -> mem__{a,b}.w.data
+ // mem.w.mask.{a,b} -> mem__{a,b}.w.mask
+ // mem.r_data.data.{a,b} -> mem__{a,b}.r_data.data
+ val renameCount = r.underlying.map(_._2.size).sum
+ assert(renameCount == 10, "it is enough to rename *to* 10 different signals")
+ assert(r.underlying.size == 9, "it is enough to rename (from) 9 different signals")
+ }
+
+ it should "rename references for a memory with a nested data type" in {
+ val l = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a"))
+ assert(l.mems.map(_.name) == Seq("mem__a", "mem__b_c"))
+ assert(l.mems.map(_.dataType) == Seq(UInt1, UInt1))
+
+ // complete memory
+ val r = l.renameMap
+ assert(get(r, mem) == Set(m.ref("mem__a"), m.ref("mem__b_c")))
+
+ // read port
+ assert(get(r, mem.field("r")) ==
+ Set(m.ref("mem__a").field("r"), m.ref("mem__b_c").field("r")))
+
+ // port sub-fields
+ assert(get(r, mem.field("r").field("data").field("a")) ==
+ Set(m.ref("mem__a").field("r").field("data")))
+ assert(get(r, mem.field("r").field("data").field("b")) ==
+ Set(m.ref("mem__b_c").field("r").field("data")))
+ assert(get(r, mem.field("r").field("data").field("b").field("c")) ==
+ Set(m.ref("mem__b_c").field("r").field("data")))
+
+ // the mask field needs to be lowered just like the data field
+ assert(get(r, mem.field("w").field("mask").field("a")) ==
+ Set(m.ref("mem__a").field("w").field("mask")))
+ assert(get(r, mem.field("w").field("mask").field("b")) ==
+ Set(m.ref("mem__b_c").field("w").field("mask")))
+ assert(get(r, mem.field("w").field("mask").field("b").field("c")) ==
+ Set(m.ref("mem__b_c").field("w").field("mask")))
+
+ val renameCount = r.underlying.map(_._2.size).sum
+ assert(renameCount == 11, "it is enough to rename *to* 11 different signals")
+ assert(r.underlying.size == 10, "it is enough to rename (from) 10 different signals")
+ }
+
+ it should "return a name to RefLikeExpression map for a memory with a nested data type" in {
+ val nameToRef = lower("mem", "{ a : UInt<1>, b : { c : UInt<1>} }", Set("mem_a")).refs
+ .map{case (n,r) => n -> r.serialize}.toSet
+
+ assert(nameToRef == Set(
+ // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated.
+ // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do.
+ "mem.r.clk" -> "mem__a.r.clk", "mem.r.clk" -> "mem__b_c.r.clk",
+ "mem.r.en" -> "mem__a.r.en", "mem.r.en" -> "mem__b_c.r.en",
+ "mem.r.addr" -> "mem__a.r.addr", "mem.r.addr" -> "mem__b_c.r.addr",
+ "mem.w.clk" -> "mem__a.w.clk", "mem.w.clk" -> "mem__b_c.w.clk",
+ "mem.w.en" -> "mem__a.w.en", "mem.w.en" -> "mem__b_c.w.en",
+ "mem.w.addr" -> "mem__a.w.addr", "mem.w.addr" -> "mem__b_c.w.addr",
+ // Ground type references to the data or mask field are unique.
+ "mem.r.data.a" -> "mem__a.r.data",
+ "mem.w.data.a" -> "mem__a.w.data",
+ "mem.w.mask.a" -> "mem__a.w.mask",
+ "mem.r.data.b.c" -> "mem__b_c.r.data",
+ "mem.w.data.b.c" -> "mem__b_c.w.data",
+ "mem.w.mask.b.c" -> "mem__b_c.w.mask"
+ ))
+ }
+
+ it should "produce references of correct type for memories with a read/write port" in {
+ val refs = lower("mem", "{ a : UInt<3>, b : { c : UInt<4>} }", Set("mem_a"),
+ r=Seq(), w=Seq(), rw=Seq("rw"), depth = 3).refs
+ val nameToRef = refs.map{case (n,r) => n -> r.serialize}.toSet
+ val nameToType = refs.map{case (n,r) => n -> r.tpe.serialize}.toSet
+
+ assert(nameToRef == Set(
+ // The non "data" or "mask" fields of read and write ports are already of ground type but still do get duplicated.
+ // They will all carry the exact same value, so for a RHS use of the old signal, any of the expanded ones will do.
+ "mem.rw.clk" -> "mem__a.rw.clk", "mem.rw.clk" -> "mem__b_c.rw.clk",
+ "mem.rw.en" -> "mem__a.rw.en", "mem.rw.en" -> "mem__b_c.rw.en",
+ "mem.rw.addr" -> "mem__a.rw.addr", "mem.rw.addr" -> "mem__b_c.rw.addr",
+ "mem.rw.wmode" -> "mem__a.rw.wmode", "mem.rw.wmode" -> "mem__b_c.rw.wmode",
+ // Ground type references to the data or mask field are unique.
+ "mem.rw.rdata.a" -> "mem__a.rw.rdata",
+ "mem.rw.wdata.a" -> "mem__a.rw.wdata",
+ "mem.rw.wmask.a" -> "mem__a.rw.wmask",
+ "mem.rw.rdata.b.c" -> "mem__b_c.rw.rdata",
+ "mem.rw.wdata.b.c" -> "mem__b_c.rw.wdata",
+ "mem.rw.wmask.b.c" -> "mem__b_c.rw.wmask"
+ ))
+
+ assert(nameToType == Set(
+ //
+ "mem.rw.clk" -> "Clock",
+ "mem.rw.en" -> "UInt<1>",
+ "mem.rw.addr" -> "UInt<2>",
+ "mem.rw.wmode" -> "UInt<1>",
+ // Ground type references to the data or mask field are unique.
+ "mem.rw.rdata.a" -> "UInt<3>",
+ "mem.rw.wdata.a" -> "UInt<3>",
+ "mem.rw.wmask.a" -> "UInt<1>",
+ "mem.rw.rdata.b.c" -> "UInt<4>",
+ "mem.rw.wdata.b.c" -> "UInt<4>",
+ "mem.rw.wmask.b.c" -> "UInt<1>"
+ ))
+ }
+
+
+ it should "rename references for vector type memories" in {
+ val l = lower("mem", "UInt<1>[2]", Set("mem_0"))
+ assert(l.mems.map(_.name) == Seq("mem__0", "mem__1"))
+ assert(l.mems.map(_.dataType) == Seq(UInt1, UInt1))
+
+ // complete memory
+ val r = l.renameMap
+ assert(get(r, mem) == Set(m.ref("mem__0"), m.ref("mem__1")))
+
+ // read port
+ assert(get(r, mem.field("r")) ==
+ Set(m.ref("mem__0").field("r"), m.ref("mem__1").field("r")))
+
+ // port sub-fields
+ assert(get(r, mem.field("r").field("data").index(0)) ==
+ Set(m.ref("mem__0").field("r").field("data")))
+ assert(get(r, mem.field("r").field("data").index(1)) ==
+ Set(m.ref("mem__1").field("r").field("data")))
+
+ val renameCount = r.underlying.map(_._2.size).sum
+ assert(renameCount == 8, "it is enough to rename *to* 8 different signals")
+ assert(r.underlying.size == 7, "it is enough to rename (from) 7 different signals")
+ }
+
+}
+
+private object LowerTypesSpecUtils {
+ private val typedCompiler = new TransformManager(Seq(Dependency(InferTypes)))
+ def parseType(tpe: String): firrtl.ir.Type = {
+ val src =
+ s"""circuit c:
+ | module c:
+ | input c: $tpe
+ |""".stripMargin
+ val c = CircuitState(firrtl.Parser.parse(src), Seq())
+ typedCompiler.execute(c).circuit.modules.head.ports.head.tpe
+ }
+ case class DestructResult(fields: Seq[String], renameMap: RenameMap)
+ def destruct(n: String, tpe: String, namespace: Set[String]): DestructResult = {
+ val ref = firrtl.ir.Field(n, firrtl.ir.Default, parseType(tpe))
+ val renames = RenameMap()
+ val mutableSet = scala.collection.mutable.HashSet[String]() ++ namespace
+ val res = DestructTypes.destruct(m, ref, mutableSet, renames, Set())
+ DestructResult(resultToFieldSeq(res), renames)
+ }
+ def resultToFieldSeq(res: Seq[(firrtl.ir.Field, String)]): Seq[String] =
+ res.map(_._1).map(r => s"${r.flip.serialize}${r.name} : ${r.tpe.serialize}")
+ def get(r: RenameMap, m: IsMember): Set[IsMember] = r.get(m).get.toSet
+ protected val m = CircuitTarget("m").module("m")
+}
diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
index 250a75d7..3616397f 100644
--- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala
+++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
@@ -13,7 +13,6 @@ class ExpandWhensSpec extends FirrtlFlatSpec {
ResolveKinds,
InferTypes,
CheckTypes,
- Uniquify,
ResolveKinds,
InferTypes,
ResolveFlows,
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 4e8a7fa5..648c6b36 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -6,38 +6,21 @@ import firrtl.Parser
import firrtl.passes._
import firrtl.transforms._
import firrtl._
+import firrtl.annotations._
+import firrtl.options.Dependency
+import firrtl.stage.TransformManager
import firrtl.testutils._
+import firrtl.util.TestOptions
+/** Integration style tests for [[LowerTypes]].
+ * You can find additional unit test style tests in [[passes.LowerTypesUnitTestSpec]]
+ */
class LowerTypesSpec extends FirrtlFlatSpec {
- private def transforms = Seq(
- ToWorkingIR,
- CheckHighForm,
- ResolveKinds,
- InferTypes,
- CheckTypes,
- ResolveFlows,
- CheckFlows,
- new InferWidths,
- CheckWidths,
- PullMuxes,
- ExpandConnects,
- RemoveAccesses,
- ExpandWhens,
- CheckInitialization,
- Legalize,
- new ConstantPropagation,
- ResolveKinds,
- InferTypes,
- ResolveFlows,
- new InferWidths,
- LowerTypes)
+ private val compiler = new TransformManager(Seq(Dependency(LowerTypes)))
private def executeTest(input: String, expected: Seq[String]) = {
- val circuit = Parser.parse(input.split("\n").toIterator)
- val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
- (c: CircuitState, p: Transform) => p.runTransform(c)
- }
- val c = result.circuit
+ val fir = Parser.parse(input.split("\n").toIterator)
+ val c = compiler.runTransform(CircuitState(fir, Seq())).circuit
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
@@ -204,3 +187,353 @@ class LowerTypesSpec extends FirrtlFlatSpec {
executeTest(input, expected)
}
}
+
+/** Uniquify used to be its own pass. We ported the tests to run with the combined LowerTypes pass. */
+class LowerTypesUniquifySpec extends FirrtlFlatSpec {
+ private val compiler = new TransformManager(Seq(Dependency(firrtl.passes.LowerTypes)))
+
+ private def executeTest(input: String, expected: Seq[String]): Unit = executeTest(input, expected, Seq.empty, Seq.empty)
+ private def executeTest(input: String, expected: Seq[String],
+ inputAnnos: Seq[Annotation], expectedAnnos: Seq[Annotation]): Unit = {
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val result = compiler.runTransform(CircuitState(circuit, inputAnnos))
+ val lines = result.circuit.serialize.split("\n") map normalized
+
+ expected.map(normalized).foreach { e =>
+ assert(lines.contains(e), f"Failed to find $e in ${lines.mkString("\n")}")
+ }
+
+ result.annotations.toSeq should equal(expectedAnnos)
+ }
+
+ behavior of "LowerTypes"
+
+ it should "rename colliding ports" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input a : { flip b : UInt<1>, c : { d : UInt<2>, flip e : UInt<3>}[2], c_1_e : UInt<4>}[2]
+ | output a_0_c_ : UInt<5>
+ | output a__0 : UInt<6>
+ """.stripMargin
+ val expected = Seq(
+ "output a___0_b : UInt<1>",
+ "input a___0_c__0_d : UInt<2>",
+ "output a___0_c__0_e : UInt<3>",
+ "output a_0_c_ : UInt<5>",
+ "output a__0 : UInt<6>")
+
+ val m = CircuitTarget("Test").module("Test")
+ val inputAnnos = Seq(
+ DontTouchAnnotation(m.ref("a").index(0).field("b")),
+ DontTouchAnnotation(m.ref("a").index(0).field("c").index(0).field("e")))
+
+ val expectedAnnos = Seq(
+ DontTouchAnnotation(m.ref("a___0_b")),
+ DontTouchAnnotation(m.ref("a___0_c__0_e")))
+
+
+ executeTest(input, expected, inputAnnos, expectedAnnos)
+ }
+
+ it should "rename colliding registers" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock
+ | reg a : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2], clock
+ | reg a_0_c_ : UInt<5>, clock
+ | reg a__0 : UInt<6>, clock
+ """.stripMargin
+ val expected = Seq(
+ "reg a___0_b : UInt<1>, clock with :",
+ "reg a___1_c__1_e : UInt<3>, clock with :",
+ "reg a___0_c_1_e : UInt<4>, clock with :",
+ "reg a_0_c_ : UInt<5>, clock with :",
+ "reg a__0 : UInt<6>, clock with :")
+
+ executeTest(input, expected)
+ }
+
+ it should "rename colliding nodes" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock
+ | reg x : { b : UInt<1>, c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2], clock
+ | node a = x
+ | node a_0_c_ = a[0].b
+ | node a__0 = a[1].c[0].d
+ """.stripMargin
+ val expected = Seq(
+ "node a___0_b = x_0_b",
+ "node a___1_c__1_e = x_1_c__1_e",
+ "node a___1_c_1_e = x_1_c_1_e"
+ )
+
+ executeTest(input, expected)
+ }
+
+
+ it should "rename DefRegister expressions: clock, reset, and init" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock[2]
+ | input clock_0 : Clock
+ | input reset : { a : UInt<1>, b : UInt<1>}
+ | input reset_a : UInt<1>
+ | input init : { a : UInt<4>, b : { c : UInt<4>, d : UInt<4>}[2], b_1_c : UInt<4>}[4]
+ | input init_0_a : UInt<4>
+ | reg foo : UInt<4>, clock[1], with :
+ | reset => (reset.a, init[3].b[1].d)
+ """.stripMargin
+ val expected = Seq(
+ "reg foo : UInt<4>, clock__1 with :",
+ "reset => (reset__a, init__3_b__1_d)"
+ )
+
+ executeTest(input, expected)
+ }
+
+ it should "rename ports before statements" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input data : { a : UInt<4>, b : UInt<4>}[2]
+ | node data_0_a = data[0].a
+ """.stripMargin
+ val expected = Seq(
+ "input data_0_a : UInt<4>",
+ "input data_0_b : UInt<4>",
+ "input data_1_a : UInt<4>",
+ "input data_1_b : UInt<4>",
+ "node data_0_a_ = data_0_a"
+ )
+
+ executeTest(input, expected)
+ }
+
+ it should "rename ports before statements (instance)" in {
+ val input =
+ """circuit Test :
+ | module Child:
+ | skip
+ | module Test :
+ | input data : { a : UInt<4>, b : UInt<4>}[2]
+ | inst data_0_a of Child
+ """.stripMargin
+ val expected = Seq(
+ "input data_0_a : UInt<4>",
+ "input data_0_b : UInt<4>",
+ "input data_1_a : UInt<4>",
+ "input data_1_b : UInt<4>",
+ "inst data_0_a_ of Child"
+ )
+
+ executeTest(input, expected)
+ }
+
+ it should "rename ports before statements (mem)" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input data : { a : UInt<4>, b : UInt<4>}[2]
+ | mem data_0_a :
+ | data-type => UInt<1>
+ | depth => 32
+ | read-latency => 0
+ | write-latency => 1
+ | reader => read
+ | writer => write
+ """.stripMargin
+ val expected = Seq(
+ "input data_0_a : UInt<4>",
+ "input data_0_b : UInt<4>",
+ "input data_1_a : UInt<4>",
+ "input data_1_b : UInt<4>",
+ "mem data_0_a_ :"
+ )
+
+ executeTest(input, expected)
+ }
+
+ it should "rename node expressions" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input data : { a : UInt<4>, b : UInt<4>[2]}
+ | input data_a : UInt<4>
+ | input data__b_1 : UInt<4>
+ | node foo = data.a
+ | node bar = data.b[1]
+ """.stripMargin
+ val expected = Seq(
+ "node foo = data___a",
+ "node bar = data___b_1")
+
+ executeTest(input, expected)
+ }
+
+ it should "rename both side of connects" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2]
+ | output a_0_b : UInt<1>
+ | input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2]
+ | a_0_b <= a[0].b
+ | a[0].c <- a__0_c_
+ """.stripMargin
+ val expected = Seq(
+ "a_0_b <= a___0_b",
+ "a___0_c__0_d <= a__0_c__0_d",
+ "a___0_c__0_e <= a__0_c__0_e",
+ "a___0_c__1_d <= a__0_c__1_d",
+ "a___0_c__1_e <= a__0_c__1_e"
+ )
+
+ executeTest(input, expected)
+ }
+
+ it should "rename deeply nested expressions" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input a : { b : UInt<1>, flip c : { d : UInt<2>, e : UInt<3>}[2], c_1_e : UInt<4>}[2]
+ | output a_0_b : UInt<1>
+ | input a__0_c_ : { d : UInt<2>, e : UInt<3>}[2]
+ | a_0_b <= mux(a[UInt(0)].c_1_e, or(a[or(a[0].b, a[1].b)].b, xorr(a[0].c_1_e)), orr(cat(a__0_c_[0].e, a[1].c_1_e)))
+ """.stripMargin
+ val expected = Seq(
+ "a_0_b <= mux(a___0_c_1_e, or(_a_or_b, xorr(a___0_c_1_e)), orr(cat(a__0_c__0_e, a___1_c_1_e)))"
+ )
+
+ executeTest(input, expected)
+ }
+
+ it should "rename memories" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock
+ | mem mem :
+ | data-type => { a : UInt<8>, b : UInt<8>[2]}[2]
+ | depth => 32
+ | read-latency => 0
+ | write-latency => 1
+ | reader => read
+ | writer => write
+ | node mem_0_b = mem.read.data[0].b
+ |
+ | mem.read.addr is invalid
+ | mem.read.en <= UInt(1)
+ | mem.read.clk <= clock
+ | mem.write.data is invalid
+ | mem.write.mask is invalid
+ | mem.write.addr is invalid
+ | mem.write.en <= UInt(0)
+ | mem.write.clk <= clock
+ """.stripMargin
+ val expected = Seq(
+ "mem mem__0_b_0 :",
+ "node mem_0_b_0 = mem__0_b_0.read.data",
+ "node mem_0_b_1 = mem__0_b_1.read.data",
+ "mem__0_b_0.read.addr is invalid")
+
+ executeTest(input, expected)
+ }
+
+ it should "rename aggregate typed memories" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock
+ | mem mem :
+ | data-type => { a : UInt<8>, b : UInt<8>[2], b_0 : UInt<8> }
+ | depth => 32
+ | read-latency => 0
+ | write-latency => 1
+ | reader => read
+ | writer => write
+ | node x = mem.read.data.b[0]
+ |
+ | mem.read.addr is invalid
+ | mem.read.en <= UInt(1)
+ | mem.read.clk <= clock
+ | mem.write.data is invalid
+ | mem.write.mask is invalid
+ | mem.write.addr is invalid
+ | mem.write.en <= UInt(0)
+ | mem.write.clk <= clock
+ """.stripMargin
+ val expected = Seq(
+ "mem mem_a :",
+ "mem mem_b__0 :",
+ "mem mem_b__1 :",
+ "mem mem_b_0 :",
+ "node x = mem_b__0.read.data")
+
+ executeTest(input, expected)
+ }
+
+ it should "rename instances and their ports" in {
+ val input =
+ """circuit Test :
+ | module Other :
+ | input a : { b : UInt<4>, c : UInt<4> }
+ | output a_b : UInt<4>
+ | a_b <= a.b
+ |
+ | module Test :
+ | node x = UInt(6)
+ | inst mod of Other
+ | mod.a.b <= x
+ | mod.a.c <= x
+ | node mod_a_b = mod.a_b
+ """.stripMargin
+ val expected = Seq(
+ "inst mod_ of Other",
+ "mod_.a__b <= x",
+ "mod_.a__c <= x",
+ "node mod_a_b = mod_.a_b")
+
+ executeTest(input, expected)
+ }
+
+ it should "quickly rename deep bundles" in {
+ val depth = 500
+ // We previously used a fixed time to determine if this test passed or failed.
+ // This test would pass under normal conditions, but would fail during coverage tests.
+ // Instead of using a fixed time, we run the test once (with a rename depth of 1), and record the time,
+ // then run it again with a depth of 500 and verify that the difference is below a fixed threshold.
+ // Additionally, since executions times vary significantly under coverage testing, we check a global
+ // to see if timing measurements are accurate enough to enforce the timing checks.
+ val threshold = depth * 2.0
+ // As of 20-Feb-2019, this still fails occasionally:
+ // [info] 9038.99351 was not less than 6113.865 (UniquifySpec.scala:317)
+ // Run the "quick" test three times and choose the longest time as the basis.
+ val nCalibrationRuns = 3
+ def mkType(i: Int): String = {
+ if(i == 0) "UInt<8>" else s"{x: ${mkType(i - 1)}}"
+ }
+ val timesMs = (
+ for (depth <- (List.fill(nCalibrationRuns)(1) :+ depth)) yield {
+ val input = s"""circuit Test:
+ | module Test :
+ | input in: ${mkType(depth)}
+ | output out: ${mkType(depth)}
+ | out <= in
+ |""".stripMargin
+ val (ms, _) = Utils.time(compileToVerilog(input))
+ ms
+ }
+ ).toArray
+ // The baseMs will be the maximum of the first calibration runs
+ val baseMs = timesMs.slice(0, nCalibrationRuns - 1).max
+ val renameMs = timesMs(nCalibrationRuns)
+ if (TestOptions.accurateTiming)
+ renameMs shouldBe < (baseMs * threshold)
+ }
+}
+
diff --git a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
index f19d52ae..802596c5 100644
--- a/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
+++ b/src/test/scala/firrtlTests/LoweringCompilersSpec.scala
@@ -147,12 +147,8 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.Resolved, Forms.WorkingIR)
val patches = Seq(
- // ResolveFlows no longer depends in Uniquify (ResolveKinds and InferTypes are fixup passes that get moved as well)
+ // Uniquify is now part of [[firrtl.passes.LowerTypes]]
Del(5), Del(6), Del(7),
- // Uniquify now is run before InferBinary Points which claims to need Uniquify
- Add(9, Seq(Dependency(firrtl.passes.Uniquify),
- Dependency(firrtl.passes.ResolveKinds),
- Dependency(firrtl.passes.InferTypes))),
Add(14, Seq(Dependency.fromTransform(firrtl.passes.CheckTypes)))
)
compare(legacyTransforms(new ResolveAndCheck), tm, patches)
@@ -165,13 +161,12 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
val patches = Seq(
Add(4, Seq(Dependency(firrtl.passes.ResolveFlows))),
Add(5, Seq(Dependency(firrtl.passes.ResolveKinds))),
- Add(6, Seq(Dependency(firrtl.passes.ResolveKinds),
- Dependency(firrtl.passes.InferTypes),
- Dependency(firrtl.passes.ResolveFlows))),
+ // Uniquify is now part of [[firrtl.passes.LowerTypes]]
+ Del(6),
+ Add(6, Seq(Dependency(firrtl.passes.ResolveFlows))),
Del(7),
Del(8),
- Add(7, Seq(Dependency(firrtl.passes.ResolveKinds),
- Dependency[firrtl.passes.ExpandWhensAndCheck])),
+ Add(7, Seq(Dependency[firrtl.passes.ExpandWhensAndCheck])),
Del(11),
Del(12),
Del(13),
@@ -191,6 +186,8 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
it should "replicate the old order" in {
val tm = new TransformManager(Forms.LowForm, Forms.MidForm)
val patches = Seq(
+ // Uniquify is now part of [[firrtl.passes.LowerTypes]]
+ Del(2), Del(3), Del(5),
// RemoveWires now visibly invalidates ResolveKinds
Add(11, Seq(Dependency(firrtl.passes.ResolveKinds)))
)
@@ -298,7 +295,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
compare(expected, tm)
}
- it should "work for Mid -> High" in {
+ it should "work for Mid -> High" ignore {
val expected =
new TransformManager(Forms.MidForm).flattenedTransformOrder ++
Some(new Transforms.MidToHigh) ++
@@ -307,7 +304,7 @@ class LoweringCompilersSpec extends AnyFlatSpec with Matchers {
compare(expected, tm)
}
- it should "work for Mid -> Chirrtl" in {
+ it should "work for Mid -> Chirrtl" ignore {
val expected =
new TransformManager(Forms.MidForm).flattenedTransformOrder ++
Some(new Transforms.MidToChirrtl) ++
diff --git a/src/test/scala/firrtlTests/MemoryInitSpec.scala b/src/test/scala/firrtlTests/MemoryInitSpec.scala
index 0826746b..5598e58b 100644
--- a/src/test/scala/firrtlTests/MemoryInitSpec.scala
+++ b/src/test/scala/firrtlTests/MemoryInitSpec.scala
@@ -129,7 +129,7 @@ class MemInitSpec extends FirrtlFlatSpec {
val annos = Seq(MemoryScalarInitAnnotation(mRef, 0))
compile(annos, "UInt<32>[2]")
}
- assert(caught.getMessage.endsWith("[module MemTest] Cannot initialize memory of non ground type UInt<32>[2]"))
+ assert(caught.getMessage.endsWith("Cannot initialize memory m of non ground type UInt<32>[2]"))
}
"MemoryScalarInitAnnotation on Memory with Bundle type" should "fail" in {
@@ -137,7 +137,7 @@ class MemInitSpec extends FirrtlFlatSpec {
val annos = Seq(MemoryScalarInitAnnotation(mRef, 0))
compile(annos, "{real: SInt<10>, imag: SInt<10>}")
}
- assert(caught.getMessage.endsWith("[module MemTest] Cannot initialize memory of non ground type { real : SInt<10>, imag : SInt<10>}"))
+ assert(caught.getMessage.endsWith("Cannot initialize memory m of non ground type { real : SInt<10>, imag : SInt<10>}"))
}
private def jsonAnno(name: String, suffix: String): String =
diff --git a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
index f847fb6c..fdb129a1 100644
--- a/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
+++ b/src/test/scala/firrtlTests/transforms/GroupComponentsSpec.scala
@@ -364,9 +364,9 @@ class GroupComponentsSpec extends MiddleTransformSpec {
| out <= add(in, wrapper.other_out)
| module Wrapper :
| output other_out: UInt<16>
- | inst other_ of Other
- | other_out <= other_.out
- | other_.in is invalid
+ | inst other of Other
+ | other_out <= other.out
+ | other.in is invalid
| module Other:
| input in: UInt<16>
| output out: UInt<16>