aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/transforms/SimplifyMems.scala83
-rw-r--r--src/test/scala/firrtlTests/SimplifyMemsSpec.scala90
2 files changed, 173 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/transforms/SimplifyMems.scala b/src/main/scala/firrtl/transforms/SimplifyMems.scala
new file mode 100644
index 00000000..140efc9f
--- /dev/null
+++ b/src/main/scala/firrtl/transforms/SimplifyMems.scala
@@ -0,0 +1,83 @@
+// See LICENSE for license details.
+
+package firrtl
+package transforms
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.annotations._
+import firrtl.passes._
+import firrtl.passes.memlib._
+import scala.collection.mutable
+
+import Utils._
+import AnalysisUtils._
+import MemPortUtils._
+import ResolveMaskGranularity._
+
+/**
+ * Lowers memories without splitting them, but without the complexity of ReplaceMemMacros
+ */
+class SimplifyMems extends Transform {
+ def inputForm = MidForm
+ def outputForm = MidForm
+
+ def onModule(c: Circuit, renames: RenameMap)(m: DefModule): DefModule = {
+ val moduleNS = Namespace(m)
+ val connects = getConnects(m)
+ val memAdapters = new mutable.LinkedHashMap[String, DefWire]
+ val mTarget = ModuleTarget(c.main, m.name)
+
+ def onExpr(e: Expression): Expression = e.map(onExpr) match {
+ case WRef(name, tpe, MemKind, gender) if memAdapters.contains(name) =>
+ WRef(name, tpe, WireKind, gender)
+ case e => e
+ }
+
+ def simplifyMem(mem: DefMemory): Statement = {
+ val adapterDecl = DefWire(mem.info, mem.name, memType(mem))
+ val simpleMemDecl = mem.copy(name = moduleNS.newName(s"${mem.name}_flattened"), dataType = flattenType(mem.dataType))
+ val oldRT = mTarget.ref(mem.name)
+ val adapterConnects = memType(simpleMemDecl).fields.flatMap {
+ case Field(pName, Flip, pType: BundleType) =>
+ val memPort = WSubField(WRef(simpleMemDecl), pName)
+ val adapterPort = WSubField(WRef(adapterDecl), pName)
+ renames.delete(oldRT.field(pName))
+ pType.fields.map {
+ case Field(name, Flip, _) if name.contains("data") => // read data
+ fromBits(WSubField(adapterPort, name), WSubField(memPort, name))
+ case Field(name, Default, _) if name.contains("data") => // write data
+ Connect(mem.info, WSubField(memPort, name), toBits(WSubField(adapterPort, name)))
+ case Field(name, Default, _) if name.contains("mask") => // mask
+ Connect(mem.info, WSubField(memPort, name), Utils.one)
+ case Field(name, _, _) => // etc
+ Connect(mem.info, WSubField(memPort, name), WSubField(adapterPort, name))
+ }
+ }
+ memAdapters(mem.name) = adapterDecl
+ renames.record(oldRT, oldRT.copy(ref = simpleMemDecl.name))
+ Block(Seq(adapterDecl, simpleMemDecl) ++ adapterConnects)
+ }
+
+ def canSimplify(mem: DefMemory) = mem.dataType match {
+ case at: AggregateType =>
+ val wMasks = mem.writers.map(w => getMaskBits(connects, memPortField(mem, w, "en"), memPortField(mem, w, "mask")))
+ val rwMasks = mem.readwriters.map(w => getMaskBits(connects, memPortField(mem, w, "wmode"), memPortField(mem, w, "wmask")))
+ (wMasks ++ rwMasks).flatten.isEmpty
+ case _ => false
+ }
+
+ def onStmt(s: Statement): Statement = s match {
+ case mem: DefMemory if canSimplify(mem) => simplifyMem(mem)
+ case s => s.map(onStmt).map(onExpr)
+ }
+
+ m.map(onStmt)
+ }
+
+ override def execute(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val renames = RenameMap()
+ CircuitState(c.map(onModule(c, renames)), outputForm, state.annotations, Some(renames))
+ }
+}
diff --git a/src/test/scala/firrtlTests/SimplifyMemsSpec.scala b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala
new file mode 100644
index 00000000..69f1a70e
--- /dev/null
+++ b/src/test/scala/firrtlTests/SimplifyMemsSpec.scala
@@ -0,0 +1,90 @@
+// See LICENSE for license details.
+
+package firrtlTests
+
+import java.io._
+import org.scalatest._
+import org.scalatest.prop._
+import firrtl.Parser
+import firrtl.ir.Circuit
+import firrtl.passes._
+import firrtl.transforms._
+import firrtl._
+
+import CompilerUtils.getLoweringTransforms
+
+class SimplifyMemsSpec extends ConstantPropagationSpec {
+ override val transforms = getLoweringTransforms(ChirrtlForm, MidForm) ++ Seq(new SimplifyMems)
+
+ "SimplifyMems" should "lower aggregate memories" in {
+ val input =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock
+ | input wen : UInt<1>
+ | input wdata : { a : UInt<8>, b : UInt<8> }
+ | output rdata : { a : UInt<8>, b : UInt<8> }
+ | mem m :
+ | data-type => { a : UInt<8>, b : UInt<8>}
+ | depth => 32
+ | read-latency => 1
+ | write-latency => 1
+ | reader => read
+ | writer => write
+ | m.read.clk <= clock
+ | m.read.en <= UInt<1>(1)
+ | m.read.addr is invalid
+ | rdata <= m.read.data
+ | m.write.clk <= clock
+ | m.write.en <= wen
+ | m.write.mask.a <= UInt<1>(1)
+ | m.write.mask.b <= UInt<1>(1)
+ | m.write.addr is invalid
+ | m.write.data <= wdata
+
+ """.stripMargin
+
+ val check =
+ """circuit Test :
+ | module Test :
+ | input clock : Clock
+ | input wen : UInt<1>
+ | input wdata : { a : UInt<8>, b : UInt<8>}
+ | output rdata : { a : UInt<8>, b : UInt<8>}
+ |
+ | wire m : { flip read : { addr : UInt<5>, en : UInt<1>, clk : Clock, flip data : { a : UInt<8>, b : UInt<8>}}, flip write : { addr : UInt<5>, en : UInt<1>, clk : Clock, data : { a : UInt<8>, b : UInt<8>}, mask : { a : UInt<1>, b : UInt<1>}}}
+ | mem m_flattened :
+ | data-type => UInt<16>
+ | depth => 32
+ | read-latency => 1
+ | write-latency => 1
+ | reader => read
+ | writer => write
+ | read-under-write => undefined
+ | m_flattened.read.addr <= m.read.addr
+ | m_flattened.read.en <= m.read.en
+ | m_flattened.read.clk <= m.read.clk
+ | m.read.data.b <= asUInt(bits(m_flattened.read.data, 7, 0))
+ | m.read.data.a <= asUInt(bits(m_flattened.read.data, 15, 8))
+ | m_flattened.write.addr <= m.write.addr
+ | m_flattened.write.en <= m.write.en
+ | m_flattened.write.clk <= m.write.clk
+ | m_flattened.write.data <= cat(asUInt(m.write.data.a), asUInt(m.write.data.b))
+ | m_flattened.write.mask <= UInt<1>("h1")
+ | rdata.a <= m.read.data.a
+ | rdata.b <= m.read.data.b
+ | m.read.addr is invalid
+ | m.read.en <= UInt<1>("h1")
+ | m.read.clk <= clock
+ | m.write.addr is invalid
+ | m.write.en <= wen
+ | m.write.clk <= clock
+ | m.write.data.a <= wdata.a
+ | m.write.data.b <= wdata.b
+ | m.write.mask.a <= UInt<1>("h1")
+ | m.write.mask.b <= UInt<1>("h1")
+
+ """.stripMargin
+ (parse(exec(input))) should be (parse(check))
+ }
+}