aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJohn Wright2019-03-07 15:34:33 -0800
committermergify[bot]2019-03-07 23:34:33 +0000
commita97a81bc0f717f80bb70733795ac5337653b58c5 (patch)
tree23ea1801f2767aee9d26a3c5588891cee4d2aab8 /src
parent0ac1fa56da06e6b0590ed05ab1ea047188d54602 (diff)
Add a data structure for memory conf reading and writing (#1041)
* Copy MemConf.scala from ucb-bar/barstools#35 into memlib. This provides a data structure wrapper around the existing memory conf format which contains both reading and writing methods, making it easier to write code that needs to read the format. * Add MemConf tests and use a Map[MemPort, Int] for port lists instead of a Seq[MemPort] which is a bit less fragile.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemConf.scala69
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala14
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala61
3 files changed, 131 insertions, 13 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala
new file mode 100644
index 00000000..55600bf6
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala
@@ -0,0 +1,69 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import scala.util.matching._
+
+sealed abstract class MemPort(val name: String) { override def toString = name }
+
+case object ReadPort extends MemPort("read")
+case object WritePort extends MemPort("write")
+case object MaskedWritePort extends MemPort("mwrite")
+case object ReadWritePort extends MemPort("rw")
+case object MaskedReadWritePort extends MemPort("mrw")
+
+object MemPort {
+
+ val all = Set(ReadPort, WritePort, MaskedWritePort, ReadWritePort, MaskedReadWritePort)
+
+ def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s)
+
+ def fromString(s: String): Map[MemPort, Int] = {
+ s.split(",").toSeq.map(MemPort.apply).map(_ match {
+ case Some(x) => x
+ case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
+ }).groupBy(identity).mapValues(_.size)
+ }
+}
+
+case class MemConf(
+ name: String,
+ depth: Int,
+ width: Int,
+ ports: Map[MemPort, Int],
+ maskGranularity: Option[Int]
+) {
+
+ private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",")
+ private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("")
+
+ // Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes
+ // (otherwise an entry of XYZPort -> 0 would not be equivalent to another with no XYZPort despite being semantically the same)
+ ports.foreach { case (k, v) => require(v > 0, "Cannot have negative or zero entry in the port map") }
+
+ override def toString = s"name ${name} depth ${depth} width ${width} ports ${portsStr} ${maskGranStr} \n"
+}
+
+object MemConf {
+
+ val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r
+
+ def fromString(s: String): Seq[MemConf] = {
+ s.split("\n").toSeq.map(_ match {
+ case MemConf.regex(name, depth, width, ports, maskGran) => MemConf(name, depth.toInt, width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))
+ case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
+ })
+ }
+
+ def apply(name: String, depth: Int, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = {
+ val ports: Map[MemPort, Int] = (if (maskGranularity.isEmpty) {
+ (if (writePorts == 0) Map.empty[MemPort, Int] else Map(WritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Map.empty[MemPort, Int] else Map(ReadWritePort -> readWritePorts))
+ } else {
+ (if (writePorts == 0) Map.empty[MemPort, Int] else Map(MaskedWritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Map.empty[MemPort, Int] else Map(MaskedReadWritePort -> readWritePorts))
+ }) ++ (if (readPorts == 0) Map.empty[MemPort, Int] else Map(ReadPort -> readPorts))
+ return new MemConf(name, depth, width, ports, maskGranularity)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index 643f63c6..1f8e89be 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -50,15 +50,11 @@ class ConfWriter(filename: String) {
val outputBuffer = new CharArrayWriter
def append(m: DefAnnotatedMemory) = {
// legacy
- val maskGran = m.maskGran
- val readers = List.fill(m.readers.length)("read")
- val writers = List.fill(m.writers.length)(if (maskGran.isEmpty) "write" else "mwrite")
- val readwriters = List.fill(m.readwriters.length)(if (maskGran.isEmpty) "rw" else "mrw")
- val ports = (writers ++ readers ++ readwriters) mkString ","
- val maskGranConf = maskGran match { case None => "" case Some(p) => s"mask_gran $p" }
- val width = bitWidth(m.dataType)
- val conf = s"name ${m.name} depth ${m.depth} width $width ports $ports $maskGranConf \n"
- outputBuffer.append(conf)
+ // assert that we don't overflow going from BigInt to Int conversion
+ require(bitWidth(m.dataType) <= Int.MaxValue)
+ m.maskGran.foreach { case x => require(x <= Int.MaxValue) }
+ val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt))
+ outputBuffer.append(conf.toString)
}
def serialize() = {
val outputFile = new PrintWriter(filename)
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 6cedd3f0..b51e2271 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -27,6 +27,16 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
}
)
+ def checkMemConf(filename: String, mems: Set[MemConf]) {
+ // Read the mem conf
+ val file = scala.io.Source.fromFile(filename)
+ val text = try file.mkString finally file.close()
+ // Verify that this does not throw an exception
+ val fromConf = MemConf.fromString(text)
+ // Verify the mems in the conf are the same as the expected ones
+ require(Set(fromConf: _*) == mems, "Parsed conf set:\n {\n " + fromConf.mkString(" ") + " }\n must be the same as reference conf set: \n {\n " + mems.toSeq.mkString(" ") + " }\n")
+ }
+
"ReplSeqMem" should "generate blackbox wrappers for mems of bundle type" in {
val input = """
circuit Top :
@@ -63,11 +73,17 @@ circuit Top :
read mport R1 = entries_info2[head_ptr], clock
io2.commit_entry.bits.info <- R1
""".stripMargin
+ val mems = Set(
+ MemConf("entries_info_ext", 24, 30, Map(WritePort -> 1, ReadPort -> 1), None),
+ MemConf("entries_info2_ext", 24, 30, Map(MaskedWritePort -> 1, ReadPort -> 1), Some(10))
+ )
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -85,11 +101,14 @@ circuit Top :
when p_valid :
write mport T_155 = mem[p_address], clock
""".stripMargin
+ val mems = Set(MemConf("mem_ext", 32, 64, Map(MaskedWritePort -> 1), Some(64)))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -110,11 +129,14 @@ circuit CustomMemory :
_T_18 <= io.dI
skip
""".stripMargin
+ val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -135,11 +157,14 @@ circuit CustomMemory :
_T_18 <= io.dI
skip
""".stripMargin
+ val mems = Set(MemConf("mem_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
parse(res.getEmittedCircuit.value)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -188,6 +213,7 @@ circuit Top :
tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) }
}
+
"ReplSeqMem" should "not de-duplicate memories with the nodedupe annotation " in {
val input = """
circuit CustomMemory :
@@ -209,6 +235,10 @@ circuit CustomMemory :
_T_20 <= io.dI
skip
"""
+ val mems = Set(
+ MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None),
+ MemConf("mem_1_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)
+ )
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(
ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
@@ -221,6 +251,8 @@ circuit CustomMemory :
case _ => false
}
numExtMods should be (2)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -249,6 +281,10 @@ circuit CustomMemory :
_T_22 <= io.dI
skip
"""
+ val mems = Set(
+ MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None),
+ MemConf("mem_1_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)
+ )
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(
ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
@@ -261,6 +297,8 @@ circuit CustomMemory :
case _ => false
}
numExtMods should be (2)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -300,6 +338,10 @@ circuit CustomMemory :
w1 <= io.dI
w2 <= io.dI
"""
+ val mems = Set(
+ MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None),
+ MemConf("mem_0_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None)
+ )
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(
ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
@@ -316,6 +358,8 @@ circuit CustomMemory :
// If the NoDedupMemAnnotation were handled incorrectly as it was prior to this test, there
// would be 3 ExtModules
numExtMods should be (2)
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
@@ -340,6 +384,7 @@ circuit CustomMemory :
_T_20 <= io.dI
skip
"""
+ val mems = Set(MemConf("mem_0_ext", 7, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
@@ -353,7 +398,7 @@ circuit CustomMemory :
(new java.io.File(confLoc)).delete()
}
- "ReplSeqMem" should "should not have a mask if there is none" in {
+ "ReplSeqMem" should "not have a mask if there is none" in {
val input = """
circuit CustomMemory :
module CustomMemory :
@@ -368,14 +413,17 @@ circuit CustomMemory :
write mport w = mem[io.waddr], clock
w <= io.wdata
"""
+ val mems = Set(MemConf("mem_ext", 1024, 16, Map(WritePort -> 1, ReadPort -> 1), None))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
res.getEmittedCircuit.value shouldNot include ("mask")
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
- "ReplSeqMem" should "should not conjoin enable signal with mask condition" in {
+ "ReplSeqMem" should "not conjoin enable signal with mask condition" in {
val input = """
circuit CustomMemory :
module CustomMemory :
@@ -393,16 +441,19 @@ circuit CustomMemory :
when io.mask[1] :
w[1] <= io.wdata[1]
"""
+ val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedWritePort -> 1, ReadPort -> 1), Some(8)))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)")
res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)")
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
- "ReplSeqMem" should "should not conjoin enable signal with wmask condition (RW Port)" in {
+ "ReplSeqMem" should "not conjoin enable signal with wmask condition (RW Port)" in {
val input = """
circuit CustomMemory :
module CustomMemory :
@@ -424,6 +475,7 @@ circuit CustomMemory :
io.out <= r
"""
+ val mems = Set(MemConf("mem_ext", 1024, 16, Map(MaskedReadWritePort -> 1), Some(8)))
val confLoc = "ReplSeqMemTests.confTEMP"
val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
InferReadWriteAnnotation)
@@ -431,9 +483,10 @@ circuit CustomMemory :
// TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)")
res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)")
+ // Check the emitted conf
+ checkMemConf(confLoc, mems)
(new java.io.File(confLoc)).delete()
}
}
// TODO: make more checks
-// conf