aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Izraelevitz2017-05-10 11:23:18 -0700
committerGitHub2017-05-10 11:23:18 -0700
commit8b8eb4eac5b353d4d632065c78faf6a706d6aae8 (patch)
tree39e2d9344166b61b376df9d3cd15a4787bcd01f4
parentaf222c1737fa72fce964190876346bdb7ff220cd (diff)
Update rename2 (#478)
* Added pass name to debug logger * Addresses #459. Rewords transform annotations API. Now, any annotation not propagated by a transform is considered deleted. A new DeletedAnnotation is added in place of it. * Added more stylized debugging style * WIP: make pass transform * WIP: All tests pass, need to pull master * Cleaned up PR * Added rename updates to all core transforms * Added more rename tests, and bugfixes * Renaming tracks non-leaf subfields E.g. given: wire x: {a: UInt<1>, b: UInt<1>[2]} Annotating x.b will eventually annotate x_b_0 and x_b_1 * Bugfix instance rename lowering broken * Address review comments * Remove check for seqTransform, UnknownForm too restrictive check
-rw-r--r--src/main/scala/firrtl/Compiler.scala50
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala2
-rw-r--r--src/main/scala/firrtl/Utils.scala2
-rw-r--r--src/main/scala/firrtl/passes/DeadCodeElimination.scala25
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala128
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala51
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala44
-rw-r--r--src/main/scala/firrtl/passes/ZeroWidth.scala52
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala2
-rw-r--r--src/test/scala/firrtlTests/AnnotationTests.scala293
-rw-r--r--src/test/scala/firrtlTests/ChirrtlSpec.scala13
-rw-r--r--src/test/scala/firrtlTests/ExpandWhensSpec.scala16
-rw-r--r--src/test/scala/firrtlTests/LowerTypesSpec.scala9
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala1
-rw-r--r--src/test/scala/firrtlTests/UniquifySpec.scala9
-rw-r--r--src/test/scala/firrtlTests/ZeroWidthTests.scala21
16 files changed, 619 insertions, 99 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 6e5cadcd..6c3911d6 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -2,9 +2,10 @@
package firrtl
-import logger.LazyLogging
+import logger._
import java.io.Writer
import annotations._
+import scala.collection.mutable
import firrtl.ir.Circuit
import passes.Pass
@@ -14,7 +15,47 @@ import Utils.throwInternalError
* RenameMap maps old names to modified names. Generated by transformations
* that modify names
*/
-case class RenameMap(map: Map[Named, Seq[Named]] = Map[Named, Seq[Named]]())
+object RenameMap {
+ def apply(map: Map[Named, Seq[Named]]) = {
+ val rm = new RenameMap
+ rm.addMap(map)
+ rm
+ }
+ def apply() = new RenameMap
+}
+class RenameMap {
+ val renameMap = new mutable.HashMap[Named, Seq[Named]]()
+ private var circuitName: String = ""
+ private var moduleName: String = ""
+ def setModule(s: String) =
+ moduleName = s
+ def setCircuit(s: String) =
+ circuitName = s
+ def rename(from: String, to: String): Unit = rename(from, Seq(to))
+ def rename(from: String, tos: Seq[String]): Unit = {
+ val fromName = ComponentName(from, ModuleName(moduleName, CircuitName(circuitName)))
+ val tosName = tos map { to =>
+ ComponentName(to, ModuleName(moduleName, CircuitName(circuitName)))
+ }
+ rename(fromName, tosName)
+ }
+ def rename(from: Named, to: Named): Unit = rename(from, Seq(to))
+ def rename(from: Named, tos: Seq[Named]): Unit = (from, tos) match {
+ case (x, Seq(y)) if x == y =>
+ case _ =>
+ renameMap(from) = renameMap.getOrElse(from, Seq.empty) ++ tos
+ }
+ def delete(names: Seq[String]): Unit = names.foreach(delete(_))
+ def delete(name: String): Unit =
+ delete(ComponentName(name, ModuleName(moduleName, CircuitName(circuitName))))
+ def delete(name: Named): Unit =
+ renameMap(name) = Seq.empty
+ def addMap(map: Map[Named, Seq[Named]]) =
+ renameMap ++= map
+ def serialize: String = renameMap.map { case (k, v) =>
+ k.serialize + "=>" + v.map(_.serialize).mkString(", ")
+ }.mkString("\n")
+}
/**
* Container of all annotations for a Firrtl compiler.
@@ -172,7 +213,6 @@ abstract class Transform extends LazyLogging {
}
logger.trace(s"Circuit:\n${result.circuit.serialize}")
logger.info(s"======== Finished Transform $name ========\n")
-
CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None)
}
@@ -200,7 +240,7 @@ abstract class Transform extends LazyLogging {
}
// For each annotation, rename all annotations.
- val renames = renameOpt.getOrElse(RenameMap()).map
+ val renames = renameOpt.getOrElse(RenameMap()).renameMap
for {
anno <- newAnnotations.toSeq
newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target)))
@@ -217,8 +257,10 @@ trait SeqTransformBased {
/** For transformations that are simply a sequence of transforms */
abstract class SeqTransform extends Transform with SeqTransformBased {
def execute(state: CircuitState): CircuitState = {
+ /*
require(state.form <= inputForm,
s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
+ */
val ret = runTransforms(state)
CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames)
}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index b9042781..84f237a3 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -6,7 +6,6 @@ sealed abstract class CoreTransform extends SeqTransform
/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting
* circuit has only IR nodes, not WIR.
- * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
*/
class ChirrtlToHighFirrtl extends CoreTransform {
def inputForm = ChirrtlForm
@@ -75,7 +74,6 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
/** Expands all aggregate types into many ground-typed components. Must
* accept a well-formed graph of only middle Firrtl features.
* Operates on working IR nodes.
- * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
*/
class MiddleFirrtlToLowFirrtl extends CoreTransform {
def inputForm = MidForm
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 76156d2d..586fe1e7 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -541,7 +541,7 @@ object Utils extends LazyLogging {
}
/** Adds a root reference to some SubField/SubIndex chain */
- def mergeRef(root: WRef, body: Expression): Expression = body match {
+ def mergeRef(root: Expression, body: Expression): Expression = body match {
case e: WRef =>
WSubField(root, e.name, e.tpe, e.gender)
case e: WSubIndex =>
diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala
index 9f249f35..54ac76fe 100644
--- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala
@@ -9,8 +9,10 @@ import firrtl.Mappers._
import annotation.tailrec
-object DeadCodeElimination extends Pass {
- private def dceOnce(s: Statement): (Statement, Long) = {
+object DeadCodeElimination extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
+ private def dceOnce(renames: RenameMap)(s: Statement): (Statement, Long) = {
val referenced = collection.mutable.HashSet[String]()
var nEliminated = 0L
@@ -28,6 +30,7 @@ object DeadCodeElimination extends Pass {
if (referenced(name)) x
else {
nEliminated += 1
+ renames.delete(name)
EmptyStmt
}
@@ -43,16 +46,22 @@ object DeadCodeElimination extends Pass {
}
@tailrec
- private def dce(s: Statement): Statement = {
- val (res, n) = dceOnce(s)
- if (n > 0) dce(res) else res
+ private def dce(renames: RenameMap)(s: Statement): Statement = {
+ val (res, n) = dceOnce(renames)(s)
+ if (n > 0) dce(renames)(res) else res
}
- def run(c: Circuit): Circuit = {
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => Module(m.info, m.name, m.ports, dce(m.body))
+ case m: Module =>
+ renames.setModule(m.name)
+ Module(m.info, m.name, m.ports, dce(renames)(m.body))
}
- Circuit(c.info, modulesx, c.main)
+ val result = Circuit(c.info, modulesx, c.main)
+ CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 5826f56e..b48ab338 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -2,9 +2,11 @@
package firrtl.passes
+import scala.collection.mutable
import firrtl._
import firrtl.ir._
import firrtl.Utils._
+import MemPortUtils.memType
import firrtl.Mappers._
/** Removes all aggregate types from a [[firrtl.ir.Circuit]]
@@ -20,7 +22,10 @@ import firrtl.Mappers._
* wire foo_b : UInt<16>
* }}}
*/
-object LowerTypes extends Pass {
+object LowerTypes extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
+
/** Delimiter used in lowering names */
val delim = "_"
/** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name
@@ -33,7 +38,49 @@ object LowerTypes extends Pass {
case e: WSubIndex => s"${loweredName(e.exp)}$delim${e.value}"
}
def loweredName(s: Seq[String]): String = s mkString delim
+ def renameExps(renames: RenameMap, n: String, t: Type, root: String): Seq[String] =
+ renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), root)
+ def renameExps(renames: RenameMap, n: String, t: Type): Seq[String] =
+ renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), "")
+ def renameExps(renames: RenameMap, e: Expression, root: String): Seq[String] = e.tpe match {
+ case (_: GroundType) =>
+ val name = root + loweredName(e)
+ renames.rename(e.serialize, name)
+ Seq(name)
+ case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) =>
+ val subNames = renameExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), root)
+ renames.rename(e.serialize, subNames)
+ names ++ subNames
+ }
+ case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) =>
+ val subNames = renameExps(renames, WSubIndex(e, i, t.tpe,gender(e)), root)
+ renames.rename(e.serialize, subNames)
+ names ++ subNames
+ }
+ }
+ private def renameMemExps(renames: RenameMap, e: Expression, portAndField: Expression): Seq[String] = e.tpe match {
+ case (_: GroundType) =>
+ val (mem, tail) = splitRef(e)
+ val loRef = mergeRef(WRef(loweredName(e)), portAndField)
+ val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
+ renames.rename(hiRef.serialize, loRef.serialize)
+ Seq(loRef.serialize)
+ case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) =>
+ val subNames = renameMemExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), portAndField)
+ val (mem, tail) = splitRef(e)
+ val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
+ renames.rename(hiRef.serialize, subNames)
+ names ++ subNames
+ }
+ case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) =>
+ val subNames = renameMemExps(renames, WSubIndex(e, i, t.tpe,gender(e)), portAndField)
+ val (mem, tail) = splitRef(e)
+ val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
+ renames.rename(hiRef.serialize, subNames)
+ names ++ subNames
+ }
+ }
private case class LowerTypesException(msg: String) extends FIRRTLException(msg)
private def error(msg: String)(info: Info, mname: String) =
throw LowerTypesException(s"$info: [module $mname] $msg")
@@ -111,34 +158,43 @@ object LowerTypes extends Pass {
case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname)
case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
-
def lowerTypesStmt(memDataTypeMap: MemDataTypeMap,
- minfo: Info, mname: String)(s: Statement): Statement = {
+ minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = {
val info = get_info(s) match {case NoInfo => minfo case x => x}
- s map lowerTypesStmt(memDataTypeMap, info, mname) match {
+ s map lowerTypesStmt(memDataTypeMap, info, mname, renames) match {
case s: DefWire => s.tpe match {
case _: GroundType => s
- case _ => Block(create_exps(s.name, s.tpe) map (
- e => DefWire(s.info, loweredName(e), e.tpe)))
+ case _ =>
+ val exps = create_exps(s.name, s.tpe)
+ val names = exps map loweredName
+ renameExps(renames, s.name, s.tpe)
+ Block((exps zip names) map { case (e, n) =>
+ DefWire(s.info, n, e.tpe)
+ })
}
case sx: DefRegister => sx.tpe match {
case _: GroundType => sx map lowerTypesExp(memDataTypeMap, info, mname)
case _ =>
val es = create_exps(sx.name, sx.tpe)
+ val names = es map loweredName
+ renameExps(renames, sx.name, sx.tpe)
val inits = create_exps(sx.init) map lowerTypesExp(memDataTypeMap, info, mname)
val clock = lowerTypesExp(memDataTypeMap, info, mname)(sx.clock)
val reset = lowerTypesExp(memDataTypeMap, info, mname)(sx.reset)
- Block(es zip inits map { case (e, i) =>
- DefRegister(sx.info, loweredName(e), e.tpe, clock, reset, i)
+ Block((es zip names) zip inits map { case ((e, n), i) =>
+ DefRegister(sx.info, n, e.tpe, clock, reset, i)
})
}
// Could instead just save the type of each Module as it gets processed
case sx: WDefInstance => sx.tpe match {
case t: BundleType =>
- val fieldsx = t.fields flatMap (f =>
- create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map (
+ val fieldsx = t.fields flatMap { f =>
+ renameExps(renames, f.name, sx.tpe, s"${sx.name}.")
+ create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map { e =>
// Flip because inst genders are reversed from Module type
- e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)))
+ Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
+ }
+ }
WDefInstance(sx.info, sx.name, sx.module, BundleType(fieldsx))
case _ => error("WDefInstance type should be Bundle!")(info, mname)
}
@@ -146,8 +202,30 @@ object LowerTypes extends Pass {
memDataTypeMap(sx.name) = sx.dataType
sx.dataType match {
case _: GroundType => sx
- case _ => Block(create_exps(sx.name, sx.dataType) map (e =>
- sx copy (name = loweredName(e), dataType = e.tpe)))
+ case _ =>
+ // Rename ports
+ val seen: mutable.Set[String] = mutable.Set[String]()
+ create_exps(sx.name, memType(sx)) foreach { e =>
+ val (mem, port, field, tail) = splitMemRef(e)
+ if (!seen.contains(field.name)) {
+ seen += field.name
+ val d = WRef(mem.name, sx.dataType)
+ tail match {
+ case None =>
+ create_exps(mem.name, sx.dataType) foreach { x =>
+ renames.rename(e.serialize, s"${loweredName(x)}.${port.serialize}.${field.serialize}")
+ }
+ case Some(_) =>
+ renameMemExps(renames, d, mergeRef(port, field))
+ }
+ }
+ }
+ Block(create_exps(sx.name, sx.dataType) map {e =>
+ val newName = loweredName(e)
+ // Rename mems
+ renames.rename(sx.name, newName)
+ sx copy (name = newName, dataType = e.tpe)
+ })
}
// wire foo : { a , b }
// node x = foo
@@ -159,7 +237,10 @@ object LowerTypes extends Pass {
case sx: DefNode =>
val names = create_exps(sx.name, sx.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname)
val exps = create_exps(sx.value) map lowerTypesExp(memDataTypeMap, info, mname)
- Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) })
+ renameExps(renames, sx.name, sx.value.tpe)
+ Block(names zip exps map { case (n, e) =>
+ DefNode(info, loweredName(n), e)
+ })
case sx: IsInvalid => kind(sx.expr) match {
case MemKind =>
Block(lowerTypesMemExp(memDataTypeMap, info, mname)(sx.expr) map (IsInvalid(info, _)))
@@ -176,21 +257,32 @@ object LowerTypes extends Pass {
}
}
- def lowerTypes(m: DefModule): DefModule = {
+ def lowerTypes(renames: RenameMap)(m: DefModule): DefModule = {
val memDataTypeMap = new MemDataTypeMap
+ renames.setModule(m.name)
// Lower Ports
val portsx = m.ports flatMap { p =>
val exps = create_exps(WRef(p.name, p.tpe, PortKind, to_gender(p.direction)))
- exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe))
+ val names = exps map loweredName
+ renameExps(renames, p.name, p.tpe)
+ (exps zip names) map { case (e, n) =>
+ Port(p.info, n, to_dir(gender(e)), e.tpe)
+ }
}
m match {
case m: ExtModule =>
m copy (ports = portsx)
case m: Module =>
- m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name)
+ m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name, renames)
}
}
- def run(c: Circuit): Circuit = c copy (modules = c.modules map lowerTypes)
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = c copy (modules = c.modules map lowerTypes(renames))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
+ }
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index b072dfa0..c841dc32 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -14,7 +14,9 @@ case class MPort(name: String, clk: Expression)
case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort])
case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean)
-object RemoveCHIRRTL extends Pass {
+object RemoveCHIRRTL extends Transform {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
val ut = UnknownType
type MPortMap = collection.mutable.LinkedHashMap[String, MPorts]
type SeqMemSet = collection.mutable.HashSet[String]
@@ -22,6 +24,15 @@ object RemoveCHIRRTL extends Pass {
type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef]
type AddrMap = collection.mutable.HashMap[String, Expression]
+ def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match {
+ case _: GroundType => Seq(ex)
+ case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex)
+ case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex)
+ case UnknownType => Seq(ex)
+ }
+
def create_exps(e: Expression): Seq[Expression] = e match {
case ex: Mux =>
val e1s = create_exps(ex.tval)
@@ -59,7 +70,7 @@ object RemoveCHIRRTL extends Pass {
}
def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
- refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match {
+ refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size)))
@@ -104,11 +115,25 @@ object RemoveCHIRRTL extends Pass {
addrs += "addr"
clks += "clk"
ens += "en"
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata")
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe))
+ val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe))
+ ((es zip rs) zip ws) map {
+ case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
+ }
case MWrite =>
refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false)
addrs += "addr"
clks += "clk"
ens += "en"
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
+ (es zip ws) map {
+ case (e, w) => renames.rename(e.serialize, w.serialize)
+ }
case MRead =>
refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false)
addrs += "addr"
@@ -118,13 +143,19 @@ object RemoveCHIRRTL extends Pass {
raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut)
case _ => ens += "en"
}
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
+ (es zip rs) map {
+ case (e, r) => renames.rename(e.serialize, r.serialize)
+ }
case MInfer => // do nothing if it's not being used
}
Block(
(addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++
(clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++
(ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one))))
- case sx => sx map collect_refs(mports, smems, types, refs, raddrs)
+ case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
}
def get_mask(refs: DataRefMap)(e: Expression): Expression =
@@ -213,17 +244,23 @@ object RemoveCHIRRTL extends Pass {
}
}
- def remove_chirrtl_m(m: DefModule): DefModule = {
+ def remove_chirrtl_m(renames: RenameMap)(m: DefModule): DefModule = {
val mports = new MPortMap
val smems = new SeqMemSet
val types = new MPortTypeMap
val refs = new DataRefMap
val raddrs = new AddrMap
+ renames.setModule(m.name)
(m map collect_smems_and_mports(mports, smems)
- map collect_refs(mports, smems, types, refs, raddrs)
+ map collect_refs(mports, smems, types, refs, raddrs, renames)
map remove_chirrtl_s(refs, raddrs))
}
- def run(c: Circuit): Circuit =
- c copy (modules = c.modules map remove_chirrtl_m)
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = c copy (modules = c.modules map remove_chirrtl_m(renames))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
+ }
}
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index deddb93e..61bd68d0 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -31,7 +31,9 @@ import MemPortUtils.memType
* there WOULD be collisions in references a[0] and a_0 so we still have
* to rename a
*/
-object Uniquify extends Pass {
+object Uniquify extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
private case class UniquifyException(msg: String) extends FIRRTLException(msg)
private def error(msg: String)(implicit sinfo: Info, mname: String) =
throw new UniquifyException(s"$sinfo: [module $mname] $msg")
@@ -224,7 +226,10 @@ object Uniquify extends Pass {
}
// Everything wrapped in run so that it's thread safe
- def run(c: Circuit): Circuit = {
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
// Debug state
implicit var mname: String = ""
implicit var sinfo: Info = NoInfo
@@ -232,7 +237,8 @@ object Uniquify extends Pass {
val portNameMap = collection.mutable.HashMap[String, Map[String, NameMapNode]]()
val portTypeMap = collection.mutable.HashMap[String, Type]()
- def uniquifyModule(m: DefModule): DefModule = {
+ def uniquifyModule(renames: RenameMap)(m: DefModule): DefModule = {
+ renames.setModule(m.name)
val namespace = collection.mutable.HashSet[String]()
val nameMap = collection.mutable.HashMap[String, NameMapNode]()
@@ -251,7 +257,11 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- DefWire(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts))
+ val newType = uniquifyNamesType(sx.tpe, node.elts)
+ (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ DefWire(sx.info, node.name, newType)
} else {
sx
}
@@ -259,8 +269,11 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- DefRegister(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts),
- sx.clock, sx.reset, sx.init)
+ val newType = uniquifyNamesType(sx.tpe, node.elts)
+ (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init)
} else {
sx
}
@@ -268,7 +281,11 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- WDefInstance(sx.info, node.name, sx.module, sx.tpe)
+ val newType = portTypeMap(m.name)
+ (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ WDefInstance(sx.info, node.name, sx.module, newType)
} else {
sx
}
@@ -280,6 +297,9 @@ object Uniquify extends Pass {
val mem = sx.copy(name = node.name, dataType = dataType)
// Create new mapping to handle references to memory data fields
val uniqueMemMap = createNameMapping(memType(sx), memType(mem))
+ (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap)
mem
} else {
@@ -289,6 +309,9 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
+ (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
DefNode(sx.info, node.name, sx.value)
} else {
sx
@@ -320,7 +343,8 @@ object Uniquify extends Pass {
}
}
- def uniquifyPorts(m: DefModule): DefModule = {
+ def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = {
+ renames.setModule(m.name)
def uniquifyPorts(ports: Seq[Port]): Seq[Port] = {
val portsType = BundleType(ports map {
case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
@@ -331,6 +355,7 @@ object Uniquify extends Pass {
portTypeMap += (m.name -> uniquePortsType)
ports zip uniquePortsType.fields map { case (p, f) =>
+ renames.rename(p.name, f.name)
Port(p.info, f.name, p.direction, f.tpe)
}
}
@@ -344,7 +369,8 @@ object Uniquify extends Pass {
}
sinfo = c.info
- Circuit(c.info, c.modules map uniquifyPorts map uniquifyModule, c.main)
+ val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main)
+ CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala
index a2ec9935..2472a0e5 100644
--- a/src/main/scala/firrtl/passes/ZeroWidth.scala
+++ b/src/main/scala/firrtl/passes/ZeroWidth.scala
@@ -10,8 +10,24 @@ import firrtl.Mappers._
import firrtl.Utils.throwInternalError
-object ZeroWidth extends Pass {
+object ZeroWidth extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
private val ZERO = BigInt(0)
+ private def getRemoved(x: IsDeclaration): Seq[String] = {
+ var removedNames: Seq[String] = Seq.empty
+ def onType(name: String)(t: Type): Type = {
+ removedNames = Utils.create_exps(name, t) map {e => (e, e.tpe)} collect {
+ case (e, GroundType(IntWidth(ZERO))) => e.serialize
+ }
+ t
+ }
+ x match {
+ case s: Statement => s map onType(s.name)
+ case Port(_, name, _, t) => onType(name)(t)
+ }
+ removedNames
+ }
private def removeZero(t: Type): Option[Type] = t match {
case GroundType(IntWidth(ZERO)) => None
case BundleType(fields) =>
@@ -34,35 +50,53 @@ object ZeroWidth extends Pass {
def replaceType(x: Type): Type = t
(e map replaceType) map onExp
}
- private def onStmt(s: Statement): Statement = s match {
+ private def onStmt(renames: RenameMap)(s: Statement): Statement = s match {
case (_: DefWire| _: DefRegister| _: DefMemory) =>
+ // List all removed expression names, and delete them from renames
+ renames.delete(getRemoved(s.asInstanceOf[IsDeclaration]))
+ // Create new types without zero-width wires
var removed = false
def applyRemoveZero(t: Type): Type = removeZero(t) match {
case None => removed = true; t
case Some(tx) => tx
}
val sxx = (s map onExp) map applyRemoveZero
+ // Return new declaration
if(removed) EmptyStmt else sxx
case Connect(info, loc, exp) => removeZero(loc.tpe) match {
case None => EmptyStmt
case Some(t) => Connect(info, loc, onExp(exp))
}
+ case IsInvalid(info, exp) => removeZero(exp.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => IsInvalid(info, onExp(exp))
+ }
case DefNode(info, name, value) => removeZero(value.tpe) match {
case None => EmptyStmt
case Some(t) => DefNode(info, name, onExp(value))
}
- case sx => sx map onStmt
+ case sx => sx map onStmt(renames)
}
- private def onModule(m: DefModule): DefModule = {
- val ports = m.ports map (p => (p, removeZero(p.tpe))) collect {
- case (Port(info, name, dir, _), Some(t)) => Port(info, name, dir, t)
+ private def onModule(renames: RenameMap)(m: DefModule): DefModule = {
+ renames.setModule(m.name)
+ // For each port, record deleted subcomponents
+ m.ports.foreach{p => renames.delete(getRemoved(p))}
+ val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap {
+ case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t))
+ case (Port(_, name, _, _), None) =>
+ renames.delete(name)
+ Nil
}
m match {
case ext: ExtModule => ext.copy(ports = ports)
- case in: Module => in.copy(ports = ports, body = onStmt(in.body))
+ case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body))
}
}
- def run(c: Circuit): Circuit = {
- InferTypes.run(c.copy(modules = c.modules map onModule))
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = InferTypes.run(c.copy(modules = c.modules map onModule(renames)))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index caaf430b..8cbf9da7 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -131,7 +131,7 @@ class ReplSeqMem extends Transform {
new SimpleMidTransform(RemoveEmpty),
new SimpleMidTransform(CheckInitialization),
new SimpleMidTransform(InferTypes),
- new SimpleMidTransform(Uniquify),
+ Uniquify,
new SimpleMidTransform(ResolveKinds),
new SimpleMidTransform(ResolveGenders))
diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala
index 81d982e4..81c394c1 100644
--- a/src/test/scala/firrtlTests/AnnotationTests.scala
+++ b/src/test/scala/firrtlTests/AnnotationTests.scala
@@ -11,6 +11,7 @@ import firrtl.passes.InlineAnnotation
import firrtl.passes.memlib.PinAnnotation
import net.jcazevedo.moultingyaml._
import org.scalatest.Matchers
+import logger._
/**
* An example methodology for testing Firrtl annotations.
@@ -25,8 +26,8 @@ trait AnnotationSpec extends LowTransformSpec {
compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty)
}
}
- def execute(annotations: AnnotationMap, input: String, check: Annotation): Unit = {
- val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty)
+ def execute(aMap: Option[AnnotationMap], input: String, check: Annotation): Unit = {
+ val cr = compile(CircuitState(parse(input), ChirrtlForm, aMap), Seq.empty)
cr.annotations.get.annotations should contain (check)
}
}
@@ -40,20 +41,19 @@ trait AnnotationSpec extends LowTransformSpec {
* Unstable, Fickle, and Insistent can be tested.
*/
class AnnotationTests extends AnnotationSpec with Matchers {
- def getAMap (a: Annotation): AnnotationMap = AnnotationMap(Seq(a))
- val input: String =
- """circuit Top :
- | module Top :
- | input a : UInt<1>[2]
- | input b : UInt<1>
- | node c = b""".stripMargin
- val mName = ModuleName("Top", CircuitName("Top"))
- val aName = ComponentName("a", mName)
- val bName = ComponentName("b", mName)
- val cName = ComponentName("c", mName)
+ def getAMap(a: Annotation): Option[AnnotationMap] = Some(AnnotationMap(Seq(a)))
+ def getAMap(as: Seq[Annotation]): Option[AnnotationMap] = Some(AnnotationMap(as))
+ def anno(s: String, value: String ="this is a value"): Annotation =
+ Annotation(ComponentName(s, ModuleName("Top", CircuitName("Top"))), classOf[Transform], value)
"Loose and Sticky annotation on a node" should "pass through" in {
- val ta = Annotation(cName, classOf[Transform], "")
+ val input: String =
+ """circuit Top :
+ | module Top :
+ | input a : UInt<1>[2]
+ | input b : UInt<1>
+ | node c = b""".stripMargin
+ val ta = anno("c", "")
execute(getAMap(ta), input, ta)
}
@@ -134,11 +134,10 @@ class AnnotationTests extends AnnotationSpec with Matchers {
val outputForm = LowForm
def execute(state: CircuitState) = state.copy(annotations = None)
}
- val anno = InlineAnnotation(CircuitName("Top"))
- val annoOpt = Some(AnnotationMap(Seq(anno)))
- val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annoOpt), Seq(new DeletingTransform))
+ val inlineAnn = InlineAnnotation(CircuitName("Top"))
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(inlineAnn)), Seq(new DeletingTransform))
result.annotations.get.annotations.head should matchPattern {
- case DeletedAnnotation(x, anno) =>
+ case DeletedAnnotation(x, inlineAnn) =>
}
val exception = (intercept[FIRRTLException] {
result.getEmittedCircuit
@@ -146,4 +145,262 @@ class AnnotationTests extends AnnotationSpec with Matchers {
val deleted = result.deletedAnnotations
exception.str should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted")
}
+
+ "Renaming" should "propagate in Lowering of memories" in {
+ val compiler = new VerilogCompiler
+ // Uncomment to help debugging failing tests
+ // Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug))
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input in: UInt<3>
+ | mem m:
+ | data-type => {a: UInt<4>, b: UInt<4>[2]}
+ | depth => 8
+ | write-latency => 1
+ | read-latency => 0
+ | reader => r
+ | m.r.clk <= clk
+ | m.r.en <= UInt(1)
+ | m.r.addr <= in
+ |""".stripMargin
+ val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem"))
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should contain (anno("m_a", "mem"))
+ resultAnno should contain (anno("m_b_0", "mem"))
+ resultAnno should contain (anno("m_b_1", "mem"))
+ resultAnno should contain (anno("m_a.r.data", "all"))
+ resultAnno should contain (anno("m_b_0.r.data", "all"))
+ resultAnno should contain (anno("m_b_1.r.data", "all"))
+ resultAnno should contain (anno("m_b_0.r.data", "sub"))
+ resultAnno should contain (anno("m_b_1.r.data", "sub"))
+ resultAnno should not contain (anno("m"))
+ resultAnno should not contain (anno("r"))
+ }
+
+ "Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in {
+ val compiler = new VerilogCompiler
+ Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug))
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input in: UInt<3>
+ | cmem m: {a: UInt<4>, b: UInt<4>[2]}[8]
+ | read mport r = m[in], clk
+ |""".stripMargin
+ val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem"))
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should contain (anno("m_a", "mem"))
+ resultAnno should contain (anno("m_b_0", "mem"))
+ resultAnno should contain (anno("m_b_1", "mem"))
+ resultAnno should contain (anno("m_a.r.data", "all"))
+ resultAnno should contain (anno("m_b_0.r.data", "all"))
+ resultAnno should contain (anno("m_b_1.r.data", "all"))
+ resultAnno should contain (anno("m_b_0.r.data", "sub"))
+ resultAnno should contain (anno("m_b_1.r.data", "sub"))
+ resultAnno should not contain (anno("m"))
+ resultAnno should not contain (anno("r"))
+ }
+
+ "Renaming" should "propagate in ZeroWidth" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Top :
+ | input zero: UInt<0>
+ | wire x: {a: UInt<3>, b: UInt<0>}
+ | wire y: UInt<0>[3]
+ | y[0] <= zero
+ | y[1] <= zero
+ | y[2] <= zero
+ | x.a <= zero
+ | x.b <= zero
+ |""".stripMargin
+ val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]"))
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should contain (anno("x_a"))
+ resultAnno should not contain (anno("zero"))
+ resultAnno should not contain (anno("x.a"))
+ resultAnno should not contain (anno("x.b"))
+ resultAnno should not contain (anno("x_b"))
+ resultAnno should not contain (anno("y[0]"))
+ resultAnno should not contain (anno("y[1]"))
+ resultAnno should not contain (anno("y[2]"))
+ resultAnno should not contain (anno("y_0"))
+ resultAnno should not contain (anno("y_1"))
+ resultAnno should not contain (anno("y_2"))
+ }
+
+ "Renaming subcomponents" should "propagate in Lowering" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input pred: UInt<1>
+ | input in: {a: UInt<3>, b: UInt<3>[2]}
+ | output out: {a: UInt<3>, b: UInt<3>[2]}
+ | wire w: {a: UInt<3>, b: UInt<3>[2]}
+ | w is invalid
+ | node n = mux(pred, in, w)
+ | out <= n
+ | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
+ | cmem mem: {a: UInt<3>, b: UInt<3>[2]}[8]
+ | write mport write = mem[pred], clk
+ | write <= in
+ |""".stripMargin
+ val annos = Seq(
+ anno("in.a"), anno("in.b[0]"), anno("in.b[1]"),
+ anno("out.a"), anno("out.b[0]"), anno("out.b[1]"),
+ anno("w.a"), anno("w.b[0]"), anno("w.b[1]"),
+ anno("n.a"), anno("n.b[0]"), anno("n.b[1]"),
+ anno("r.a"), anno("r.b[0]"), anno("r.b[1]"),
+ anno("write.a"), anno("write.b[0]"), anno("write.b[1]")
+ )
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should not contain (anno("in.a"))
+ resultAnno should not contain (anno("in.b[0]"))
+ resultAnno should not contain (anno("in.b[1]"))
+ resultAnno should not contain (anno("out.a"))
+ resultAnno should not contain (anno("out.b[0]"))
+ resultAnno should not contain (anno("out.b[1]"))
+ resultAnno should not contain (anno("w.a"))
+ resultAnno should not contain (anno("w.b[0]"))
+ resultAnno should not contain (anno("w.b[1]"))
+ resultAnno should not contain (anno("n.a"))
+ resultAnno should not contain (anno("n.b[0]"))
+ resultAnno should not contain (anno("n.b[1]"))
+ resultAnno should not contain (anno("r.a"))
+ resultAnno should not contain (anno("r.b[0]"))
+ resultAnno should not contain (anno("r.b[1]"))
+ resultAnno should contain (anno("in_a"))
+ resultAnno should contain (anno("in_b_0"))
+ resultAnno should contain (anno("in_b_1"))
+ resultAnno should contain (anno("out_a"))
+ resultAnno should contain (anno("out_b_0"))
+ resultAnno should contain (anno("out_b_1"))
+ resultAnno should contain (anno("w_a"))
+ resultAnno should contain (anno("w_b_0"))
+ resultAnno should contain (anno("w_b_1"))
+ resultAnno should contain (anno("n_a"))
+ resultAnno should contain (anno("n_b_0"))
+ resultAnno should contain (anno("n_b_1"))
+ resultAnno should contain (anno("r_a"))
+ resultAnno should contain (anno("r_b_0"))
+ resultAnno should contain (anno("r_b_1"))
+ resultAnno should contain (anno("mem_a.write.data"))
+ resultAnno should contain (anno("mem_b_0.write.data"))
+ resultAnno should contain (anno("mem_b_1.write.data"))
+ }
+
+ "Renaming components" should "expand in Lowering" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input pred: UInt<1>
+ | input in: {a: UInt<3>, b: UInt<3>[2]}
+ | output out: {a: UInt<3>, b: UInt<3>[2]}
+ | wire w: {a: UInt<3>, b: UInt<3>[2]}
+ | w is invalid
+ | node n = mux(pred, in, w)
+ | out <= n
+ | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
+ |""".stripMargin
+ val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r"))
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should contain (anno("in_a"))
+ resultAnno should contain (anno("in_b_0"))
+ resultAnno should contain (anno("in_b_1"))
+ resultAnno should contain (anno("out_a"))
+ resultAnno should contain (anno("out_b_0"))
+ resultAnno should contain (anno("out_b_1"))
+ resultAnno should contain (anno("w_a"))
+ resultAnno should contain (anno("w_b_0"))
+ resultAnno should contain (anno("w_b_1"))
+ resultAnno should contain (anno("n_a"))
+ resultAnno should contain (anno("n_b_0"))
+ resultAnno should contain (anno("n_b_1"))
+ resultAnno should contain (anno("r_a"))
+ resultAnno should contain (anno("r_b_0"))
+ resultAnno should contain (anno("r_b_1"))
+ }
+
+ "Renaming subcomponents that aren't leaves" should "expand in Lowering" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input pred: UInt<1>
+ | input in: {a: UInt<3>, b: UInt<3>[2]}
+ | output out: {a: UInt<3>, b: UInt<3>[2]}
+ | wire w: {a: UInt<3>, b: UInt<3>[2]}
+ | w is invalid
+ | node n = mux(pred, in, w)
+ | out <= n
+ | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk
+ |""".stripMargin
+ val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b"))
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should contain (anno("in_b_0"))
+ resultAnno should contain (anno("in_b_1"))
+ resultAnno should contain (anno("out_b_0"))
+ resultAnno should contain (anno("out_b_1"))
+ resultAnno should contain (anno("w_b_0"))
+ resultAnno should contain (anno("w_b_1"))
+ resultAnno should contain (anno("n_b_0"))
+ resultAnno should contain (anno("n_b_1"))
+ resultAnno should contain (anno("r_b_0"))
+ resultAnno should contain (anno("r_b_1"))
+ }
+
+
+ "Renaming" should "track dce" in {
+ val compiler = new VerilogCompiler
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input pred: UInt<1>
+ | input in: {a: UInt<3>, b: UInt<3>[2]}
+ | output out: {a: UInt<3>, b: UInt<3>[2]}
+ | node n = in
+ | out <= n
+ |""".stripMargin
+ val annos = Seq(
+ anno("in.a"), anno("in.b[0]"), anno("in.b[1]"),
+ anno("out.a"), anno("out.b[0]"), anno("out.b[1]"),
+ anno("n.a"), anno("n.b[0]"), anno("n.b[1]")
+ )
+ val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil)
+ val resultAnno = result.annotations.get.annotations
+ resultAnno should not contain (anno("in.a"))
+ resultAnno should not contain (anno("in.b[0]"))
+ resultAnno should not contain (anno("in.b[1]"))
+ resultAnno should not contain (anno("out.a"))
+ resultAnno should not contain (anno("out.b[0]"))
+ resultAnno should not contain (anno("out.b[1]"))
+ resultAnno should not contain (anno("n.a"))
+ resultAnno should not contain (anno("n.b[0]"))
+ resultAnno should not contain (anno("n.b[1]"))
+ resultAnno should not contain (anno("n_a"))
+ resultAnno should not contain (anno("n_b_0"))
+ resultAnno should not contain (anno("n_b_1"))
+ resultAnno should contain (anno("in_a"))
+ resultAnno should contain (anno("in_b_0"))
+ resultAnno should contain (anno("in_b_1"))
+ resultAnno should contain (anno("out_a"))
+ resultAnno should contain (anno("out_b_0"))
+ resultAnno should contain (anno("out_b_1"))
+ }
}
diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala
index 0ae112f0..fd4374f0 100644
--- a/src/test/scala/firrtlTests/ChirrtlSpec.scala
+++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala
@@ -8,9 +8,10 @@ import org.scalatest.prop._
import firrtl.Parser
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl._
class ChirrtlSpec extends FirrtlFlatSpec {
- def passes = Seq(
+ def transforms = Seq(
CheckChirrtl,
CInferTypes,
CInferMDir,
@@ -44,8 +45,9 @@ class ChirrtlSpec extends FirrtlFlatSpec {
| infer mport y = ram[UInt(4)], newClock
| y <= UInt(5)
""".stripMargin
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
@@ -63,8 +65,9 @@ class ChirrtlSpec extends FirrtlFlatSpec {
| y <= z
""".stripMargin
intercept[PassException] {
- passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
}
}
diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
index a7824087..dcaf52e3 100644
--- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala
+++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala
@@ -11,10 +11,12 @@ import firrtl.ir._
import firrtl.Parser.IgnoreInfo
class ExpandWhensSpec extends FirrtlFlatSpec {
- private def executeTest(input: String, check: String, passes: Seq[Pass], expected: Boolean) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ private def executeTest(input: String, check: String, transforms: Seq[Transform], expected: Boolean) = {
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
+ val c = result.circuit
val lines = c.serialize.split("\n") map normalized
println(c.serialize)
@@ -25,7 +27,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec {
}
}
"Expand Whens" should "not emit INVALID" in {
- val passes = Seq(
+ val transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
@@ -51,10 +53,10 @@ class ExpandWhensSpec extends FirrtlFlatSpec {
| a is invalid
| a.b <= UInt<64>("h04000000000000000")""".stripMargin
val check = "INVALID"
- executeTest(input, check, passes, false)
+ executeTest(input, check, transforms, false)
}
"Expand Whens" should "void unwritten memory fields" in {
- val passes = Seq(
+ val transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
@@ -92,7 +94,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec {
| memory.w0.clk <= clk
| """.stripMargin
val check = "VOID"
- executeTest(input, check, passes, true)
+ executeTest(input, check, transforms, true)
}
}
diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala
index 89461f12..b43df713 100644
--- a/src/test/scala/firrtlTests/LowerTypesSpec.scala
+++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala
@@ -8,9 +8,10 @@ import org.scalatest.prop._
import firrtl.Parser
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl._
class LowerTypesSpec extends FirrtlFlatSpec {
- private val passes = Seq(
+ private val transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
@@ -34,9 +35,11 @@ class LowerTypesSpec extends FirrtlFlatSpec {
LowerTypes)
private def executeTest(input: String, expected: Seq[String]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
+ val c = result.circuit
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 0831bb31..8367f152 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -65,6 +65,7 @@ circuit Top :
val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc)))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(aMap)))
// Check correctness of firrtl
+ println(res.annotations)
parse(res.getEmittedCircuit.value)
(new java.io.File(confLoc)).delete()
}
diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala
index 14c0f652..27918cc5 100644
--- a/src/test/scala/firrtlTests/UniquifySpec.scala
+++ b/src/test/scala/firrtlTests/UniquifySpec.scala
@@ -8,10 +8,11 @@ import org.scalatest.prop._
import firrtl.Parser
import firrtl.ir.Circuit
import firrtl.passes._
+import firrtl._
class UniquifySpec extends FirrtlFlatSpec {
- private val passes = Seq(
+ private val transforms = Seq(
ToWorkingIR,
CheckHighForm,
ResolveKinds,
@@ -20,9 +21,11 @@ class UniquifySpec extends FirrtlFlatSpec {
)
private def executeTest(input: String, expected: Seq[String]) = {
- val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
- (c: Circuit, p: Pass) => p.run(c)
+ val circuit = Parser.parse(input.split("\n").toIterator)
+ val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
}
+ val c = result.circuit
val lines = c.serialize.split("\n") map normalized
expected foreach { e =>
diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala
index 90926bc1..8c39dc1e 100644
--- a/src/test/scala/firrtlTests/ZeroWidthTests.scala
+++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala
@@ -11,7 +11,7 @@ import firrtl.Parser
import firrtl.passes._
class ZeroWidthTests extends FirrtlFlatSpec {
- val passes = Seq(
+ val transforms = Seq(
ToWorkingIR,
ResolveKinds,
InferTypes,
@@ -19,9 +19,10 @@ class ZeroWidthTests extends FirrtlFlatSpec {
InferWidths,
ZeroWidth)
private def exec (input: String) = {
- passes.foldLeft(parse(input)) {
- (c: Circuit, p: Pass) => p.run(c)
- }.serialize
+ val circuit = parse(input)
+ transforms.foldLeft(CircuitState(circuit, UnknownForm)) {
+ (c: CircuitState, p: Transform) => p.runTransform(c)
+ }.circuit.serialize
}
// =============================
"Zero width port" should " be deleted" in {
@@ -105,6 +106,18 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| skip""".stripMargin
(parse(exec(input)).serialize) should be (parse(check).serialize)
}
+ "IsInvalid on <0>" should "be deleted" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | output y: UInt<0>
+ | y is invalid""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | skip""".stripMargin
+ (parse(exec(input)).serialize) should be (parse(check).serialize)
+ }
"Expression in node with type <0>" should "be replaced by UInt<1>(0)" in {
val input =
"""circuit Top :