diff options
| author | Albert Magyar | 2019-07-19 12:29:08 -0700 |
|---|---|---|
| committer | mergify[bot] | 2019-07-19 19:29:08 +0000 |
| commit | 71d52cecde697d0734d55694e2344c2fb7e55cbe (patch) | |
| tree | 012d498442457facc3407009782c0520949e647e /src/main | |
| parent | 21d5c808a818835f2f4745c1c8ba3ae6aa194b16 (diff) | |
Add SimplifyMems transform to lower memories without splitting (#1111)
* Add SimplifyMems transform to lower memories without splitting
* Remove spurious anonymous function
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/scala/firrtl/transforms/SimplifyMems.scala | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala new file mode 100644 index 00000000..140efc9f --- /dev/null +++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala @@ -0,0 +1,83 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.annotations._ +import firrtl.passes._ +import firrtl.passes.memlib._ +import scala.collection.mutable + +import Utils._ +import AnalysisUtils._ +import MemPortUtils._ +import ResolveMaskGranularity._ + +/** + * Lowers memories without splitting them, but without the complexity of ReplaceMemMacros + */ +class SimplifyMems extends Transform { + def inputForm = MidForm + def outputForm = MidForm + + def onModule(c: Circuit, renames: RenameMap)(m: DefModule): DefModule = { + val moduleNS = Namespace(m) + val connects = getConnects(m) + val memAdapters = new mutable.LinkedHashMap[String, DefWire] + val mTarget = ModuleTarget(c.main, m.name) + + def onExpr(e: Expression): Expression = e.map(onExpr) match { + case WRef(name, tpe, MemKind, gender) if memAdapters.contains(name) => + WRef(name, tpe, WireKind, gender) + case e => e + } + + def simplifyMem(mem: DefMemory): Statement = { + val adapterDecl = DefWire(mem.info, mem.name, memType(mem)) + val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType)) + val oldRT = mTarget.ref(mem.name) + val adapterConnects = memType(simpleMemDecl).fields.flatMap { + case Field(pName, Flip, pType: BundleType) => + val memPort = WSubField(WRef(simpleMemDecl), pName) + val adapterPort = WSubField(WRef(adapterDecl), pName) + renames.delete(oldRT.field(pName)) + pType.fields.map { + case Field(name, Flip, _) if name.contains("data") => // read data + fromBits(WSubField(adapterPort, name), WSubField(memPort, name)) + case Field(name, Default, _) if name.contains("data") => // write data + Connect(mem.info, WSubField(memPort, name), toBits(WSubField(adapterPort, name))) + case Field(name, Default, _) if name.contains("mask") => // mask + Connect(mem.info, WSubField(memPort, name), Utils.one) + case Field(name, _, _) => // etc + Connect(mem.info, WSubField(memPort, name), WSubField(adapterPort, name)) + } + } + memAdapters(mem.name) = adapterDecl + renames.record(oldRT, oldRT.copy(ref = simpleMemDecl.name)) + Block(Seq(adapterDecl, simpleMemDecl) ++ adapterConnects) + } + + def canSimplify(mem: DefMemory) = mem.dataType match { + case at: AggregateType => + val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask"))) + val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask"))) + (wMasks ++ rwMasks).flatten.isEmpty + case _ => false + } + + def onStmt(s: Statement): Statement = s match { + case mem: DefMemory if canSimplify(mem) => simplifyMem(mem) + case s => s.map(onStmt).map(onExpr) + } + + m.map(onStmt) + } + + override def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + CircuitState(c.map(onModule(c, renames)), outputForm, state.annotations, Some(renames)) + } +} |
