diff options
| author | Adam Izraelevitz | 2017-05-10 11:23:18 -0700 |
|---|---|---|
| committer | GitHub | 2017-05-10 11:23:18 -0700 |
| commit | 8b8eb4eac5b353d4d632065c78faf6a706d6aae8 (patch) | |
| tree | 39e2d9344166b61b376df9d3cd15a4787bcd01f4 /src/main | |
| parent | af222c1737fa72fce964190876346bdb7ff220cd (diff) | |
Update rename2 (#478)
* Added pass name to debug logger
* Addresses #459. Rewords transform annotations API.
Now, any annotation not propagated by a transform is considered deleted.
A new DeletedAnnotation is added in place of it.
* Added more stylized debugging style
* WIP: make pass transform
* WIP: All tests pass, need to pull master
* Cleaned up PR
* Added rename updates to all core transforms
* Added more rename tests, and bugfixes
* Renaming tracks non-leaf subfields
E.g. given:
wire x: {a: UInt<1>, b: UInt<1>[2]}
Annotating x.b will eventually annotate x_b_0 and x_b_1
* Bugfix instance rename lowering broken
* Address review comments
* Remove check for seqTransform, UnknownForm too restrictive check
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/Compiler.scala | 50 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/DeadCodeElimination.scala | 25 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 128 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 51 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 44 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ZeroWidth.scala | 52 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala | 2 |
9 files changed, 297 insertions, 59 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 6e5cadcd..6c3911d6 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -2,9 +2,10 @@ package firrtl -import logger.LazyLogging +import logger._ import java.io.Writer import annotations._ +import scala.collection.mutable import firrtl.ir.Circuit import passes.Pass @@ -14,7 +15,47 @@ import Utils.throwInternalError * RenameMap maps old names to modified names. Generated by transformations * that modify names */ -case class RenameMap(map: Map[Named, Seq[Named]] = Map[Named, Seq[Named]]()) +object RenameMap { + def apply(map: Map[Named, Seq[Named]]) = { + val rm = new RenameMap + rm.addMap(map) + rm + } + def apply() = new RenameMap +} +class RenameMap { + val renameMap = new mutable.HashMap[Named, Seq[Named]]() + private var circuitName: String = "" + private var moduleName: String = "" + def setModule(s: String) = + moduleName = s + def setCircuit(s: String) = + circuitName = s + def rename(from: String, to: String): Unit = rename(from, Seq(to)) + def rename(from: String, tos: Seq[String]): Unit = { + val fromName = ComponentName(from, ModuleName(moduleName, CircuitName(circuitName))) + val tosName = tos map { to => + ComponentName(to, ModuleName(moduleName, CircuitName(circuitName))) + } + rename(fromName, tosName) + } + def rename(from: Named, to: Named): Unit = rename(from, Seq(to)) + def rename(from: Named, tos: Seq[Named]): Unit = (from, tos) match { + case (x, Seq(y)) if x == y => + case _ => + renameMap(from) = renameMap.getOrElse(from, Seq.empty) ++ tos + } + def delete(names: Seq[String]): Unit = names.foreach(delete(_)) + def delete(name: String): Unit = + delete(ComponentName(name, ModuleName(moduleName, CircuitName(circuitName)))) + def delete(name: Named): Unit = + renameMap(name) = Seq.empty + def addMap(map: Map[Named, Seq[Named]]) = + renameMap ++= map + def serialize: String = renameMap.map { case (k, v) => + k.serialize + "=>" + v.map(_.serialize).mkString(", ") + }.mkString("\n") +} /** * Container of all annotations for a Firrtl compiler. @@ -172,7 +213,6 @@ abstract class Transform extends LazyLogging { } logger.trace(s"Circuit:\n${result.circuit.serialize}") logger.info(s"======== Finished Transform $name ========\n") - CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None) } @@ -200,7 +240,7 @@ abstract class Transform extends LazyLogging { } // For each annotation, rename all annotations. - val renames = renameOpt.getOrElse(RenameMap()).map + val renames = renameOpt.getOrElse(RenameMap()).renameMap for { anno <- newAnnotations.toSeq newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target))) @@ -217,8 +257,10 @@ trait SeqTransformBased { /** For transformations that are simply a sequence of transforms */ abstract class SeqTransform extends Transform with SeqTransformBased { def execute(state: CircuitState): CircuitState = { + /* require(state.form <= inputForm, s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + */ val ret = runTransforms(state) CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames) } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index b9042781..84f237a3 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -6,7 +6,6 @@ sealed abstract class CoreTransform extends SeqTransform /** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting * circuit has only IR nodes, not WIR. - * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL */ class ChirrtlToHighFirrtl extends CoreTransform { def inputForm = ChirrtlForm @@ -75,7 +74,6 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { /** Expands all aggregate types into many ground-typed components. Must * accept a well-formed graph of only middle Firrtl features. * Operates on working IR nodes. - * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL */ class MiddleFirrtlToLowFirrtl extends CoreTransform { def inputForm = MidForm diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 76156d2d..586fe1e7 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -541,7 +541,7 @@ object Utils extends LazyLogging { } /** Adds a root reference to some SubField/SubIndex chain */ - def mergeRef(root: WRef, body: Expression): Expression = body match { + def mergeRef(root: Expression, body: Expression): Expression = body match { case e: WRef => WSubField(root, e.name, e.tpe, e.gender) case e: WSubIndex => diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala index 9f249f35..54ac76fe 100644 --- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala @@ -9,8 +9,10 @@ import firrtl.Mappers._ import annotation.tailrec -object DeadCodeElimination extends Pass { - private def dceOnce(s: Statement): (Statement, Long) = { +object DeadCodeElimination extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm + private def dceOnce(renames: RenameMap)(s: Statement): (Statement, Long) = { val referenced = collection.mutable.HashSet[String]() var nEliminated = 0L @@ -28,6 +30,7 @@ object DeadCodeElimination extends Pass { if (referenced(name)) x else { nEliminated += 1 + renames.delete(name) EmptyStmt } @@ -43,16 +46,22 @@ object DeadCodeElimination extends Pass { } @tailrec - private def dce(s: Statement): Statement = { - val (res, n) = dceOnce(s) - if (n > 0) dce(res) else res + private def dce(renames: RenameMap)(s: Statement): Statement = { + val (res, n) = dceOnce(renames)(s) + if (n > 0) dce(renames)(res) else res } - def run(c: Circuit): Circuit = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) val modulesx = c.modules.map { case m: ExtModule => m - case m: Module => Module(m.info, m.name, m.ports, dce(m.body)) + case m: Module => + renames.setModule(m.name) + Module(m.info, m.name, m.ports, dce(renames)(m.body)) } - Circuit(c.info, modulesx, c.main) + val result = Circuit(c.info, modulesx, c.main) + CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 5826f56e..b48ab338 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -2,9 +2,11 @@ package firrtl.passes +import scala.collection.mutable import firrtl._ import firrtl.ir._ import firrtl.Utils._ +import MemPortUtils.memType import firrtl.Mappers._ /** Removes all aggregate types from a [[firrtl.ir.Circuit]] @@ -20,7 +22,10 @@ import firrtl.Mappers._ * wire foo_b : UInt<16> * }}} */ -object LowerTypes extends Pass { +object LowerTypes extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm + /** Delimiter used in lowering names */ val delim = "_" /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name @@ -33,7 +38,49 @@ object LowerTypes extends Pass { case e: WSubIndex => s"${loweredName(e.exp)}$delim${e.value}" } def loweredName(s: Seq[String]): String = s mkString delim + def renameExps(renames: RenameMap, n: String, t: Type, root: String): Seq[String] = + renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), root) + def renameExps(renames: RenameMap, n: String, t: Type): Seq[String] = + renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), "") + def renameExps(renames: RenameMap, e: Expression, root: String): Seq[String] = e.tpe match { + case (_: GroundType) => + val name = root + loweredName(e) + renames.rename(e.serialize, name) + Seq(name) + case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) => + val subNames = renameExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), root) + renames.rename(e.serialize, subNames) + names ++ subNames + } + case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) => + val subNames = renameExps(renames, WSubIndex(e, i, t.tpe,gender(e)), root) + renames.rename(e.serialize, subNames) + names ++ subNames + } + } + private def renameMemExps(renames: RenameMap, e: Expression, portAndField: Expression): Seq[String] = e.tpe match { + case (_: GroundType) => + val (mem, tail) = splitRef(e) + val loRef = mergeRef(WRef(loweredName(e)), portAndField) + val hiRef = mergeRef(mem, mergeRef(portAndField, tail)) + renames.rename(hiRef.serialize, loRef.serialize) + Seq(loRef.serialize) + case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) => + val subNames = renameMemExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), portAndField) + val (mem, tail) = splitRef(e) + val hiRef = mergeRef(mem, mergeRef(portAndField, tail)) + renames.rename(hiRef.serialize, subNames) + names ++ subNames + } + case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) => + val subNames = renameMemExps(renames, WSubIndex(e, i, t.tpe,gender(e)), portAndField) + val (mem, tail) = splitRef(e) + val hiRef = mergeRef(mem, mergeRef(portAndField, tail)) + renames.rename(hiRef.serialize, subNames) + names ++ subNames + } + } private case class LowerTypesException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(info: Info, mname: String) = throw LowerTypesException(s"$info: [module $mname] $msg") @@ -111,34 +158,43 @@ object LowerTypes extends Pass { case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname) case e @ (_: UIntLiteral | _: SIntLiteral) => e } - def lowerTypesStmt(memDataTypeMap: MemDataTypeMap, - minfo: Info, mname: String)(s: Statement): Statement = { + minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = { val info = get_info(s) match {case NoInfo => minfo case x => x} - s map lowerTypesStmt(memDataTypeMap, info, mname) match { + s map lowerTypesStmt(memDataTypeMap, info, mname, renames) match { case s: DefWire => s.tpe match { case _: GroundType => s - case _ => Block(create_exps(s.name, s.tpe) map ( - e => DefWire(s.info, loweredName(e), e.tpe))) + case _ => + val exps = create_exps(s.name, s.tpe) + val names = exps map loweredName + renameExps(renames, s.name, s.tpe) + Block((exps zip names) map { case (e, n) => + DefWire(s.info, n, e.tpe) + }) } case sx: DefRegister => sx.tpe match { case _: GroundType => sx map lowerTypesExp(memDataTypeMap, info, mname) case _ => val es = create_exps(sx.name, sx.tpe) + val names = es map loweredName + renameExps(renames, sx.name, sx.tpe) val inits = create_exps(sx.init) map lowerTypesExp(memDataTypeMap, info, mname) val clock = lowerTypesExp(memDataTypeMap, info, mname)(sx.clock) val reset = lowerTypesExp(memDataTypeMap, info, mname)(sx.reset) - Block(es zip inits map { case (e, i) => - DefRegister(sx.info, loweredName(e), e.tpe, clock, reset, i) + Block((es zip names) zip inits map { case ((e, n), i) => + DefRegister(sx.info, n, e.tpe, clock, reset, i) }) } // Could instead just save the type of each Module as it gets processed case sx: WDefInstance => sx.tpe match { case t: BundleType => - val fieldsx = t.fields flatMap (f => - create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map ( + val fieldsx = t.fields flatMap { f => + renameExps(renames, f.name, sx.tpe, s"${sx.name}.") + create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map { e => // Flip because inst genders are reversed from Module type - e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe))) + Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) + } + } WDefInstance(sx.info, sx.name, sx.module, BundleType(fieldsx)) case _ => error("WDefInstance type should be Bundle!")(info, mname) } @@ -146,8 +202,30 @@ object LowerTypes extends Pass { memDataTypeMap(sx.name) = sx.dataType sx.dataType match { case _: GroundType => sx - case _ => Block(create_exps(sx.name, sx.dataType) map (e => - sx copy (name = loweredName(e), dataType = e.tpe))) + case _ => + // Rename ports + val seen: mutable.Set[String] = mutable.Set[String]() + create_exps(sx.name, memType(sx)) foreach { e => + val (mem, port, field, tail) = splitMemRef(e) + if (!seen.contains(field.name)) { + seen += field.name + val d = WRef(mem.name, sx.dataType) + tail match { + case None => + create_exps(mem.name, sx.dataType) foreach { x => + renames.rename(e.serialize, s"${loweredName(x)}.${port.serialize}.${field.serialize}") + } + case Some(_) => + renameMemExps(renames, d, mergeRef(port, field)) + } + } + } + Block(create_exps(sx.name, sx.dataType) map {e => + val newName = loweredName(e) + // Rename mems + renames.rename(sx.name, newName) + sx copy (name = newName, dataType = e.tpe) + }) } // wire foo : { a , b } // node x = foo @@ -159,7 +237,10 @@ object LowerTypes extends Pass { case sx: DefNode => val names = create_exps(sx.name, sx.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname) val exps = create_exps(sx.value) map lowerTypesExp(memDataTypeMap, info, mname) - Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) }) + renameExps(renames, sx.name, sx.value.tpe) + Block(names zip exps map { case (n, e) => + DefNode(info, loweredName(n), e) + }) case sx: IsInvalid => kind(sx.expr) match { case MemKind => Block(lowerTypesMemExp(memDataTypeMap, info, mname)(sx.expr) map (IsInvalid(info, _))) @@ -176,21 +257,32 @@ object LowerTypes extends Pass { } } - def lowerTypes(m: DefModule): DefModule = { + def lowerTypes(renames: RenameMap)(m: DefModule): DefModule = { val memDataTypeMap = new MemDataTypeMap + renames.setModule(m.name) // Lower Ports val portsx = m.ports flatMap { p => val exps = create_exps(WRef(p.name, p.tpe, PortKind, to_gender(p.direction))) - exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe)) + val names = exps map loweredName + renameExps(renames, p.name, p.tpe) + (exps zip names) map { case (e, n) => + Port(p.info, n, to_dir(gender(e)), e.tpe) + } } m match { case m: ExtModule => m copy (ports = portsx) case m: Module => - m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name) + m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name, renames) } } - def run(c: Circuit): Circuit = c copy (modules = c.modules map lowerTypes) + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) + val result = c copy (modules = c.modules map lowerTypes(renames)) + CircuitState(result, outputForm, state.annotations, Some(renames)) + } } diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index b072dfa0..c841dc32 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -14,7 +14,9 @@ case class MPort(name: String, clk: Expression) case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean) -object RemoveCHIRRTL extends Pass { +object RemoveCHIRRTL extends Transform { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm val ut = UnknownType type MPortMap = collection.mutable.LinkedHashMap[String, MPorts] type SeqMemSet = collection.mutable.HashSet[String] @@ -22,6 +24,15 @@ object RemoveCHIRRTL extends Pass { type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef] type AddrMap = collection.mutable.HashMap[String, Expression] + def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match { + case _: GroundType => Seq(ex) + case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex) + case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex) + case UnknownType => Seq(ex) + } + def create_exps(e: Expression): Seq[Expression] = e match { case ex: Mux => val e1s = create_exps(ex.tval) @@ -59,7 +70,7 @@ object RemoveCHIRRTL extends Pass { } def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap, - refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match { + refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match { case sx: CDefMemory => types(sx.name) = sx.tpe val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size))) @@ -104,11 +115,25 @@ object RemoveCHIRRTL extends Pass { addrs += "addr" clks += "clk" ens += "en" + renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata") + renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata") + val es = create_all_exps(WRef(sx.name, sx.tpe)) + val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe)) + val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe)) + ((es zip rs) zip ws) map { + case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) + } case MWrite => refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false) addrs += "addr" clks += "clk" ens += "en" + renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") + val es = create_all_exps(WRef(sx.name, sx.tpe)) + val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) + (es zip ws) map { + case (e, w) => renames.rename(e.serialize, w.serialize) + } case MRead => refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false) addrs += "addr" @@ -118,13 +143,19 @@ object RemoveCHIRRTL extends Pass { raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut) case _ => ens += "en" } + renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") + val es = create_all_exps(WRef(sx.name, sx.tpe)) + val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) + (es zip rs) map { + case (e, r) => renames.rename(e.serialize, r.serialize) + } case MInfer => // do nothing if it's not being used } Block( (addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++ (clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++ (ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one)))) - case sx => sx map collect_refs(mports, smems, types, refs, raddrs) + case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames) } def get_mask(refs: DataRefMap)(e: Expression): Expression = @@ -213,17 +244,23 @@ object RemoveCHIRRTL extends Pass { } } - def remove_chirrtl_m(m: DefModule): DefModule = { + def remove_chirrtl_m(renames: RenameMap)(m: DefModule): DefModule = { val mports = new MPortMap val smems = new SeqMemSet val types = new MPortTypeMap val refs = new DataRefMap val raddrs = new AddrMap + renames.setModule(m.name) (m map collect_smems_and_mports(mports, smems) - map collect_refs(mports, smems, types, refs, raddrs) + map collect_refs(mports, smems, types, refs, raddrs, renames) map remove_chirrtl_s(refs, raddrs)) } - def run(c: Circuit): Circuit = - c copy (modules = c.modules map remove_chirrtl_m) + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) + val result = c copy (modules = c.modules map remove_chirrtl_m(renames)) + CircuitState(result, outputForm, state.annotations, Some(renames)) + } } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index deddb93e..61bd68d0 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -31,7 +31,9 @@ import MemPortUtils.memType * there WOULD be collisions in references a[0] and a_0 so we still have * to rename a */ -object Uniquify extends Pass { +object Uniquify extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm private case class UniquifyException(msg: String) extends FIRRTLException(msg) private def error(msg: String)(implicit sinfo: Info, mname: String) = throw new UniquifyException(s"$sinfo: [module $mname] $msg") @@ -224,7 +226,10 @@ object Uniquify extends Pass { } // Everything wrapped in run so that it's thread safe - def run(c: Circuit): Circuit = { + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) // Debug state implicit var mname: String = "" implicit var sinfo: Info = NoInfo @@ -232,7 +237,8 @@ object Uniquify extends Pass { val portNameMap = collection.mutable.HashMap[String, Map[String, NameMapNode]]() val portTypeMap = collection.mutable.HashMap[String, Type]() - def uniquifyModule(m: DefModule): DefModule = { + def uniquifyModule(renames: RenameMap)(m: DefModule): DefModule = { + renames.setModule(m.name) val namespace = collection.mutable.HashSet[String]() val nameMap = collection.mutable.HashMap[String, NameMapNode]() @@ -251,7 +257,11 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - DefWire(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts)) + val newType = uniquifyNamesType(sx.tpe, node.elts) + (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + DefWire(sx.info, node.name, newType) } else { sx } @@ -259,8 +269,11 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - DefRegister(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts), - sx.clock, sx.reset, sx.init) + val newType = uniquifyNamesType(sx.tpe, node.elts) + (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init) } else { sx } @@ -268,7 +281,11 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) - WDefInstance(sx.info, node.name, sx.module, sx.tpe) + val newType = portTypeMap(m.name) + (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } + WDefInstance(sx.info, node.name, sx.module, newType) } else { sx } @@ -280,6 +297,9 @@ object Uniquify extends Pass { val mem = sx.copy(name = node.name, dataType = dataType) // Create new mapping to handle references to memory data fields val uniqueMemMap = createNameMapping(memType(sx), memType(mem)) + (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap) mem } else { @@ -289,6 +309,9 @@ object Uniquify extends Pass { sinfo = sx.info if (nameMap.contains(sx.name)) { val node = nameMap(sx.name) + (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach { + case (from, to) => renames.rename(from.serialize, to.serialize) + } DefNode(sx.info, node.name, sx.value) } else { sx @@ -320,7 +343,8 @@ object Uniquify extends Pass { } } - def uniquifyPorts(m: DefModule): DefModule = { + def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = { + renames.setModule(m.name) def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { val portsType = BundleType(ports map { case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) @@ -331,6 +355,7 @@ object Uniquify extends Pass { portTypeMap += (m.name -> uniquePortsType) ports zip uniquePortsType.fields map { case (p, f) => + renames.rename(p.name, f.name) Port(p.info, f.name, p.direction, f.tpe) } } @@ -344,7 +369,8 @@ object Uniquify extends Pass { } sinfo = c.info - Circuit(c.info, c.modules map uniquifyPorts map uniquifyModule, c.main) + val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main) + CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala index a2ec9935..2472a0e5 100644 --- a/src/main/scala/firrtl/passes/ZeroWidth.scala +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -10,8 +10,24 @@ import firrtl.Mappers._ import firrtl.Utils.throwInternalError -object ZeroWidth extends Pass { +object ZeroWidth extends Transform { + def inputForm = UnknownForm + def outputForm = UnknownForm private val ZERO = BigInt(0) + private def getRemoved(x: IsDeclaration): Seq[String] = { + var removedNames: Seq[String] = Seq.empty + def onType(name: String)(t: Type): Type = { + removedNames = Utils.create_exps(name, t) map {e => (e, e.tpe)} collect { + case (e, GroundType(IntWidth(ZERO))) => e.serialize + } + t + } + x match { + case s: Statement => s map onType(s.name) + case Port(_, name, _, t) => onType(name)(t) + } + removedNames + } private def removeZero(t: Type): Option[Type] = t match { case GroundType(IntWidth(ZERO)) => None case BundleType(fields) => @@ -34,35 +50,53 @@ object ZeroWidth extends Pass { def replaceType(x: Type): Type = t (e map replaceType) map onExp } - private def onStmt(s: Statement): Statement = s match { + private def onStmt(renames: RenameMap)(s: Statement): Statement = s match { case (_: DefWire| _: DefRegister| _: DefMemory) => + // List all removed expression names, and delete them from renames + renames.delete(getRemoved(s.asInstanceOf[IsDeclaration])) + // Create new types without zero-width wires var removed = false def applyRemoveZero(t: Type): Type = removeZero(t) match { case None => removed = true; t case Some(tx) => tx } val sxx = (s map onExp) map applyRemoveZero + // Return new declaration if(removed) EmptyStmt else sxx case Connect(info, loc, exp) => removeZero(loc.tpe) match { case None => EmptyStmt case Some(t) => Connect(info, loc, onExp(exp)) } + case IsInvalid(info, exp) => removeZero(exp.tpe) match { + case None => EmptyStmt + case Some(t) => IsInvalid(info, onExp(exp)) + } case DefNode(info, name, value) => removeZero(value.tpe) match { case None => EmptyStmt case Some(t) => DefNode(info, name, onExp(value)) } - case sx => sx map onStmt + case sx => sx map onStmt(renames) } - private def onModule(m: DefModule): DefModule = { - val ports = m.ports map (p => (p, removeZero(p.tpe))) collect { - case (Port(info, name, dir, _), Some(t)) => Port(info, name, dir, t) + private def onModule(renames: RenameMap)(m: DefModule): DefModule = { + renames.setModule(m.name) + // For each port, record deleted subcomponents + m.ports.foreach{p => renames.delete(getRemoved(p))} + val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap { + case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t)) + case (Port(_, name, _, _), None) => + renames.delete(name) + Nil } m match { case ext: ExtModule => ext.copy(ports = ports) - case in: Module => in.copy(ports = ports, body = onStmt(in.body)) + case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body)) } } - def run(c: Circuit): Circuit = { - InferTypes.run(c.copy(modules = c.modules map onModule)) + def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + renames.setCircuit(c.main) + val result = InferTypes.run(c.copy(modules = c.modules map onModule(renames))) + CircuitState(result, outputForm, state.annotations, Some(renames)) } } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index caaf430b..8cbf9da7 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -131,7 +131,7 @@ class ReplSeqMem extends Transform { new SimpleMidTransform(RemoveEmpty), new SimpleMidTransform(CheckInitialization), new SimpleMidTransform(InferTypes), - new SimpleMidTransform(Uniquify), + Uniquify, new SimpleMidTransform(ResolveKinds), new SimpleMidTransform(ResolveGenders)) |
