aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
diff options
context:
space:
mode:
authorAngie2016-08-30 00:27:30 -0700
committerjackkoenig2016-09-06 00:17:18 -0700
commit6a05468ed0ece1ace3019666b16f2ae83ef76ef9 (patch)
tree5d4e4244c61845334184a45f4df960c2d7ccb313 /src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
parenta82f30d90940fd3c0386dee6f1ef21850c3c91c9 (diff)
Address style feedback and add tests for getConnectOrigin utility
Diffstat (limited to 'src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala')
-rw-r--r--src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala138
1 files changed, 67 insertions, 71 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
index a7f7703b..816a179b 100644
--- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
+++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
@@ -1,3 +1,5 @@
+// See LICENSE for license details.
+
package firrtl.passes
import firrtl.ir._
@@ -5,7 +7,8 @@ import firrtl._
import net.jcazevedo.moultingyaml._
import net.jcazevedo.moultingyaml.DefaultYamlProtocol._
import AnalysisUtils._
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+import firrtl.Mappers._
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
@@ -17,13 +20,12 @@ object CustomYAMLProtocol extends DefaultYamlProtocol {
}
case class DimensionRules(
- min: Int,
- // step size
- inc: Int,
- max: Int,
- // these values should not be used, regardless of min,inc,max
- illegal: Option[List[Int]]
-){
+ min: Int,
+ // step size
+ inc: Int,
+ max: Int,
+ // these values should not be used, regardless of min,inc,max
+ illegal: Option[List[Int]]) {
def getValid = {
val range = (min to max by inc).toList
range.filterNot(illegal.getOrElse(List[Int]()).toSet)
@@ -31,9 +33,8 @@ case class DimensionRules(
}
case class MemDimension(
- rules: Option[DimensionRules],
- set: Option[List[Int]]
-){
+ rules: Option[DimensionRules],
+ set: Option[List[Int]]) {
require (
if(rules == None) set != None else set == None,
"Should specify either rules or a list of valid options, but not both"
@@ -42,44 +43,42 @@ case class MemDimension(
}
case class SRAMConfig(
- ymux: String = "",
- ybank: String = "",
- width: Int,
- depth: Int,
- xsplit: Int = 1,
- ysplit: Int = 1
-){
+ ymux: String = "",
+ ybank: String = "",
+ width: Int,
+ depth: Int,
+ xsplit: Int = 1,
+ ysplit: Int = 1) {
// how many duplicate copies of this SRAM are needed
def num = xsplit * ysplit
def serialize(pattern: String): String = {
- val fieldMap = getClass.getDeclaredFields.map{f =>
+ val fieldMap = getClass.getDeclaredFields.map { f =>
f.setAccessible(true)
f.getName -> f.get(this)
- }.toMap
+ } toMap
val fieldDelimiter = """\[.*?\]""".r
val configOptions = fieldDelimiter.findAllIn(pattern).toList
- configOptions.foldLeft(pattern)((b,a) => {
+ configOptions.foldLeft(pattern)((b, a) => {
// Expects the contents of [] are valid configuration fields (otherwise key match error)
val fieldVal = {
- try fieldMap(a.substring(1,a.length-1))
- catch { case e: Exception => Error("**SRAM config field incorrect**") }
+ try fieldMap(a.substring(1, a.length-1))
+ catch { case e: Exception => error("**SRAM config field incorrect**") }
}
- b.replace(a,fieldVal.toString)
- })
+ b.replace(a, fieldVal.toString)
+ } )
}
}
// Ex: https://www.ece.cmu.edu/~ece548/hw/hw5/meml80.pdf
case class SRAMRules(
- // column mux parameter (for adjusting aspect ratio)
- ymux: (Int,String),
- // vertical segmentation (banking -- tradeoff performance / area)
- ybank: (Int,String),
- width: MemDimension,
- depth: MemDimension
-){
+ // column mux parameter (for adjusting aspect ratio)
+ ymux: (Int, String),
+ // vertical segmentation (banking -- tradeoff performance / area)
+ ybank: (Int, String),
+ width: MemDimension,
+ depth: MemDimension) {
def getValidWidths = width.getValid
def getValidDepths = depth.getValid
def getValidConfig(width: Int, depth: Int): Option[SRAMConfig] = {
@@ -88,13 +87,12 @@ case class SRAMRules(
else
None
}
- def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue,m.depth)
+ def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue, m.depth)
}
case class WMaskArg(
- t: String,
- f: String
-)
+ t: String,
+ f: String)
// vendor-specific compilers
case class SRAMCompiler(
@@ -116,8 +114,7 @@ case class SRAMCompiler(
defaultArgs: Option[String],
// default behavior (if not used) is to have wmask port width = datawidth/maskgran
// if true: wmask port width pre-filled to datawidth
- fillWMask: Boolean
-){
+ fillWMask: Boolean) {
require(portType == "RW" || portType == "R,W", "Memory must be single port RW or dual port R,W")
require(
(configFile != None && configPattern != None && wMaskArg != None) || configFile == None,
@@ -135,7 +132,7 @@ case class SRAMCompiler(
private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter
def append(m: DefMemory) : DefMemory = {
- val validCombos = ArrayBuffer[SRAMConfig]()
+ val validCombos = mutable.ArrayBuffer[SRAMConfig]()
defaultSearchOrdering foreach { r =>
val config = r.getValidConfig(m)
if (config != None) validCombos += config.get
@@ -146,7 +143,7 @@ case class SRAMCompiler(
if (validCombos.nonEmpty) validCombos.head
else getBestAlternative(m)
}
- val usesMaskGran = containsInfo(m.info,"maskGran")
+ val usesMaskGran = containsInfo(m.info, "maskGran")
if (configPattern != None) {
val newConfig = usedConfig.serialize(configPattern.get) + "\n"
val currentBuff = {
@@ -156,8 +153,8 @@ case class SRAMCompiler(
if (!currentBuff.toString.contains(newConfig))
currentBuff.append(newConfig)
}
- val temp = appendInfo(m.info,"sramConfig" -> usedConfig)
- val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp,"maskGran" -> 1) else temp
+ val temp = appendInfo(m.info, "sramConfig" -> usedConfig)
+ val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
m.copy(info = newInfo)
}
@@ -165,24 +162,24 @@ case class SRAMCompiler(
// handled w/ a separate set of registers ?
// split memory until width, depth achievable via given memory compiler
private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = {
- val validXRange = ArrayBuffer[SRAMRules]()
- val validYRange = ArrayBuffer[SRAMRules]()
+ val validXRange = mutable.ArrayBuffer[SRAMRules]()
+ val validYRange = mutable.ArrayBuffer[SRAMRules]()
defaultSearchOrdering foreach { r =>
if (m.width <= r.getValidWidths.max) validXRange += r
if (m.depth <= r.getValidDepths.max) validYRange += r
}
if (validXRange.isEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2,depth = m.depth/2))
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
else if (validXRange.isEmpty && validYRange.nonEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2,depth = m.depth))
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
else if (validXRange.nonEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width,depth = m.depth/2))
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
else if (validXRange.intersect(validYRange).nonEmpty)
Seq(m)
else
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width,depth = m.depth/2)) ++
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2,depth = m.depth))
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
}
private def getBestAlternative(m: DefMemory): SRAMConfig = {
@@ -190,39 +187,39 @@ case class SRAMCompiler(
val minNum = validConfigs.map(x => x.num).min
val validMinConfigs = validConfigs.filter(_.num == minNum)
val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap
- val squarestAspectRatio = validMinConfigsSquareness.map(x => x._1).min
+ val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min
val validConfig = validMinConfigsSquareness(squarestAspectRatio)
- val validRules = ArrayBuffer[SRAMRules]()
+ val validRules = mutable.ArrayBuffer[SRAMRules]()
defaultSearchOrdering foreach { r =>
if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r
}
// TODO: don't just take first option
+ // TODO: More optimal split if particular value is in range but not supported
+ // TODO: Support up to 2 read ports, 2 write ports; should be power of 2?
val bestRule = validRules.head
val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get
val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get
bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)
}
- def serialize() = {
- // TODO
- }
+ // TODO
+ def serialize = ???
}
// TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true?
// Or is this where module-level transforms (rather than circuit-level) make sense?
-class YamlFileReader(file: String){
+class YamlFileReader(file: String) {
import CustomYAMLProtocol._
def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
if (new java.io.File(file).exists) {
val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
- val optionOut = yamlString.parseYamls.map(x =>
+ yamlString.parseYamls.flatMap(x =>
try Some(reader.read(x))
- catch {case e: Exception => None}
+ catch { case e: Exception => None }
)
- optionOut.filter(_ != None).map(_.get)
}
- else Error("Yaml file doesn't exist!")
+ else error("Yaml file doesn't exist!")
}
}
@@ -233,7 +230,7 @@ class YamlFileWriter(file: String) {
def append(in: YamlValue) = {
outputBuffer.append(separator + in.prettyPrint)
}
- def serialize = {
+ def dump = {
val outputFile = new java.io.PrintWriter(file)
outputFile.write(outputBuffer.toString)
outputFile.close()
@@ -249,12 +246,11 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
// TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist
// TODO: Don't get first available?
case class SRAMCompilerSet(
- sp: Option[SRAMCompiler] = None,
- dp: Option[SRAMCompiler] = None
- ){
- def serialize() = {
- if (sp != None) sp.get.serialize()
- if (dp != None) dp.get.serialize()
+ sp: Option[SRAMCompiler] = None,
+ dp: Option[SRAMCompiler] = None) {
+ def serialize = {
+ if (sp != None) sp.get.serialize
+ if (dp != None) dp.get.serialize
}
}
val sramCompilers = {
@@ -270,18 +266,18 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
def run(c: Circuit) = {
def annotateModMems(m: Module) = {
def updateStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info,"useMacro") => {
+ case m: DefMemory if containsInfo(m.info, "useMacro") => {
if (sramCompilers == None) m
else {
if (m.readwriters.length == 1)
- if (sramCompilers.get.sp == None) Error("Design needs RW port memory compiler!")
+ if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!")
else sramCompilers.get.sp.get.append(m)
else
- if (sramCompilers.get.dp == None) Error("Design needs R,W port memory compiler!")
+ if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!")
else sramCompilers.get.dp.get.append(m)
}
}
- case b: Block => Block(b.stmts map updateStmts)
+ case b: Block => b map updateStmts
case s => s
}
m.copy(body=updateStmts(m.body))
@@ -292,4 +288,4 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
}
c.copy(modules = updatedMods)
}
-} \ No newline at end of file
+}