aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDonggyu Kim2016-09-16 17:34:56 -0700
committerDonggyu Kim2016-09-21 13:19:42 -0700
commitb83203f00d11ca61017fbbc847c290e2d56e29e9 (patch)
treeb5f109cac4a5620fbd91a58c1629a225c6ff9ac5
parenta142551bfcce6b05e445bc75dd284d994c8e91f2 (diff)
refactor AnnotateValidMemConfigs
-rw-r--r--src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala174
1 files changed, 86 insertions, 88 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
index 816a179b..f80d4a0c 100644
--- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
+++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
@@ -2,13 +2,16 @@
package firrtl.passes
-import firrtl.ir._
import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import Utils.error
+import AnalysisUtils._
+
import net.jcazevedo.moultingyaml._
import net.jcazevedo.moultingyaml.DefaultYamlProtocol._
-import AnalysisUtils._
import scala.collection.mutable
-import firrtl.Mappers._
+import java.io.{File, CharArrayWriter, PrintWriter}
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
@@ -36,7 +39,7 @@ case class MemDimension(
rules: Option[DimensionRules],
set: Option[List[Int]]) {
require (
- if(rules == None) set != None else set == None,
+ if (rules == None) set != None else set == None,
"Should specify either rules or a list of valid options, but not both"
)
def getValid = set.getOrElse(rules.get.getValid).sorted
@@ -128,15 +131,12 @@ case class SRAMCompiler(
rules.find(r => r.ymux._1 == x && r.ybank._1 == y).get
}
- private val maskConfigOutputBuffer = new java.io.CharArrayWriter
- private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter
+ private val maskConfigOutputBuffer = new CharArrayWriter
+ private val noMaskConfigOutputBuffer = new CharArrayWriter
def append(m: DefMemory) : DefMemory = {
- val validCombos = mutable.ArrayBuffer[SRAMConfig]()
- defaultSearchOrdering foreach { r =>
- val config = r.getValidConfig(m)
- if (config != None) validCombos += config.get
- }
+ val validCombos = (defaultSearchOrdering map (_ getValidConfig m)
+ collect { case Some(config) => config })
// non empty if successfully found compiler option that supports depth/width
// TODO: don't just take first option
val usedConfig = {
@@ -144,18 +144,19 @@ case class SRAMCompiler(
else getBestAlternative(m)
}
val usesMaskGran = containsInfo(m.info, "maskGran")
- if (configPattern != None) {
- val newConfig = usedConfig.serialize(configPattern.get) + "\n"
- val currentBuff = {
- if (usesMaskGran) maskConfigOutputBuffer
- else noMaskConfigOutputBuffer
- }
- if (!currentBuff.toString.contains(newConfig))
- currentBuff.append(newConfig)
+ configPattern match {
+ case None =>
+ case Some(p) =>
+ val newConfig = usedConfig.serialize(p) + "\n"
+ val currentBuff = {
+ if (usesMaskGran) maskConfigOutputBuffer
+ else noMaskConfigOutputBuffer
+ }
+ if (!currentBuff.toString.contains(newConfig)) currentBuff append newConfig
}
val temp = appendInfo(m.info, "sramConfig" -> usedConfig)
- val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
- m.copy(info = newInfo)
+ val newInfo = if (usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
+ m copy (info = newInfo)
}
// TODO: Should you really be splitting in 2 if, say, depth is 1 more than allowed? should be thresholded and
@@ -164,47 +165,47 @@ case class SRAMCompiler(
private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = {
val validXRange = mutable.ArrayBuffer[SRAMRules]()
val validYRange = mutable.ArrayBuffer[SRAMRules]()
- defaultSearchOrdering foreach { r =>
+ defaultSearchOrdering foreach { r =>
if (m.width <= r.getValidWidths.max) validXRange += r
if (m.depth <= r.getValidDepths.max) validYRange += r
}
-
- if (validXRange.isEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
- else if (validXRange.isEmpty && validYRange.nonEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
- else if (validXRange.nonEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
- else if (validXRange.intersect(validYRange).nonEmpty)
- Seq(m)
- else
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
+ (validXRange.isEmpty, validYRange.isEmpty) match {
+ case (true, true) =>
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
+ case (true, false) =>
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
+ case (false, true) =>
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
+ case (false, false) if validXRange.intersect(validYRange).nonEmpty =>
+ Seq(m)
+ case (false, false) =>
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
+ }
}
private def getBestAlternative(m: DefMemory): SRAMConfig = {
val validConfigs = getInRange(SRAMConfig(width = bitWidth(m.dataType).intValue, depth = m.depth))
- val minNum = validConfigs.map(x => x.num).min
+ val minNum = validConfigs.map(_.num).min
val validMinConfigs = validConfigs.filter(_.num == minNum)
- val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap
- val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min
+ val validMinConfigsSquareness = validMinConfigs.map(
+ x => math.abs(x.width.toDouble / x.depth - 1) -> x).toMap
+ val squarestAspectRatio = validMinConfigsSquareness.unzip._1.min
val validConfig = validMinConfigsSquareness(squarestAspectRatio)
- val validRules = mutable.ArrayBuffer[SRAMRules]()
- defaultSearchOrdering foreach { r =>
- if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r
- }
+ val validRules = defaultSearchOrdering filter (r =>
+ (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max))
// TODO: don't just take first option
// TODO: More optimal split if particular value is in range but not supported
// TODO: Support up to 2 read ports, 2 write ports; should be power of 2?
val bestRule = validRules.head
val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get
val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get
- bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)
+ (bestRule.getValidConfig(width = memWidth, depth = memDepth).get
+ copy (xsplit = validConfig.xsplit, ysplit = validConfig.ysplit))
}
// TODO
def serialize = ???
-
}
// TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true?
@@ -212,11 +213,11 @@ case class SRAMCompiler(
class YamlFileReader(file: String) {
import CustomYAMLProtocol._
def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
- if (new java.io.File(file).exists) {
+ if (new File(file).exists) {
val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
- yamlString.parseYamls.flatMap(x =>
- try Some(reader.read(x))
- catch { case e: Exception => None }
+ yamlString.parseYamls flatMap (x =>
+ try Some(reader read x)
+ catch { case e: Exception => None }
)
}
else error("Yaml file doesn't exist!")
@@ -225,22 +226,20 @@ class YamlFileReader(file: String) {
class YamlFileWriter(file: String) {
import CustomYAMLProtocol._
- val outputBuffer = new java.io.CharArrayWriter
+ val outputBuffer = new CharArrayWriter
val separator = "--- \n"
- def append(in: YamlValue) = {
- outputBuffer.append(separator + in.prettyPrint)
+ def append(in: YamlValue) {
+ outputBuffer append s"$separator${in.prettyPrint}"
}
- def dump = {
- val outputFile = new java.io.PrintWriter(file)
- outputFile.write(outputBuffer.toString)
- outputFile.close()
+ def dump {
+ val outputFile = new PrintWriter(file)
+ outputFile write outputBuffer.toString
+ outputFile.close
}
}
class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
-
import CustomYAMLProtocol._
-
def name = "Annotate memories with valid split depths, widths, #\'s"
// TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist
@@ -249,43 +248,42 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
sp: Option[SRAMCompiler] = None,
dp: Option[SRAMCompiler] = None) {
def serialize = {
- if (sp != None) sp.get.serialize
- if (dp != None) dp.get.serialize
+ sp match {
+ case None =>
+ case Some(p) => p.serialize
+ }
+ dp match {
+ case None =>
+ case Some(p) => p.serialize
+ }
}
}
- val sramCompilers = {
- if (reader == None) None
- else {
- val compilers = reader.get.parse[SRAMCompiler]
- val sp = compilers.find(_.portType == "RW")
- val dp = compilers.find(_.portType == "R,W")
+
+ val sramCompilers = reader match {
+ case None => None
+ case Some(r) =>
+ val compilers = r.parse[SRAMCompiler]
+ val sp = compilers find (_.portType == "RW")
+ val dp = compilers find (_.portType == "R,W")
Some(SRAMCompilerSet(sp = sp, dp = dp))
- }
}
- def run(c: Circuit) = {
- def annotateModMems(m: Module) = {
- def updateStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") => {
- if (sramCompilers == None) m
- else {
- if (m.readwriters.length == 1)
- if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!")
- else sramCompilers.get.sp.get.append(m)
- else
- if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!")
- else sramCompilers.get.dp.get.append(m)
- }
- }
- case b: Block => b map updateStmts
- case s => s
- }
- m.copy(body=updateStmts(m.body))
- }
- val updatedMods = c.modules map {
- case m: Module => annotateModMems(m)
- case m: ExtModule => m
+ def updateStmts(s: Statement): Statement = s match {
+ case m: DefMemory if containsInfo(m.info, "useMacro") => sramCompilers match {
+ case None => m
+ case Some(compiler) if (m.readwriters.length == 1) =>
+ compiler.sp match {
+ case None => error("Design needs RW port memory compiler!")
+ case Some(p) => p append m
+ }
+ case Some(compiler) =>
+ compiler.dp match {
+ case None => error("Design needs R,W port memory compiler!")
+ case Some(p) => p append m
+ }
}
- c.copy(modules = updatedMods)
- }
+ case s => s map updateStmts
+ }
+
+ def run(c: Circuit) = c copy (modules = (c.modules map (_ map updateStmts)))
}