aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorAdam Izraelevitz2017-05-10 11:23:18 -0700
committerGitHub2017-05-10 11:23:18 -0700
commit8b8eb4eac5b353d4d632065c78faf6a706d6aae8 (patch)
tree39e2d9344166b61b376df9d3cd15a4787bcd01f4 /src/main
parentaf222c1737fa72fce964190876346bdb7ff220cd (diff)
Update rename2 (#478)
* Added pass name to debug logger * Addresses #459. Rewords transform annotations API. Now, any annotation not propagated by a transform is considered deleted. A new DeletedAnnotation is added in place of it. * Added more stylized debugging style * WIP: make pass transform * WIP: All tests pass, need to pull master * Cleaned up PR * Added rename updates to all core transforms * Added more rename tests, and bugfixes * Renaming tracks non-leaf subfields E.g. given: wire x: {a: UInt<1>, b: UInt<1>[2]} Annotating x.b will eventually annotate x_b_0 and x_b_1 * Bugfix instance rename lowering broken * Address review comments * Remove check for seqTransform, UnknownForm too restrictive check
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/Compiler.scala50
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala2
-rw-r--r--src/main/scala/firrtl/Utils.scala2
-rw-r--r--src/main/scala/firrtl/passes/DeadCodeElimination.scala25
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala128
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala51
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala44
-rw-r--r--src/main/scala/firrtl/passes/ZeroWidth.scala52
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala2
9 files changed, 297 insertions, 59 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala
index 6e5cadcd..6c3911d6 100644
--- a/src/main/scala/firrtl/Compiler.scala
+++ b/src/main/scala/firrtl/Compiler.scala
@@ -2,9 +2,10 @@
package firrtl
-import logger.LazyLogging
+import logger._
import java.io.Writer
import annotations._
+import scala.collection.mutable
import firrtl.ir.Circuit
import passes.Pass
@@ -14,7 +15,47 @@ import Utils.throwInternalError
* RenameMap maps old names to modified names. Generated by transformations
* that modify names
*/
-case class RenameMap(map: Map[Named, Seq[Named]] = Map[Named, Seq[Named]]())
+object RenameMap {
+ def apply(map: Map[Named, Seq[Named]]) = {
+ val rm = new RenameMap
+ rm.addMap(map)
+ rm
+ }
+ def apply() = new RenameMap
+}
+class RenameMap {
+ val renameMap = new mutable.HashMap[Named, Seq[Named]]()
+ private var circuitName: String = ""
+ private var moduleName: String = ""
+ def setModule(s: String) =
+ moduleName = s
+ def setCircuit(s: String) =
+ circuitName = s
+ def rename(from: String, to: String): Unit = rename(from, Seq(to))
+ def rename(from: String, tos: Seq[String]): Unit = {
+ val fromName = ComponentName(from, ModuleName(moduleName, CircuitName(circuitName)))
+ val tosName = tos map { to =>
+ ComponentName(to, ModuleName(moduleName, CircuitName(circuitName)))
+ }
+ rename(fromName, tosName)
+ }
+ def rename(from: Named, to: Named): Unit = rename(from, Seq(to))
+ def rename(from: Named, tos: Seq[Named]): Unit = (from, tos) match {
+ case (x, Seq(y)) if x == y =>
+ case _ =>
+ renameMap(from) = renameMap.getOrElse(from, Seq.empty) ++ tos
+ }
+ def delete(names: Seq[String]): Unit = names.foreach(delete(_))
+ def delete(name: String): Unit =
+ delete(ComponentName(name, ModuleName(moduleName, CircuitName(circuitName))))
+ def delete(name: Named): Unit =
+ renameMap(name) = Seq.empty
+ def addMap(map: Map[Named, Seq[Named]]) =
+ renameMap ++= map
+ def serialize: String = renameMap.map { case (k, v) =>
+ k.serialize + "=>" + v.map(_.serialize).mkString(", ")
+ }.mkString("\n")
+}
/**
* Container of all annotations for a Firrtl compiler.
@@ -172,7 +213,6 @@ abstract class Transform extends LazyLogging {
}
logger.trace(s"Circuit:\n${result.circuit.serialize}")
logger.info(s"======== Finished Transform $name ========\n")
-
CircuitState(result.circuit, result.form, Some(AnnotationMap(remappedAnnotations)), None)
}
@@ -200,7 +240,7 @@ abstract class Transform extends LazyLogging {
}
// For each annotation, rename all annotations.
- val renames = renameOpt.getOrElse(RenameMap()).map
+ val renames = renameOpt.getOrElse(RenameMap()).renameMap
for {
anno <- newAnnotations.toSeq
newAnno <- anno.update(renames.getOrElse(anno.target, Seq(anno.target)))
@@ -217,8 +257,10 @@ trait SeqTransformBased {
/** For transformations that are simply a sequence of transforms */
abstract class SeqTransform extends Transform with SeqTransformBased {
def execute(state: CircuitState): CircuitState = {
+ /*
require(state.form <= inputForm,
s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
+ */
val ret = runTransforms(state)
CircuitState(ret.circuit, outputForm, ret.annotations, ret.renames)
}
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index b9042781..84f237a3 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -6,7 +6,6 @@ sealed abstract class CoreTransform extends SeqTransform
/** This transforms "CHIRRTL", the chisel3 IR, to "Firrtl". Note the resulting
* circuit has only IR nodes, not WIR.
- * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
*/
class ChirrtlToHighFirrtl extends CoreTransform {
def inputForm = ChirrtlForm
@@ -75,7 +74,6 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
/** Expands all aggregate types into many ground-typed components. Must
* accept a well-formed graph of only middle Firrtl features.
* Operates on working IR nodes.
- * TODO(izraelevitz): Create RenameMap from RemoveCHIRRTL
*/
class MiddleFirrtlToLowFirrtl extends CoreTransform {
def inputForm = MidForm
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 76156d2d..586fe1e7 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -541,7 +541,7 @@ object Utils extends LazyLogging {
}
/** Adds a root reference to some SubField/SubIndex chain */
- def mergeRef(root: WRef, body: Expression): Expression = body match {
+ def mergeRef(root: Expression, body: Expression): Expression = body match {
case e: WRef =>
WSubField(root, e.name, e.tpe, e.gender)
case e: WSubIndex =>
diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala
index 9f249f35..54ac76fe 100644
--- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala
@@ -9,8 +9,10 @@ import firrtl.Mappers._
import annotation.tailrec
-object DeadCodeElimination extends Pass {
- private def dceOnce(s: Statement): (Statement, Long) = {
+object DeadCodeElimination extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
+ private def dceOnce(renames: RenameMap)(s: Statement): (Statement, Long) = {
val referenced = collection.mutable.HashSet[String]()
var nEliminated = 0L
@@ -28,6 +30,7 @@ object DeadCodeElimination extends Pass {
if (referenced(name)) x
else {
nEliminated += 1
+ renames.delete(name)
EmptyStmt
}
@@ -43,16 +46,22 @@ object DeadCodeElimination extends Pass {
}
@tailrec
- private def dce(s: Statement): Statement = {
- val (res, n) = dceOnce(s)
- if (n > 0) dce(res) else res
+ private def dce(renames: RenameMap)(s: Statement): Statement = {
+ val (res, n) = dceOnce(renames)(s)
+ if (n > 0) dce(renames)(res) else res
}
- def run(c: Circuit): Circuit = {
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
val modulesx = c.modules.map {
case m: ExtModule => m
- case m: Module => Module(m.info, m.name, m.ports, dce(m.body))
+ case m: Module =>
+ renames.setModule(m.name)
+ Module(m.info, m.name, m.ports, dce(renames)(m.body))
}
- Circuit(c.info, modulesx, c.main)
+ val result = Circuit(c.info, modulesx, c.main)
+ CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 5826f56e..b48ab338 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -2,9 +2,11 @@
package firrtl.passes
+import scala.collection.mutable
import firrtl._
import firrtl.ir._
import firrtl.Utils._
+import MemPortUtils.memType
import firrtl.Mappers._
/** Removes all aggregate types from a [[firrtl.ir.Circuit]]
@@ -20,7 +22,10 @@ import firrtl.Mappers._
* wire foo_b : UInt<16>
* }}}
*/
-object LowerTypes extends Pass {
+object LowerTypes extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
+
/** Delimiter used in lowering names */
val delim = "_"
/** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name
@@ -33,7 +38,49 @@ object LowerTypes extends Pass {
case e: WSubIndex => s"${loweredName(e.exp)}$delim${e.value}"
}
def loweredName(s: Seq[String]): String = s mkString delim
+ def renameExps(renames: RenameMap, n: String, t: Type, root: String): Seq[String] =
+ renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), root)
+ def renameExps(renames: RenameMap, n: String, t: Type): Seq[String] =
+ renameExps(renames, WRef(n, t, ExpKind, UNKNOWNGENDER), "")
+ def renameExps(renames: RenameMap, e: Expression, root: String): Seq[String] = e.tpe match {
+ case (_: GroundType) =>
+ val name = root + loweredName(e)
+ renames.rename(e.serialize, name)
+ Seq(name)
+ case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) =>
+ val subNames = renameExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), root)
+ renames.rename(e.serialize, subNames)
+ names ++ subNames
+ }
+ case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) =>
+ val subNames = renameExps(renames, WSubIndex(e, i, t.tpe,gender(e)), root)
+ renames.rename(e.serialize, subNames)
+ names ++ subNames
+ }
+ }
+ private def renameMemExps(renames: RenameMap, e: Expression, portAndField: Expression): Seq[String] = e.tpe match {
+ case (_: GroundType) =>
+ val (mem, tail) = splitRef(e)
+ val loRef = mergeRef(WRef(loweredName(e)), portAndField)
+ val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
+ renames.rename(hiRef.serialize, loRef.serialize)
+ Seq(loRef.serialize)
+ case (t: BundleType) => t.fields.foldLeft(Seq[String]()){(names, f) =>
+ val subNames = renameMemExps(renames, WSubField(e, f.name, f.tpe, times(gender(e), f.flip)), portAndField)
+ val (mem, tail) = splitRef(e)
+ val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
+ renames.rename(hiRef.serialize, subNames)
+ names ++ subNames
+ }
+ case (t: VectorType) => (0 until t.size).foldLeft(Seq[String]()){(names, i) =>
+ val subNames = renameMemExps(renames, WSubIndex(e, i, t.tpe,gender(e)), portAndField)
+ val (mem, tail) = splitRef(e)
+ val hiRef = mergeRef(mem, mergeRef(portAndField, tail))
+ renames.rename(hiRef.serialize, subNames)
+ names ++ subNames
+ }
+ }
private case class LowerTypesException(msg: String) extends FIRRTLException(msg)
private def error(msg: String)(info: Info, mname: String) =
throw LowerTypesException(s"$info: [module $mname] $msg")
@@ -111,34 +158,43 @@ object LowerTypes extends Pass {
case e: DoPrim => e map lowerTypesExp(memDataTypeMap, info, mname)
case e @ (_: UIntLiteral | _: SIntLiteral) => e
}
-
def lowerTypesStmt(memDataTypeMap: MemDataTypeMap,
- minfo: Info, mname: String)(s: Statement): Statement = {
+ minfo: Info, mname: String, renames: RenameMap)(s: Statement): Statement = {
val info = get_info(s) match {case NoInfo => minfo case x => x}
- s map lowerTypesStmt(memDataTypeMap, info, mname) match {
+ s map lowerTypesStmt(memDataTypeMap, info, mname, renames) match {
case s: DefWire => s.tpe match {
case _: GroundType => s
- case _ => Block(create_exps(s.name, s.tpe) map (
- e => DefWire(s.info, loweredName(e), e.tpe)))
+ case _ =>
+ val exps = create_exps(s.name, s.tpe)
+ val names = exps map loweredName
+ renameExps(renames, s.name, s.tpe)
+ Block((exps zip names) map { case (e, n) =>
+ DefWire(s.info, n, e.tpe)
+ })
}
case sx: DefRegister => sx.tpe match {
case _: GroundType => sx map lowerTypesExp(memDataTypeMap, info, mname)
case _ =>
val es = create_exps(sx.name, sx.tpe)
+ val names = es map loweredName
+ renameExps(renames, sx.name, sx.tpe)
val inits = create_exps(sx.init) map lowerTypesExp(memDataTypeMap, info, mname)
val clock = lowerTypesExp(memDataTypeMap, info, mname)(sx.clock)
val reset = lowerTypesExp(memDataTypeMap, info, mname)(sx.reset)
- Block(es zip inits map { case (e, i) =>
- DefRegister(sx.info, loweredName(e), e.tpe, clock, reset, i)
+ Block((es zip names) zip inits map { case ((e, n), i) =>
+ DefRegister(sx.info, n, e.tpe, clock, reset, i)
})
}
// Could instead just save the type of each Module as it gets processed
case sx: WDefInstance => sx.tpe match {
case t: BundleType =>
- val fieldsx = t.fields flatMap (f =>
- create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map (
+ val fieldsx = t.fields flatMap { f =>
+ renameExps(renames, f.name, sx.tpe, s"${sx.name}.")
+ create_exps(WRef(f.name, f.tpe, ExpKind, times(f.flip, MALE))) map { e =>
// Flip because inst genders are reversed from Module type
- e => Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)))
+ Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
+ }
+ }
WDefInstance(sx.info, sx.name, sx.module, BundleType(fieldsx))
case _ => error("WDefInstance type should be Bundle!")(info, mname)
}
@@ -146,8 +202,30 @@ object LowerTypes extends Pass {
memDataTypeMap(sx.name) = sx.dataType
sx.dataType match {
case _: GroundType => sx
- case _ => Block(create_exps(sx.name, sx.dataType) map (e =>
- sx copy (name = loweredName(e), dataType = e.tpe)))
+ case _ =>
+ // Rename ports
+ val seen: mutable.Set[String] = mutable.Set[String]()
+ create_exps(sx.name, memType(sx)) foreach { e =>
+ val (mem, port, field, tail) = splitMemRef(e)
+ if (!seen.contains(field.name)) {
+ seen += field.name
+ val d = WRef(mem.name, sx.dataType)
+ tail match {
+ case None =>
+ create_exps(mem.name, sx.dataType) foreach { x =>
+ renames.rename(e.serialize, s"${loweredName(x)}.${port.serialize}.${field.serialize}")
+ }
+ case Some(_) =>
+ renameMemExps(renames, d, mergeRef(port, field))
+ }
+ }
+ }
+ Block(create_exps(sx.name, sx.dataType) map {e =>
+ val newName = loweredName(e)
+ // Rename mems
+ renames.rename(sx.name, newName)
+ sx copy (name = newName, dataType = e.tpe)
+ })
}
// wire foo : { a , b }
// node x = foo
@@ -159,7 +237,10 @@ object LowerTypes extends Pass {
case sx: DefNode =>
val names = create_exps(sx.name, sx.value.tpe) map lowerTypesExp(memDataTypeMap, info, mname)
val exps = create_exps(sx.value) map lowerTypesExp(memDataTypeMap, info, mname)
- Block(names zip exps map { case (n, e) => DefNode(info, loweredName(n), e) })
+ renameExps(renames, sx.name, sx.value.tpe)
+ Block(names zip exps map { case (n, e) =>
+ DefNode(info, loweredName(n), e)
+ })
case sx: IsInvalid => kind(sx.expr) match {
case MemKind =>
Block(lowerTypesMemExp(memDataTypeMap, info, mname)(sx.expr) map (IsInvalid(info, _)))
@@ -176,21 +257,32 @@ object LowerTypes extends Pass {
}
}
- def lowerTypes(m: DefModule): DefModule = {
+ def lowerTypes(renames: RenameMap)(m: DefModule): DefModule = {
val memDataTypeMap = new MemDataTypeMap
+ renames.setModule(m.name)
// Lower Ports
val portsx = m.ports flatMap { p =>
val exps = create_exps(WRef(p.name, p.tpe, PortKind, to_gender(p.direction)))
- exps map (e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe))
+ val names = exps map loweredName
+ renameExps(renames, p.name, p.tpe)
+ (exps zip names) map { case (e, n) =>
+ Port(p.info, n, to_dir(gender(e)), e.tpe)
+ }
}
m match {
case m: ExtModule =>
m copy (ports = portsx)
case m: Module =>
- m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name)
+ m copy (ports = portsx) map lowerTypesStmt(memDataTypeMap, m.info, m.name, renames)
}
}
- def run(c: Circuit): Circuit = c copy (modules = c.modules map lowerTypes)
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = c copy (modules = c.modules map lowerTypes(renames))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
+ }
}
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index b072dfa0..c841dc32 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -14,7 +14,9 @@ case class MPort(name: String, clk: Expression)
case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort])
case class DataRef(exp: Expression, male: String, female: String, mask: String, rdwrite: Boolean)
-object RemoveCHIRRTL extends Pass {
+object RemoveCHIRRTL extends Transform {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
val ut = UnknownType
type MPortMap = collection.mutable.LinkedHashMap[String, MPorts]
type SeqMemSet = collection.mutable.HashSet[String]
@@ -22,6 +24,15 @@ object RemoveCHIRRTL extends Pass {
type DataRefMap = collection.mutable.LinkedHashMap[String, DataRef]
type AddrMap = collection.mutable.HashMap[String, Expression]
+ def create_all_exps(ex: Expression): Seq[Expression] = ex.tpe match {
+ case _: GroundType => Seq(ex)
+ case t: BundleType => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_all_exps(SubField(ex, f.name, f.tpe))) ++ Seq(ex)
+ case t: VectorType => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_all_exps(SubIndex(ex, i, t.tpe))) ++ Seq(ex)
+ case UnknownType => Seq(ex)
+ }
+
def create_exps(e: Expression): Seq[Expression] = e match {
case ex: Mux =>
val e1s = create_exps(ex.tval)
@@ -59,7 +70,7 @@ object RemoveCHIRRTL extends Pass {
}
def collect_refs(mports: MPortMap, smems: SeqMemSet, types: MPortTypeMap,
- refs: DataRefMap, raddrs: AddrMap)(s: Statement): Statement = s match {
+ refs: DataRefMap, raddrs: AddrMap, renames: RenameMap)(s: Statement): Statement = s match {
case sx: CDefMemory =>
types(sx.name) = sx.tpe
val taddr = UIntType(IntWidth(1 max ceilLog2(sx.size)))
@@ -104,11 +115,25 @@ object RemoveCHIRRTL extends Pass {
addrs += "addr"
clks += "clk"
ens += "en"
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata")
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.rdata", sx.tpe))
+ val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.wdata", sx.tpe))
+ ((es zip rs) zip ws) map {
+ case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
+ }
case MWrite =>
refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false)
addrs += "addr"
clks += "clk"
ens += "en"
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
+ (es zip ws) map {
+ case (e, w) => renames.rename(e.serialize, w.serialize)
+ }
case MRead =>
refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false)
addrs += "addr"
@@ -118,13 +143,19 @@ object RemoveCHIRRTL extends Pass {
raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut)
case _ => ens += "en"
}
+ renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
+ val es = create_all_exps(WRef(sx.name, sx.tpe))
+ val rs = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
+ (es zip rs) map {
+ case (e, r) => renames.rename(e.serialize, r.serialize)
+ }
case MInfer => // do nothing if it's not being used
}
Block(
(addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++
(clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++
(ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one))))
- case sx => sx map collect_refs(mports, smems, types, refs, raddrs)
+ case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
}
def get_mask(refs: DataRefMap)(e: Expression): Expression =
@@ -213,17 +244,23 @@ object RemoveCHIRRTL extends Pass {
}
}
- def remove_chirrtl_m(m: DefModule): DefModule = {
+ def remove_chirrtl_m(renames: RenameMap)(m: DefModule): DefModule = {
val mports = new MPortMap
val smems = new SeqMemSet
val types = new MPortTypeMap
val refs = new DataRefMap
val raddrs = new AddrMap
+ renames.setModule(m.name)
(m map collect_smems_and_mports(mports, smems)
- map collect_refs(mports, smems, types, refs, raddrs)
+ map collect_refs(mports, smems, types, refs, raddrs, renames)
map remove_chirrtl_s(refs, raddrs))
}
- def run(c: Circuit): Circuit =
- c copy (modules = c.modules map remove_chirrtl_m)
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = c copy (modules = c.modules map remove_chirrtl_m(renames))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
+ }
}
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index deddb93e..61bd68d0 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -31,7 +31,9 @@ import MemPortUtils.memType
* there WOULD be collisions in references a[0] and a_0 so we still have
* to rename a
*/
-object Uniquify extends Pass {
+object Uniquify extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
private case class UniquifyException(msg: String) extends FIRRTLException(msg)
private def error(msg: String)(implicit sinfo: Info, mname: String) =
throw new UniquifyException(s"$sinfo: [module $mname] $msg")
@@ -224,7 +226,10 @@ object Uniquify extends Pass {
}
// Everything wrapped in run so that it's thread safe
- def run(c: Circuit): Circuit = {
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
// Debug state
implicit var mname: String = ""
implicit var sinfo: Info = NoInfo
@@ -232,7 +237,8 @@ object Uniquify extends Pass {
val portNameMap = collection.mutable.HashMap[String, Map[String, NameMapNode]]()
val portTypeMap = collection.mutable.HashMap[String, Type]()
- def uniquifyModule(m: DefModule): DefModule = {
+ def uniquifyModule(renames: RenameMap)(m: DefModule): DefModule = {
+ renames.setModule(m.name)
val namespace = collection.mutable.HashSet[String]()
val nameMap = collection.mutable.HashMap[String, NameMapNode]()
@@ -251,7 +257,11 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- DefWire(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts))
+ val newType = uniquifyNamesType(sx.tpe, node.elts)
+ (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ DefWire(sx.info, node.name, newType)
} else {
sx
}
@@ -259,8 +269,11 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- DefRegister(sx.info, node.name, uniquifyNamesType(sx.tpe, node.elts),
- sx.clock, sx.reset, sx.init)
+ val newType = uniquifyNamesType(sx.tpe, node.elts)
+ (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ DefRegister(sx.info, node.name, newType, sx.clock, sx.reset, sx.init)
} else {
sx
}
@@ -268,7 +281,11 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
- WDefInstance(sx.info, node.name, sx.module, sx.tpe)
+ val newType = portTypeMap(m.name)
+ (Utils.create_exps(sx.name, sx.tpe) zip Utils.create_exps(node.name, newType)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
+ WDefInstance(sx.info, node.name, sx.module, newType)
} else {
sx
}
@@ -280,6 +297,9 @@ object Uniquify extends Pass {
val mem = sx.copy(name = node.name, dataType = dataType)
// Create new mapping to handle references to memory data fields
val uniqueMemMap = createNameMapping(memType(sx), memType(mem))
+ (Utils.create_exps(sx.name, memType(sx)) zip Utils.create_exps(node.name, memType(mem))) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
nameMap(sx.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap)
mem
} else {
@@ -289,6 +309,9 @@ object Uniquify extends Pass {
sinfo = sx.info
if (nameMap.contains(sx.name)) {
val node = nameMap(sx.name)
+ (Utils.create_exps(sx.name, s.asInstanceOf[DefNode].value.tpe) zip Utils.create_exps(node.name, sx.value.tpe)) foreach {
+ case (from, to) => renames.rename(from.serialize, to.serialize)
+ }
DefNode(sx.info, node.name, sx.value)
} else {
sx
@@ -320,7 +343,8 @@ object Uniquify extends Pass {
}
}
- def uniquifyPorts(m: DefModule): DefModule = {
+ def uniquifyPorts(renames: RenameMap)(m: DefModule): DefModule = {
+ renames.setModule(m.name)
def uniquifyPorts(ports: Seq[Port]): Seq[Port] = {
val portsType = BundleType(ports map {
case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
@@ -331,6 +355,7 @@ object Uniquify extends Pass {
portTypeMap += (m.name -> uniquePortsType)
ports zip uniquePortsType.fields map { case (p, f) =>
+ renames.rename(p.name, f.name)
Port(p.info, f.name, p.direction, f.tpe)
}
}
@@ -344,7 +369,8 @@ object Uniquify extends Pass {
}
sinfo = c.info
- Circuit(c.info, c.modules map uniquifyPorts map uniquifyModule, c.main)
+ val result = Circuit(c.info, c.modules map uniquifyPorts(renames) map uniquifyModule(renames), c.main)
+ CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala
index a2ec9935..2472a0e5 100644
--- a/src/main/scala/firrtl/passes/ZeroWidth.scala
+++ b/src/main/scala/firrtl/passes/ZeroWidth.scala
@@ -10,8 +10,24 @@ import firrtl.Mappers._
import firrtl.Utils.throwInternalError
-object ZeroWidth extends Pass {
+object ZeroWidth extends Transform {
+ def inputForm = UnknownForm
+ def outputForm = UnknownForm
private val ZERO = BigInt(0)
+ private def getRemoved(x: IsDeclaration): Seq[String] = {
+ var removedNames: Seq[String] = Seq.empty
+ def onType(name: String)(t: Type): Type = {
+ removedNames = Utils.create_exps(name, t) map {e => (e, e.tpe)} collect {
+ case (e, GroundType(IntWidth(ZERO))) => e.serialize
+ }
+ t
+ }
+ x match {
+ case s: Statement => s map onType(s.name)
+ case Port(_, name, _, t) => onType(name)(t)
+ }
+ removedNames
+ }
private def removeZero(t: Type): Option[Type] = t match {
case GroundType(IntWidth(ZERO)) => None
case BundleType(fields) =>
@@ -34,35 +50,53 @@ object ZeroWidth extends Pass {
def replaceType(x: Type): Type = t
(e map replaceType) map onExp
}
- private def onStmt(s: Statement): Statement = s match {
+ private def onStmt(renames: RenameMap)(s: Statement): Statement = s match {
case (_: DefWire| _: DefRegister| _: DefMemory) =>
+ // List all removed expression names, and delete them from renames
+ renames.delete(getRemoved(s.asInstanceOf[IsDeclaration]))
+ // Create new types without zero-width wires
var removed = false
def applyRemoveZero(t: Type): Type = removeZero(t) match {
case None => removed = true; t
case Some(tx) => tx
}
val sxx = (s map onExp) map applyRemoveZero
+ // Return new declaration
if(removed) EmptyStmt else sxx
case Connect(info, loc, exp) => removeZero(loc.tpe) match {
case None => EmptyStmt
case Some(t) => Connect(info, loc, onExp(exp))
}
+ case IsInvalid(info, exp) => removeZero(exp.tpe) match {
+ case None => EmptyStmt
+ case Some(t) => IsInvalid(info, onExp(exp))
+ }
case DefNode(info, name, value) => removeZero(value.tpe) match {
case None => EmptyStmt
case Some(t) => DefNode(info, name, onExp(value))
}
- case sx => sx map onStmt
+ case sx => sx map onStmt(renames)
}
- private def onModule(m: DefModule): DefModule = {
- val ports = m.ports map (p => (p, removeZero(p.tpe))) collect {
- case (Port(info, name, dir, _), Some(t)) => Port(info, name, dir, t)
+ private def onModule(renames: RenameMap)(m: DefModule): DefModule = {
+ renames.setModule(m.name)
+ // For each port, record deleted subcomponents
+ m.ports.foreach{p => renames.delete(getRemoved(p))}
+ val ports = m.ports map (p => (p, removeZero(p.tpe))) flatMap {
+ case (Port(info, name, dir, _), Some(t)) => Seq(Port(info, name, dir, t))
+ case (Port(_, name, _, _), None) =>
+ renames.delete(name)
+ Nil
}
m match {
case ext: ExtModule => ext.copy(ports = ports)
- case in: Module => in.copy(ports = ports, body = onStmt(in.body))
+ case in: Module => in.copy(ports = ports, body = onStmt(renames)(in.body))
}
}
- def run(c: Circuit): Circuit = {
- InferTypes.run(c.copy(modules = c.modules map onModule))
+ def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ renames.setCircuit(c.main)
+ val result = InferTypes.run(c.copy(modules = c.modules map onModule(renames)))
+ CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index caaf430b..8cbf9da7 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -131,7 +131,7 @@ class ReplSeqMem extends Transform {
new SimpleMidTransform(RemoveEmpty),
new SimpleMidTransform(CheckInitialization),
new SimpleMidTransform(InferTypes),
- new SimpleMidTransform(Uniquify),
+ Uniquify,
new SimpleMidTransform(ResolveKinds),
new SimpleMidTransform(ResolveGenders))