diff options
| author | Albert Magyar | 2019-07-19 12:29:08 -0700 |
|---|---|---|
| committer | mergify[bot] | 2019-07-19 19:29:08 +0000 |
| commit | 71d52cecde697d0734d55694e2344c2fb7e55cbe (patch) | |
| tree | 012d498442457facc3407009782c0520949e647e /src | |
| parent | 21d5c808a818835f2f4745c1c8ba3ae6aa194b16 (diff) | |
Add SimplifyMems transform to lower memories without splitting (#1111)
* Add SimplifyMems transform to lower memories without splitting
* Remove spurious anonymous function
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/transforms/SimplifyMems.scala | 83 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/SimplifyMemsSpec.scala | 90 |
2 files changed, 173 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala new file mode 100644 index 00000000..140efc9f --- /dev/null +++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala @@ -0,0 +1,83 @@ +// See LICENSE for license details. + +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.annotations._ +import firrtl.passes._ +import firrtl.passes.memlib._ +import scala.collection.mutable + +import Utils._ +import AnalysisUtils._ +import MemPortUtils._ +import ResolveMaskGranularity._ + +/** + * Lowers memories without splitting them, but without the complexity of ReplaceMemMacros + */ +class SimplifyMems extends Transform { + def inputForm = MidForm + def outputForm = MidForm + + def onModule(c: Circuit, renames: RenameMap)(m: DefModule): DefModule = { + val moduleNS = Namespace(m) + val connects = getConnects(m) + val memAdapters = new mutable.LinkedHashMap[String, DefWire] + val mTarget = ModuleTarget(c.main, m.name) + + def onExpr(e: Expression): Expression = e.map(onExpr) match { + case WRef(name, tpe, MemKind, gender) if memAdapters.contains(name) => + WRef(name, tpe, WireKind, gender) + case e => e + } + + def simplifyMem(mem: DefMemory): Statement = { + val adapterDecl = DefWire(mem.info, mem.name, memType(mem)) + val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType)) + val oldRT = mTarget.ref(mem.name) + val adapterConnects = memType(simpleMemDecl).fields.flatMap { + case Field(pName, Flip, pType: BundleType) => + val memPort = WSubField(WRef(simpleMemDecl), pName) + val adapterPort = WSubField(WRef(adapterDecl), pName) + renames.delete(oldRT.field(pName)) + pType.fields.map { + case Field(name, Flip, _) if name.contains("data") => // read data + fromBits(WSubField(adapterPort, name), WSubField(memPort, name)) + case Field(name, Default, _) if name.contains("data") => // write data + Connect(mem.info, WSubField(memPort, name), toBits(WSubField(adapterPort, name))) + case Field(name, Default, _) if name.contains("mask") => // mask + Connect(mem.info, WSubField(memPort, name), Utils.one) + case Field(name, _, _) => // etc + Connect(mem.info, WSubField(memPort, name), WSubField(adapterPort, name)) + } + } + memAdapters(mem.name) = adapterDecl + renames.record(oldRT, oldRT.copy(ref = simpleMemDecl.name)) + Block(Seq(adapterDecl, simpleMemDecl) ++ adapterConnects) + } + + def canSimplify(mem: DefMemory) = mem.dataType match { + case at: AggregateType => + val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask"))) + val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask"))) + (wMasks ++ rwMasks).flatten.isEmpty + case _ => false + } + + def onStmt(s: Statement): Statement = s match { + case mem: DefMemory if canSimplify(mem) => simplifyMem(mem) + case s => s.map(onStmt).map(onExpr) + } + + m.map(onStmt) + } + + override def execute(state: CircuitState): CircuitState = { + val c = state.circuit + val renames = RenameMap() + CircuitState(c.map(onModule(c, renames)), outputForm, state.annotations, Some(renames)) + } +} diff --git a/src/test/scala/firrtlTests/SimplifyMemsSpec.scala b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala new file mode 100644 index 00000000..69f1a70e --- /dev/null +++ b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala @@ -0,0 +1,90 @@ +// See LICENSE for license details. + +package firrtlTests + +import java.io._ +import org.scalatest._ +import org.scalatest.prop._ +import firrtl.Parser +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.transforms._ +import firrtl._ + +import CompilerUtils.getLoweringTransforms + +class SimplifyMemsSpec extends ConstantPropagationSpec { + override val transforms = getLoweringTransforms(ChirrtlForm, MidForm) ++ Seq(new SimplifyMems) + + "SimplifyMems" should "lower aggregate memories" in { + val input = + """circuit Test : + | module Test : + | input clock : Clock + | input wen : UInt<1> + | input wdata : { a : UInt<8>, b : UInt<8> } + | output rdata : { a : UInt<8>, b : UInt<8> } + | mem m : + | data-type => { a : UInt<8>, b : UInt<8>} + | depth => 32 + | read-latency => 1 + | write-latency => 1 + | reader => read + | writer => write + | m.read.clk <= clock + | m.read.en <= UInt<1>(1) + | m.read.addr is invalid + | rdata <= m.read.data + | m.write.clk <= clock + | m.write.en <= wen + | m.write.mask.a <= UInt<1>(1) + | m.write.mask.b <= UInt<1>(1) + | m.write.addr is invalid + | m.write.data <= wdata + + """.stripMargin + + val check = + """circuit Test : + | module Test : + | input clock : Clock + | input wen : UInt<1> + | input wdata : { a : UInt<8>, b : UInt<8>} + | output rdata : { a : UInt<8>, b : UInt<8>} + | + | wire m : { flip read : { addr : UInt<5>, en : UInt<1>, clk : Clock, flip data : { a : UInt<8>, b : UInt<8>}}, flip write : { addr : UInt<5>, en : UInt<1>, clk : Clock, data : { a : UInt<8>, b : UInt<8>}, mask : { a : UInt<1>, b : UInt<1>}}} + | mem m_flattened : + | data-type => UInt<16> + | depth => 32 + | read-latency => 1 + | write-latency => 1 + | reader => read + | writer => write + | read-under-write => undefined + | m_flattened.read.addr <= m.read.addr + | m_flattened.read.en <= m.read.en + | m_flattened.read.clk <= m.read.clk + | m.read.data.b <= asUInt(bits(m_flattened.read.data, 7, 0)) + | m.read.data.a <= asUInt(bits(m_flattened.read.data, 15, 8)) + | m_flattened.write.addr <= m.write.addr + | m_flattened.write.en <= m.write.en + | m_flattened.write.clk <= m.write.clk + | m_flattened.write.data <= cat(asUInt(m.write.data.a), asUInt(m.write.data.b)) + | m_flattened.write.mask <= UInt<1>("h1") + | rdata.a <= m.read.data.a + | rdata.b <= m.read.data.b + | m.read.addr is invalid + | m.read.en <= UInt<1>("h1") + | m.read.clk <= clock + | m.write.addr is invalid + | m.write.en <= wen + | m.write.clk <= clock + | m.write.data.a <= wdata.a + | m.write.data.b <= wdata.b + | m.write.mask.a <= UInt<1>("h1") + | m.write.mask.b <= UInt<1>("h1") + + """.stripMargin + (parse(exec(input))) should be (parse(check)) + } +} |
