aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorAlbert Magyar2019-07-19 12:29:08 -0700
committermergify[bot]2019-07-19 19:29:08 +0000
commit71d52cecde697d0734d55694e2344c2fb7e55cbe (patch)
tree012d498442457facc3407009782c0520949e647e /src/main
parent21d5c808a818835f2f4745c1c8ba3ae6aa194b16 (diff)
Add SimplifyMems transform to lower memories without splitting (#1111)
* Add SimplifyMems transform to lower memories without splitting * Remove spurious anonymous function
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/transforms/SimplifyMems.scala83
1 files changed, 83 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala
new file mode 100644
index 00000000..140efc9f
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala
@@ -0,0 +1,83 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.annotations._
+import firrtl.passes._
+import firrtl.passes.memlib._
+import scala.collection.mutable
+
+import Utils._
+import AnalysisUtils._
+import MemPortUtils._
+import ResolveMaskGranularity._
+
+/**
+ * Lowers memories without splitting them, but without the complexity of ReplaceMemMacros
+ */
+class SimplifyMems extends Transform {
+ def inputForm = MidForm
+ def outputForm = MidForm
+
+ def onModule(c: Circuit, renames: RenameMap)(m: DefModule): DefModule = {
+ val moduleNS = Namespace(m)
+ val connects = getConnects(m)
+ val memAdapters = new mutable.LinkedHashMap[String, DefWire]
+ val mTarget = ModuleTarget(c.main, m.name)
+
+ def onExpr(e: Expression): Expression = e.map(onExpr) match {
+ case WRef(name, tpe, MemKind, gender) if memAdapters.contains(name) =>
+ WRef(name, tpe, WireKind, gender)
+ case e => e
+ }
+
+ def simplifyMem(mem: DefMemory): Statement = {
+ val adapterDecl = DefWire(mem.info, mem.name, memType(mem))
+ val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
+ val oldRT = mTarget.ref(mem.name)
+ val adapterConnects = memType(simpleMemDecl).fields.flatMap {
+ case Field(pName, Flip, pType: BundleType) =>
+ val memPort = WSubField(WRef(simpleMemDecl), pName)
+ val adapterPort = WSubField(WRef(adapterDecl), pName)
+ renames.delete(oldRT.field(pName))
+ pType.fields.map {
+ case Field(name, Flip, _) if name.contains("data") => // read data
+ fromBits(WSubField(adapterPort, name), WSubField(memPort, name))
+ case Field(name, Default, _) if name.contains("data") => // write data
+ Connect(mem.info, WSubField(memPort, name), toBits(WSubField(adapterPort, name)))
+ case Field(name, Default, _) if name.contains("mask") => // mask
+ Connect(mem.info, WSubField(memPort, name), Utils.one)
+ case Field(name, _, _) => // etc
+ Connect(mem.info, WSubField(memPort, name), WSubField(adapterPort, name))
+ }
+ }
+ memAdapters(mem.name) = adapterDecl
+ renames.record(oldRT, oldRT.copy(ref = simpleMemDecl.name))
+ Block(Seq(adapterDecl, simpleMemDecl) ++ adapterConnects)
+ }
+
+ def canSimplify(mem: DefMemory) = mem.dataType match {
+ case at: AggregateType =>
+ val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
+ val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
+ (wMasks ++ rwMasks).flatten.isEmpty
+ case _ => false
+ }
+
+ def onStmt(s: Statement): Statement = s match {
+ case mem: DefMemory if canSimplify(mem) => simplifyMem(mem)
+ case s => s.map(onStmt).map(onExpr)
+ }
+
+ m.map(onStmt)
+ }
+
+ override def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ CircuitState(c.map(onModule(c, renames)), outputForm, state.annotations, Some(renames))
+ }
+}