diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 50 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/DeadCodeElimination.scala | 25 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 128 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 51 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 44 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ZeroWidth.scala | 52 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/AnnotationTests.scala | 293 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ChirrtlSpec.scala | 13 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ExpandWhensSpec.scala | 16 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/LowerTypesSpec.scala | 9 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 1 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UniquifySpec.scala | 9 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ZeroWidthTests.scala | 21 |
16 files changed, 619 insertions, 99 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 6e5cadcd..6c3911d6 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -2,9 +2,10 @@ package firrtl -import logger.LazyLogging +import logger._ import java.io.Writer import annotations._ +import scala.collection.mutable import firrtl.ir.Circuit import passes.Pass @@ -14,7 +15,47 @@ import Utils.throwInternalError * RenameMap maps old names to modified names. Generated by transformations * that modify names */ -case class RenameMap(map: Map[Named, Seq[Named]] = Map[Named, Seq[Named]]()) +object RenameMap { + def apply(map: Map[Named, Seq[Named]]) = { + val rm = new RenameMap + rm.addMap(map) + rm + } + def apply() = new RenameMap +} +class RenameMap { + val renameMap = new mutable.HashMap[Named, Seq[Named]]() + private var circuitName: String = "" + private var moduleName: String = "" + def setModule(s: String) = + moduleName = s + def setCircuit(s: String) = + circuitName = s + def rename(from: String, to: String): Unit = rename(from, Seq(to)) + def rename(from: String, tos: Seq[String]): Unit = { + val fromName = ComponentName(from, ModuleName(moduleName, CircuitName(circuitName))) + val tosName = tos map { to => + ComponentName(to, ModuleName(moduleName, CircuitName(circuitName))) + } + rename(fromName, tosName) + } + def rename(from: Named, to: Named): Unit = rename(from, Seq(to)) + def rename(from: Named, tos: Seq[Named]): Unit = (from, tos) match { + case (x, Seq(y)) if x == y => + case _ => + renameMap(from) = renameMap.getOrElse(from, Seq.empty) ++ tos + } + def delete(names: Seq[String]): Unit = names.foreach(delete(_)) + def delete(name: String): Unit = + delete(ComponentName(name, ModuleName(moduleName, CircuitName(circuitName)))) + def delete(name: Named): Unit = + renameMap(name) = Seq.empty + def addMap(map: Map[Named, Seq[Named]]) = + renameMap ++= map + def serialize: String = renameMap.map { case (k, v) => + k.serialize + "=>" + v.map(_.serialize).mkString(", ") + }.mkString("\n") +} /** * Container of all annotations for a Firrtl compiler. @@ -172,7 +213,6 @@ abstract class Transform extends LazyLogging { } logger.trace(s"Circuit:\n${result.circuit.serialize}") logger.info(s"======== Finished Transform $name ========\n") - CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None) } @@ -200,7 +240,7 @@ abstract class Transform extends LazyLogging { } // For each annotation, rename all annotations. - val renames = renameOpt.getOrElse(RenameMap()).map + val renames = renameOpt.getOrElse(RenameMap()).renameMap for { anno <- newAnnotations.toSeq newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) @@ -217,8 +257,10 @@ trait SeqTransformBased { /** For transformations that are simply a sequence of transforms */ abstract class SeqTransform extends Transform with SeqTransformBased { def execute(state: CircuitState): CircuitState = { + /* require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + */ val ret = runTransforms(state) CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index b9042781..84f237a3 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -6,7 +6,6 @@ sealed abstract class CoreTransform extends SeqTransform /** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting * circuit has only IR nodes, not WIR. - * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL */ class ChirrtlToHighFirrtl extends CoreTransform { def inputForm = ChirrtlForm @@ -75,7 +74,6 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { /** Expands all aggregate types into many ground-typed components. Must * accept a well-formed graph of only middle Firrtl features. * Operates on working IR nodes. - * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL */ class MiddleFirrtlToLowFirrtl extends CoreTransform { def inputForm = MidForm diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 76156d2d..586fe1e7 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -541,7 +541,7 @@ object Utils extends LazyLogging { } /** Adds a root reference to some SubField/SubIndex chain */ - def mergeRef(root: WRef, body: Expression): Expression = body match { + def mergeRef(root: Expression, body: Expression): Expression = body match { case e: WRef => WSubField(root, e.name, e.tpe, e.gender) case e: WSubIndex => diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala index 9f249f35..54ac76fe 100644 --- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala @@ -9,8 +9,10 @@ import firrtl.Mappers._ import annotation.tailrec -object DeadCodeElimination extends Pass { - private def dceOnce(s: Statement): (Statement, Long) = { +object DeadCodeElimination extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm + private def dceOnce(renames: RenameMap)(s: Statement): (Statement, Long) = { val referenced = collection.mutable.HashSet[String]() var nEliminated = 0L @@ -28,6 +30,7 @@ object DeadCodeElimination extends Pass { if (referenced(name)) x else { nEliminated += 1 + renames.delete(name) EmptyStmt } @@ -43,16 +46,22 @@ object DeadCodeElimination extends Pass { } @tailrec - private def dce(s: Statement): Statement = { - val (res, n) = dceOnce(s) - if (n > 0) dce(res) else res + private def dce(renames: RenameMap)(s: Statement): Statement = { + val (res, n) = dceOnce(renames)(s) + if (n > 0) dce(renames)(res) else res } - def run(c: Circuit): Circuit = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => Module(m.info, m.name, m.ports, dce(m.body)) + case m: Module => + renames.setModule(m.name) + Module(m.info, m.name, m.ports, dce(renames)(m.body)) } - Circuit(c.info, modulesx, c.main) + val result = Circuit(c.info, modulesx, c.main) + CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 5826f56e..b48ab338 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -2,9 +2,11 @@ package firrtl.passes +import scala.collection.mutable import firrtl._ import firrtl.ir._ import firrtl.Utils._ +import MemPortUtils.memType import firrtl.Mappers._ /** Removes all aggregate types from a [[firrtl.ir.Circuit]] @@ -20,7 +22,10 @@ import firrtl.Mappers._ * wire foo_b : UInt<16> * }}} */ -object LowerTypes extends Pass { +object LowerTypes extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm + /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name @@ -33,7 +38,49 @@ object LowerTypes extends Pass { case e: WSubIndex => s"${loweredName(e.exp)}$delim${e.value}" } def loweredName(s: Seq[String]): String = s mkString delim + def renameExps(renames: RenameMap, n: String, t: Type, root: String): Seq[String] = + renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), root) + def renameExps(renames: RenameMap, n: String, t: Type): Seq[String] = + renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), "") + def renameExps(renames: RenameMap, e: Expression, root: String): Seq[String] = e.tpe match { + case (_: GroundType) => + val name = root + loweredName(e) + renames.rename(e.serialize, name) + Seq(name) + case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) => + val subNames = renameExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), root) + renames.rename(e.serialize, subNames) + names ++ subNames + } + case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) => + val subNames = renameExps(renames, WSubIndex(e, i, t.tpe,gender(e)), root) + renames.rename(e.serialize, subNames) + names ++ subNames + } + } + private def renameMemExps(renames: RenameMap, e: Expression, portAndField: Expression): Seq[String] = e.tpe match { + case (_: GroundType) => + val (mem, tail) = splitRef(e) + val loRef = mergeRef(WRef(loweredName(e)), portAndField) + val hiRef = mergeRef(mem, mergeRef(portAndField, tail)) + renames.rename(hiRef.serialize, loRef.serialize) + Seq(loRef.serialize) + case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) => + val subNames = renameMemExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), portAndField) + val (mem, tail) = splitRef(e) + val hiRef = mergeRef(mem, mergeRef(portAndField, tail)) + renames.rename(hiRef.serialize, subNames) + names ++ subNames + } + case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) => + val subNames = renameMemExps(renames, WSubIndex(e, i, t.tpe,gender(e)), portAndField) + val (mem, tail) = splitRef(e) + val hiRef = mergeRef(mem, mergeRef(portAndField, tail)) + renames.rename(hiRef.serialize, subNames) + names ++ subNames + } + } private case class LowerTypesException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(info: Info, mname: String) = throw LowerTypesException(s"$info: [module $mname] $msg") @@ -111,34 +158,43 @@ object LowerTypes extends Pass { case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) case e @ (_: UIntLiteral | _: SIntLiteral) => e } - def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, - minfo: Info, mname: String)(s: Statement): Statement = { + minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = { val info = get_info(s) match {case NoInfo => minfo case x => x} - s map lowerTypesStmt(memDataTypeMap, info, mname) match { + s map lowerTypesStmt(memDataTypeMap, info, mname, renames) match { case s: DefWire => s.tpe match { case _: GroundType => s - case _ => Block(create_exps(s.name, s.tpe) map ( - e => DefWire(s.info, loweredName(e), e.tpe))) + case _ => + val exps = create_exps(s.name, s.tpe) + val names = exps map loweredName + renameExps(renames, s.name, s.tpe) + Block((exps zip names) map { case (e, n) => + DefWire(s.info, n, e.tpe) + }) } case sx: DefRegister => sx.tpe match { case _: GroundType => sx map lowerTypesExp(memDataTypeMap, info, mname) case _ => val es = create_exps(sx.name, sx.tpe) + val names = es map loweredName + renameExps(renames, sx.name, sx.tpe) val inits = create_exps(sx.init) map lowerTypesExp(memDataTypeMap, info, mname) val clock = lowerTypesExp(memDataTypeMap, info, mname)(sx.clock) val reset = lowerTypesExp(memDataTypeMap, info, mname)(sx.reset) - Block(es zip inits map { case (e, i) => - DefRegister(sx.info, loweredName(e), e.tpe, clock, reset, i) + Block((es zip names) zip inits map { case ((e, n), i) => + DefRegister(sx.info, n, e.tpe, clock, reset, i) }) } // Could instead just save the type of each Module as it gets processed case sx: WDefInstance => sx.tpe match { case t: BundleType => - val fieldsx = t.fields flatMap (f => - create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map ( + val fieldsx = t.fields flatMap { f => + renameExps(renames, f.name, sx.tpe, s"${sx.name}.") + create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map { e => // Flip because inst genders are reversed from Module type - e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe))) + Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) + } + } WDefInstance(sx.info, sx.name, sx.module, BundleType(fieldsx)) case _ => error("WDefInstance type should be Bundle!")(info, mname) } @@ -146,8 +202,30 @@ object LowerTypes extends Pass { memDataTypeMap(sx.name) = sx.dataType sx.dataType match { case _: GroundType => sx - case _ => Block(create_exps(sx.name, sx.dataType) map (e => - sx copy (name = loweredName(e), dataType = e.tpe))) + case _ => + // Rename ports + val seen: mutable.Set[String] = mutable.Set[String]() + create_exps(sx.name, memType(sx)) foreach { e => + val (mem, port, field, tail) = splitMemRef(e) + if (!seen.contains(field.name)) { + seen += field.name + val d = WRef(mem.name, sx.dataType) + tail match { + case None => + create_exps(mem.name, sx.dataType) foreach { x => + renames.rename(e.serialize, s"${loweredName(x)}.${port.serialize}.${field.serialize}") + } + case Some(_) => + renameMemExps(renames, d, mergeRef(port, field)) + } + } + } + Block(create_exps(sx.name, sx.dataType) map {e => + val newName = loweredName(e) + // Rename mems + renames.rename(sx.name, newName) + sx copy (name = newName, dataType = e.tpe) + }) } // wire foo : { a , b } // node x = foo @@ -159,7 +237,10 @@ object LowerTypes extends Pass { case sx: DefNode => val names = create_exps(sx.name, sx.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname) val exps = create_exps(sx.value) map lowerTypesExp(memDataTypeMap, info, mname) - Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) }) + renameExps(renames, sx.name, sx.value.tpe) + Block(names zip exps map { case (n, e) => + DefNode(info, loweredName(n), e) + }) case sx: IsInvalid => kind(sx.expr) match { case MemKind => Block(lowerTypesMemExp(memDataTypeMap, info, mname)(sx.expr) map (IsInvalid(info, _))) @@ -176,21 +257,32 @@ object LowerTypes extends Pass { } } - def lowerTypes(m: DefModule): DefModule = { + def lowerTypes(renames: RenameMap)(m: DefModule): DefModule = { val memDataTypeMap = new MemDataTypeMap + renames.setModule(m.name) // Lower Ports val portsx = m.ports flatMap { p => val exps = create_exps(WRef(p.name, p.tpe, PortKind, to_gender(p.direction))) - exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)) + val names = exps map loweredName + renameExps(renames, p.name, p.tpe) + (exps zip names) map { case (e, n) => + Port(p.info, n, to_dir(gender(e)), e.tpe) + } } m match { case m: ExtModule => m copy (ports = portsx) case m: Module => - m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name) + m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name, renames) } } - def run(c: Circuit): Circuit = c copy (modules = c.modules map lowerTypes) + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) + val result = c copy (modules = c.modules map lowerTypes(renames)) + CircuitState(result, outputForm, state.annotations, Some(renames)) + } } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index b072dfa0..c841dc32 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -14,7 +14,9 @@ case class MPort(name: String, clk: Expression) case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) -object RemoveCHIRRTL extends Pass { +object RemoveCHIRRTL extends Transform { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm val ut = UnknownType type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] type SeqMemSet = collection.mutable.HashSet[String] @@ -22,6 +24,15 @@ object RemoveCHIRRTL extends Pass { type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef] type AddrMap = collection.mutable.HashMap[String, Expression] + def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match { + case _: GroundType => Seq(ex) + case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex) + case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex) + case UnknownType => Seq(ex) + } + def create_exps(e: Expression): Seq[Expression] = e match { case ex: Mux => val e1s = create_exps(ex.tval) @@ -59,7 +70,7 @@ object RemoveCHIRRTL extends Pass { } def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap, - refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match { + refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match { case sx: CDefMemory => types(sx.name) = sx.tpe val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size))) @@ -104,11 +115,25 @@ object RemoveCHIRRTL extends Pass { addrs += "addr" clks += "clk" ens += "en" + renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata") + renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata") + val es = create_all_exps(WRef(sx.name, sx.tpe)) + val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe)) + val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe)) + ((es zip rs) zip ws) map { + case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) + } case MWrite => refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false) addrs += "addr" clks += "clk" ens += "en" + renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") + val es = create_all_exps(WRef(sx.name, sx.tpe)) + val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) + (es zip ws) map { + case (e, w) => renames.rename(e.serialize, w.serialize) + } case MRead => refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false) addrs += "addr" @@ -118,13 +143,19 @@ object RemoveCHIRRTL extends Pass { raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut) case _ => ens += "en" } + renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") + val es = create_all_exps(WRef(sx.name, sx.tpe)) + val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) + (es zip rs) map { + case (e, r) => renames.rename(e.serialize, r.serialize) + } case MInfer => // do nothing if it's not being used } Block( (addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++ (clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++ (ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one)))) - case sx => sx map collect_refs(mports, smems, types, refs, raddrs) + case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames) } def get_mask(refs: DataRefMap)(e: Expression): Expression = @@ -213,17 +244,23 @@ object RemoveCHIRRTL extends Pass { } } - def remove_chirrtl_m(m: DefModule): DefModule = { + def remove_chirrtl_m(renames: RenameMap)(m: DefModule): DefModule = { val mports = new MPortMap val smems = new SeqMemSet val types = new MPortTypeMap val refs = new DataRefMap val raddrs = new AddrMap + renames.setModule(m.name) (m map collect_smems_and_mports(mports, smems) - map collect_refs(mports, smems, types, refs, raddrs) + map collect_refs(mports, smems, types, refs, raddrs, renames) map remove_chirrtl_s(refs, raddrs)) } - def run(c: Circuit): Circuit = - c copy (modules = c.modules map remove_chirrtl_m) + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) + val result = c copy (modules = c.modules map remove_chirrtl_m(renames)) + CircuitState(result, outputForm, state.annotations, Some(renames)) + } } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index deddb93e..61bd68d0 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -31,7 +31,9 @@ import MemPortUtils.memType * there WOULD be collisions in references a[0] and a_0 so we still have * to rename a */ -object Uniquify extends Pass { +object Uniquify extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm private case class UniquifyException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = throw new UniquifyException(s"$sinfo: [module $mname] $msg") @@ -224,7 +226,10 @@ object Uniquify extends Pass { } // Everything wrapped in run so that it's thread safe - def run(c: Circuit): Circuit = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) // Debug state implicit var mname: String = "" implicit var sinfo: Info = NoInfo @@ -232,7 +237,8 @@ object Uniquify extends Pass { val portNameMap = collection.mutable.HashMap[String, Map[String, NameMapNode]]() val portTypeMap = collection.mutable.HashMap[String, Type]() - def uniquifyModule(m: DefModule): DefModule = { + def uniquifyModule(renames: RenameMap)(m: DefModule): DefModule = { + renames.setModule(m.name) val namespace = collection.mutable.HashSet[String]() val nameMap = collection.mutable.HashMap[String, NameMapNode]() @@ -251,7 +257,11 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - DefWire(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts)) + val newType = uniquifyNamesType(sx.tpe, node.elts) + (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + DefWire(sx.info, node.name, newType) } else { sx } @@ -259,8 +269,11 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - DefRegister(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts), - sx.clock, sx.reset, sx.init) + val newType = uniquifyNamesType(sx.tpe, node.elts) + (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init) } else { sx } @@ -268,7 +281,11 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - WDefInstance(sx.info, node.name, sx.module, sx.tpe) + val newType = portTypeMap(m.name) + (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + WDefInstance(sx.info, node.name, sx.module, newType) } else { sx } @@ -280,6 +297,9 @@ object Uniquify extends Pass { val mem = sx.copy(name = node.name, dataType = dataType) // Create new mapping to handle references to memory data fields val uniqueMemMap = createNameMapping(memType(sx), memType(mem)) + (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap) mem } else { @@ -289,6 +309,9 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) + (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } DefNode(sx.info, node.name, sx.value) } else { sx @@ -320,7 +343,8 @@ object Uniquify extends Pass { } } - def uniquifyPorts(m: DefModule): DefModule = { + def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = { + renames.setModule(m.name) def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { val portsType = BundleType(ports map { case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) @@ -331,6 +355,7 @@ object Uniquify extends Pass { portTypeMap += (m.name -> uniquePortsType) ports zip uniquePortsType.fields map { case (p, f) => + renames.rename(p.name, f.name) Port(p.info, f.name, p.direction, f.tpe) } } @@ -344,7 +369,8 @@ object Uniquify extends Pass { } sinfo = c.info - Circuit(c.info, c.modules map uniquifyPorts map uniquifyModule, c.main) + val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main) + CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index a2ec9935..2472a0e5 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -10,8 +10,24 @@ import firrtl.Mappers._ import firrtl.Utils.throwInternalError -object ZeroWidth extends Pass { +object ZeroWidth extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm private val ZERO = BigInt(0) + private def getRemoved(x: IsDeclaration): Seq[String] = { + var removedNames: Seq[String] = Seq.empty + def onType(name: String)(t: Type): Type = { + removedNames = Utils.create_exps(name, t) map {e => (e, e.tpe)} collect { + case (e, GroundType(IntWidth(ZERO))) => e.serialize + } + t + } + x match { + case s: Statement => s map onType(s.name) + case Port(_, name, _, t) => onType(name)(t) + } + removedNames + } private def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None case BundleType(fields) => @@ -34,35 +50,53 @@ object ZeroWidth extends Pass { def replaceType(x: Type): Type = t (e map replaceType) map onExp } - private def onStmt(s: Statement): Statement = s match { + private def onStmt(renames: RenameMap)(s: Statement): Statement = s match { case (_: DefWire| _: DefRegister| _: DefMemory) => + // List all removed expression names, and delete them from renames + renames.delete(getRemoved(s.asInstanceOf[IsDeclaration])) + // Create new types without zero-width wires var removed = false def applyRemoveZero(t: Type): Type = removeZero(t) match { case None => removed = true; t case Some(tx) => tx } val sxx = (s map onExp) map applyRemoveZero + // Return new declaration if(removed) EmptyStmt else sxx case Connect(info, loc, exp) => removeZero(loc.tpe) match { case None => EmptyStmt case Some(t) => Connect(info, loc, onExp(exp)) } + case IsInvalid(info, exp) => removeZero(exp.tpe) match { + case None => EmptyStmt + case Some(t) => IsInvalid(info, onExp(exp)) + } case DefNode(info, name, value) => removeZero(value.tpe) match { case None => EmptyStmt case Some(t) => DefNode(info, name, onExp(value)) } - case sx => sx map onStmt + case sx => sx map onStmt(renames) } - private def onModule(m: DefModule): DefModule = { - val ports = m.ports map (p => (p, removeZero(p.tpe))) collect { - case (Port(info, name, dir, _), Some(t)) => Port(info, name, dir, t) + private def onModule(renames: RenameMap)(m: DefModule): DefModule = { + renames.setModule(m.name) + // For each port, record deleted subcomponents + m.ports.foreach{p => renames.delete(getRemoved(p))} + val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap { + case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t)) + case (Port(_, name, _, _), None) => + renames.delete(name) + Nil } m match { case ext: ExtModule => ext.copy(ports = ports) - case in: Module => in.copy(ports = ports, body = onStmt(in.body)) + case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body)) } } - def run(c: Circuit): Circuit = { - InferTypes.run(c.copy(modules = c.modules map onModule)) + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) + val result = InferTypes.run(c.copy(modules = c.modules map onModule(renames))) + CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index caaf430b..8cbf9da7 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -131,7 +131,7 @@ class ReplSeqMem extends Transform { new SimpleMidTransform(RemoveEmpty), new SimpleMidTransform(CheckInitialization), new SimpleMidTransform(InferTypes), - new SimpleMidTransform(Uniquify), + Uniquify, new SimpleMidTransform(ResolveKinds), new SimpleMidTransform(ResolveGenders)) diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 81d982e4..81c394c1 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -11,6 +11,7 @@ import firrtl.passes.InlineAnnotation import firrtl.passes.memlib.PinAnnotation import net.jcazevedo.moultingyaml._ import org.scalatest.Matchers +import logger._ /** * An example methodology for testing Firrtl annotations. @@ -25,8 +26,8 @@ trait AnnotationSpec extends LowTransformSpec { compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty) } } - def execute(annotations: AnnotationMap, input: String, check: Annotation): Unit = { - val cr = compile(CircuitState(parse(input), ChirrtlForm, Some(annotations)), Seq.empty) + def execute(aMap: Option[AnnotationMap], input: String, check: Annotation): Unit = { + val cr = compile(CircuitState(parse(input), ChirrtlForm, aMap), Seq.empty) cr.annotations.get.annotations should contain (check) } } @@ -40,20 +41,19 @@ trait AnnotationSpec extends LowTransformSpec { * Unstable, Fickle, and Insistent can be tested. */ class AnnotationTests extends AnnotationSpec with Matchers { - def getAMap (a: Annotation): AnnotationMap = AnnotationMap(Seq(a)) - val input: String = - """circuit Top : - | module Top : - | input a : UInt<1>[2] - | input b : UInt<1> - | node c = b""".stripMargin - val mName = ModuleName("Top", CircuitName("Top")) - val aName = ComponentName("a", mName) - val bName = ComponentName("b", mName) - val cName = ComponentName("c", mName) + def getAMap(a: Annotation): Option[AnnotationMap] = Some(AnnotationMap(Seq(a))) + def getAMap(as: Seq[Annotation]): Option[AnnotationMap] = Some(AnnotationMap(as)) + def anno(s: String, value: String ="this is a value"): Annotation = + Annotation(ComponentName(s, ModuleName("Top", CircuitName("Top"))), classOf[Transform], value) "Loose and Sticky annotation on a node" should "pass through" in { - val ta = Annotation(cName, classOf[Transform], "") + val input: String = + """circuit Top : + | module Top : + | input a : UInt<1>[2] + | input b : UInt<1> + | node c = b""".stripMargin + val ta = anno("c", "") execute(getAMap(ta), input, ta) } @@ -134,11 +134,10 @@ class AnnotationTests extends AnnotationSpec with Matchers { val outputForm = LowForm def execute(state: CircuitState) = state.copy(annotations = None) } - val anno = InlineAnnotation(CircuitName("Top")) - val annoOpt = Some(AnnotationMap(Seq(anno))) - val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, annoOpt), Seq(new DeletingTransform)) + val inlineAnn = InlineAnnotation(CircuitName("Top")) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(inlineAnn)), Seq(new DeletingTransform)) result.annotations.get.annotations.head should matchPattern { - case DeletedAnnotation(x, anno) => + case DeletedAnnotation(x, inlineAnn) => } val exception = (intercept[FIRRTLException] { result.getEmittedCircuit @@ -146,4 +145,262 @@ class AnnotationTests extends AnnotationSpec with Matchers { val deleted = result.deletedAnnotations exception.str should be (s"No EmittedCircuit found! Did you delete any annotations?\n$deleted") } + + "Renaming" should "propagate in Lowering of memories" in { + val compiler = new VerilogCompiler + // Uncomment to help debugging failing tests + // Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug)) + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input in: UInt<3> + | mem m: + | data-type => {a: UInt<4>, b: UInt<4>[2]} + | depth => 8 + | write-latency => 1 + | read-latency => 0 + | reader => r + | m.r.clk <= clk + | m.r.en <= UInt(1) + | m.r.addr <= in + |""".stripMargin + val annos = Seq(anno("m.r.data.b", "sub"), anno("m.r.data", "all"), anno("m", "mem")) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should contain (anno("m_a", "mem")) + resultAnno should contain (anno("m_b_0", "mem")) + resultAnno should contain (anno("m_b_1", "mem")) + resultAnno should contain (anno("m_a.r.data", "all")) + resultAnno should contain (anno("m_b_0.r.data", "all")) + resultAnno should contain (anno("m_b_1.r.data", "all")) + resultAnno should contain (anno("m_b_0.r.data", "sub")) + resultAnno should contain (anno("m_b_1.r.data", "sub")) + resultAnno should not contain (anno("m")) + resultAnno should not contain (anno("r")) + } + + "Renaming" should "propagate in RemoveChirrtl and Lowering of memories" in { + val compiler = new VerilogCompiler + Logger.setClassLogLevels(Map(compiler.getClass.getName -> LogLevel.Debug)) + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input in: UInt<3> + | cmem m: {a: UInt<4>, b: UInt<4>[2]}[8] + | read mport r = m[in], clk + |""".stripMargin + val annos = Seq(anno("r.b", "sub"), anno("r", "all"), anno("m", "mem")) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should contain (anno("m_a", "mem")) + resultAnno should contain (anno("m_b_0", "mem")) + resultAnno should contain (anno("m_b_1", "mem")) + resultAnno should contain (anno("m_a.r.data", "all")) + resultAnno should contain (anno("m_b_0.r.data", "all")) + resultAnno should contain (anno("m_b_1.r.data", "all")) + resultAnno should contain (anno("m_b_0.r.data", "sub")) + resultAnno should contain (anno("m_b_1.r.data", "sub")) + resultAnno should not contain (anno("m")) + resultAnno should not contain (anno("r")) + } + + "Renaming" should "propagate in ZeroWidth" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Top : + | input zero: UInt<0> + | wire x: {a: UInt<3>, b: UInt<0>} + | wire y: UInt<0>[3] + | y[0] <= zero + | y[1] <= zero + | y[2] <= zero + | x.a <= zero + | x.b <= zero + |""".stripMargin + val annos = Seq(anno("zero"), anno("x.a"), anno("x.b"), anno("y[0]"), anno("y[1]"), anno("y[2]")) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should contain (anno("x_a")) + resultAnno should not contain (anno("zero")) + resultAnno should not contain (anno("x.a")) + resultAnno should not contain (anno("x.b")) + resultAnno should not contain (anno("x_b")) + resultAnno should not contain (anno("y[0]")) + resultAnno should not contain (anno("y[1]")) + resultAnno should not contain (anno("y[2]")) + resultAnno should not contain (anno("y_0")) + resultAnno should not contain (anno("y_1")) + resultAnno should not contain (anno("y_2")) + } + + "Renaming subcomponents" should "propagate in Lowering" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input pred: UInt<1> + | input in: {a: UInt<3>, b: UInt<3>[2]} + | output out: {a: UInt<3>, b: UInt<3>[2]} + | wire w: {a: UInt<3>, b: UInt<3>[2]} + | w is invalid + | node n = mux(pred, in, w) + | out <= n + | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk + | cmem mem: {a: UInt<3>, b: UInt<3>[2]}[8] + | write mport write = mem[pred], clk + | write <= in + |""".stripMargin + val annos = Seq( + anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), + anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), + anno("w.a"), anno("w.b[0]"), anno("w.b[1]"), + anno("n.a"), anno("n.b[0]"), anno("n.b[1]"), + anno("r.a"), anno("r.b[0]"), anno("r.b[1]"), + anno("write.a"), anno("write.b[0]"), anno("write.b[1]") + ) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should not contain (anno("in.a")) + resultAnno should not contain (anno("in.b[0]")) + resultAnno should not contain (anno("in.b[1]")) + resultAnno should not contain (anno("out.a")) + resultAnno should not contain (anno("out.b[0]")) + resultAnno should not contain (anno("out.b[1]")) + resultAnno should not contain (anno("w.a")) + resultAnno should not contain (anno("w.b[0]")) + resultAnno should not contain (anno("w.b[1]")) + resultAnno should not contain (anno("n.a")) + resultAnno should not contain (anno("n.b[0]")) + resultAnno should not contain (anno("n.b[1]")) + resultAnno should not contain (anno("r.a")) + resultAnno should not contain (anno("r.b[0]")) + resultAnno should not contain (anno("r.b[1]")) + resultAnno should contain (anno("in_a")) + resultAnno should contain (anno("in_b_0")) + resultAnno should contain (anno("in_b_1")) + resultAnno should contain (anno("out_a")) + resultAnno should contain (anno("out_b_0")) + resultAnno should contain (anno("out_b_1")) + resultAnno should contain (anno("w_a")) + resultAnno should contain (anno("w_b_0")) + resultAnno should contain (anno("w_b_1")) + resultAnno should contain (anno("n_a")) + resultAnno should contain (anno("n_b_0")) + resultAnno should contain (anno("n_b_1")) + resultAnno should contain (anno("r_a")) + resultAnno should contain (anno("r_b_0")) + resultAnno should contain (anno("r_b_1")) + resultAnno should contain (anno("mem_a.write.data")) + resultAnno should contain (anno("mem_b_0.write.data")) + resultAnno should contain (anno("mem_b_1.write.data")) + } + + "Renaming components" should "expand in Lowering" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input pred: UInt<1> + | input in: {a: UInt<3>, b: UInt<3>[2]} + | output out: {a: UInt<3>, b: UInt<3>[2]} + | wire w: {a: UInt<3>, b: UInt<3>[2]} + | w is invalid + | node n = mux(pred, in, w) + | out <= n + | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk + |""".stripMargin + val annos = Seq(anno("in"), anno("out"), anno("w"), anno("n"), anno("r")) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should contain (anno("in_a")) + resultAnno should contain (anno("in_b_0")) + resultAnno should contain (anno("in_b_1")) + resultAnno should contain (anno("out_a")) + resultAnno should contain (anno("out_b_0")) + resultAnno should contain (anno("out_b_1")) + resultAnno should contain (anno("w_a")) + resultAnno should contain (anno("w_b_0")) + resultAnno should contain (anno("w_b_1")) + resultAnno should contain (anno("n_a")) + resultAnno should contain (anno("n_b_0")) + resultAnno should contain (anno("n_b_1")) + resultAnno should contain (anno("r_a")) + resultAnno should contain (anno("r_b_0")) + resultAnno should contain (anno("r_b_1")) + } + + "Renaming subcomponents that aren't leaves" should "expand in Lowering" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input pred: UInt<1> + | input in: {a: UInt<3>, b: UInt<3>[2]} + | output out: {a: UInt<3>, b: UInt<3>[2]} + | wire w: {a: UInt<3>, b: UInt<3>[2]} + | w is invalid + | node n = mux(pred, in, w) + | out <= n + | reg r: {a: UInt<3>, b: UInt<3>[2]}, clk + |""".stripMargin + val annos = Seq(anno("in.b"), anno("out.b"), anno("w.b"), anno("n.b"), anno("r.b")) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should contain (anno("in_b_0")) + resultAnno should contain (anno("in_b_1")) + resultAnno should contain (anno("out_b_0")) + resultAnno should contain (anno("out_b_1")) + resultAnno should contain (anno("w_b_0")) + resultAnno should contain (anno("w_b_1")) + resultAnno should contain (anno("n_b_0")) + resultAnno should contain (anno("n_b_1")) + resultAnno should contain (anno("r_b_0")) + resultAnno should contain (anno("r_b_1")) + } + + + "Renaming" should "track dce" in { + val compiler = new VerilogCompiler + val input = + """circuit Top : + | module Top : + | input clk: Clock + | input pred: UInt<1> + | input in: {a: UInt<3>, b: UInt<3>[2]} + | output out: {a: UInt<3>, b: UInt<3>[2]} + | node n = in + | out <= n + |""".stripMargin + val annos = Seq( + anno("in.a"), anno("in.b[0]"), anno("in.b[1]"), + anno("out.a"), anno("out.b[0]"), anno("out.b[1]"), + anno("n.a"), anno("n.b[0]"), anno("n.b[1]") + ) + val result = compiler.compile(CircuitState(parse(input), ChirrtlForm, getAMap(annos)), Nil) + val resultAnno = result.annotations.get.annotations + resultAnno should not contain (anno("in.a")) + resultAnno should not contain (anno("in.b[0]")) + resultAnno should not contain (anno("in.b[1]")) + resultAnno should not contain (anno("out.a")) + resultAnno should not contain (anno("out.b[0]")) + resultAnno should not contain (anno("out.b[1]")) + resultAnno should not contain (anno("n.a")) + resultAnno should not contain (anno("n.b[0]")) + resultAnno should not contain (anno("n.b[1]")) + resultAnno should not contain (anno("n_a")) + resultAnno should not contain (anno("n_b_0")) + resultAnno should not contain (anno("n_b_1")) + resultAnno should contain (anno("in_a")) + resultAnno should contain (anno("in_b_0")) + resultAnno should contain (anno("in_b_1")) + resultAnno should contain (anno("out_a")) + resultAnno should contain (anno("out_b_0")) + resultAnno should contain (anno("out_b_1")) + } } diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index 0ae112f0..fd4374f0 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -8,9 +8,10 @@ import org.scalatest.prop._ import firrtl.Parser import firrtl.ir.Circuit import firrtl.passes._ +import firrtl._ class ChirrtlSpec extends FirrtlFlatSpec { - def passes = Seq( + def transforms = Seq( CheckChirrtl, CInferTypes, CInferMDir, @@ -44,8 +45,9 @@ class ChirrtlSpec extends FirrtlFlatSpec { | infer mport y = ram[UInt(4)], newClock | y <= UInt(5) """.stripMargin - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val circuit = Parser.parse(input.split("\n").toIterator) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } @@ -63,8 +65,9 @@ class ChirrtlSpec extends FirrtlFlatSpec { | y <= z """.stripMargin intercept[PassException] { - passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val circuit = Parser.parse(input.split("\n").toIterator) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } } } diff --git a/src/test/scala/firrtlTests/ExpandWhensSpec.scala b/src/test/scala/firrtlTests/ExpandWhensSpec.scala index a7824087..dcaf52e3 100644 --- a/src/test/scala/firrtlTests/ExpandWhensSpec.scala +++ b/src/test/scala/firrtlTests/ExpandWhensSpec.scala @@ -11,10 +11,12 @@ import firrtl.ir._ import firrtl.Parser.IgnoreInfo class ExpandWhensSpec extends FirrtlFlatSpec { - private def executeTest(input: String, check: String, passes: Seq[Pass], expected: Boolean) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + private def executeTest(input: String, check: String, transforms: Seq[Transform], expected: Boolean) = { + val circuit = Parser.parse(input.split("\n").toIterator) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } + val c = result.circuit val lines = c.serialize.split("\n") map normalized println(c.serialize) @@ -25,7 +27,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec { } } "Expand Whens" should "not emit INVALID" in { - val passes = Seq( + val transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, @@ -51,10 +53,10 @@ class ExpandWhensSpec extends FirrtlFlatSpec { | a is invalid | a.b <= UInt<64>("h04000000000000000")""".stripMargin val check = "INVALID" - executeTest(input, check, passes, false) + executeTest(input, check, transforms, false) } "Expand Whens" should "void unwritten memory fields" in { - val passes = Seq( + val transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, @@ -92,7 +94,7 @@ class ExpandWhensSpec extends FirrtlFlatSpec { | memory.w0.clk <= clk | """.stripMargin val check = "VOID" - executeTest(input, check, passes, true) + executeTest(input, check, transforms, true) } } diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 89461f12..b43df713 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -8,9 +8,10 @@ import org.scalatest.prop._ import firrtl.Parser import firrtl.ir.Circuit import firrtl.passes._ +import firrtl._ class LowerTypesSpec extends FirrtlFlatSpec { - private val passes = Seq( + private val transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, @@ -34,9 +35,11 @@ class LowerTypesSpec extends FirrtlFlatSpec { LowerTypes) private def executeTest(input: String, expected: Seq[String]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val circuit = Parser.parse(input.split("\n").toIterator) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } + val c = result.circuit val lines = c.serialize.split("\n") map normalized expected foreach { e => diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 0831bb31..8367f152 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -65,6 +65,7 @@ circuit Top : val aMap = AnnotationMap(Seq(ReplSeqMemAnnotation("-c:Top:-o:"+confLoc))) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(aMap))) // Check correctness of firrtl + println(res.annotations) parse(res.getEmittedCircuit.value) (new java.io.File(confLoc)).delete() } diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index 14c0f652..27918cc5 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -8,10 +8,11 @@ import org.scalatest.prop._ import firrtl.Parser import firrtl.ir.Circuit import firrtl.passes._ +import firrtl._ class UniquifySpec extends FirrtlFlatSpec { - private val passes = Seq( + private val transforms = Seq( ToWorkingIR, CheckHighForm, ResolveKinds, @@ -20,9 +21,11 @@ class UniquifySpec extends FirrtlFlatSpec { ) private def executeTest(input: String, expected: Seq[String]) = { - val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { - (c: Circuit, p: Pass) => p.run(c) + val circuit = Parser.parse(input.split("\n").toIterator) + val result = transforms.foldLeft(CircuitState(circuit, UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) } + val c = result.circuit val lines = c.serialize.split("\n") map normalized expected foreach { e => diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala index 90926bc1..8c39dc1e 100644 --- a/src/test/scala/firrtlTests/ZeroWidthTests.scala +++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala @@ -11,7 +11,7 @@ import firrtl.Parser import firrtl.passes._ class ZeroWidthTests extends FirrtlFlatSpec { - val passes = Seq( + val transforms = Seq( ToWorkingIR, ResolveKinds, InferTypes, @@ -19,9 +19,10 @@ class ZeroWidthTests extends FirrtlFlatSpec { InferWidths, ZeroWidth) private def exec (input: String) = { - passes.foldLeft(parse(input)) { - (c: Circuit, p: Pass) => p.run(c) - }.serialize + val circuit = parse(input) + transforms.foldLeft(CircuitState(circuit, UnknownForm)) { + (c: CircuitState, p: Transform) => p.runTransform(c) + }.circuit.serialize } // ============================= "Zero width port" should " be deleted" in { @@ -105,6 +106,18 @@ class ZeroWidthTests extends FirrtlFlatSpec { | skip""".stripMargin (parse(exec(input)).serialize) should be (parse(check).serialize) } + "IsInvalid on <0>" should "be deleted" in { + val input = + """circuit Top : + | module Top : + | output y: UInt<0> + | y is invalid""".stripMargin + val check = + """circuit Top : + | module Top : + | skip""".stripMargin + (parse(exec(input)).serialize) should be (parse(check).serialize) + } "Expression in node with type <0>" should "be replaced by UInt<1>(0)" in { val input = """circuit Top : |
