diff options
| author | Donggyu Kim | 2016-09-16 17:34:56 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-21 13:19:42 -0700 |
| commit | b83203f00d11ca61017fbbc847c290e2d56e29e9 (patch) | |
| tree | b5f109cac4a5620fbd91a58c1629a225c6ff9ac5 /src | |
| parent | a142551bfcce6b05e445bc75dd284d994c8e91f2 (diff) | |
refactor AnnotateValidMemConfigs
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala | 174 |
1 files changed, 86 insertions, 88 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala index 816a179b..f80d4a0c 100644 --- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala +++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala @@ -2,13 +2,16 @@ package firrtl.passes -import firrtl.ir._ import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import Utils.error +import AnalysisUtils._ + import net.jcazevedo.moultingyaml._ import net.jcazevedo.moultingyaml.DefaultYamlProtocol._ -import AnalysisUtils._ import scala.collection.mutable -import firrtl.Mappers._ +import java.io.{File, CharArrayWriter, PrintWriter} object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top @@ -36,7 +39,7 @@ case class MemDimension( rules: Option[DimensionRules], set: Option[List[Int]]) { require ( - if(rules == None) set != None else set == None, + if (rules == None) set != None else set == None, "Should specify either rules or a list of valid options, but not both" ) def getValid = set.getOrElse(rules.get.getValid).sorted @@ -128,15 +131,12 @@ case class SRAMCompiler( rules.find(r => r.ymux._1 == x && r.ybank._1 == y).get } - private val maskConfigOutputBuffer = new java.io.CharArrayWriter - private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter + private val maskConfigOutputBuffer = new CharArrayWriter + private val noMaskConfigOutputBuffer = new CharArrayWriter def append(m: DefMemory) : DefMemory = { - val validCombos = mutable.ArrayBuffer[SRAMConfig]() - defaultSearchOrdering foreach { r => - val config = r.getValidConfig(m) - if (config != None) validCombos += config.get - } + val validCombos = (defaultSearchOrdering map (_ getValidConfig m) + collect { case Some(config) => config }) // non empty if successfully found compiler option that supports depth/width // TODO: don't just take first option val usedConfig = { @@ -144,18 +144,19 @@ case class SRAMCompiler( else getBestAlternative(m) } val usesMaskGran = containsInfo(m.info, "maskGran") - if (configPattern != None) { - val newConfig = usedConfig.serialize(configPattern.get) + "\n" - val currentBuff = { - if (usesMaskGran) maskConfigOutputBuffer - else noMaskConfigOutputBuffer - } - if (!currentBuff.toString.contains(newConfig)) - currentBuff.append(newConfig) + configPattern match { + case None => + case Some(p) => + val newConfig = usedConfig.serialize(p) + "\n" + val currentBuff = { + if (usesMaskGran) maskConfigOutputBuffer + else noMaskConfigOutputBuffer + } + if (!currentBuff.toString.contains(newConfig)) currentBuff append newConfig } val temp = appendInfo(m.info, "sramConfig" -> usedConfig) - val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp - m.copy(info = newInfo) + val newInfo = if (usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp + m copy (info = newInfo) } // TODO: Should you really be splitting in 2 if, say, depth is 1 more than allowed? should be thresholded and @@ -164,47 +165,47 @@ case class SRAMCompiler( private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = { val validXRange = mutable.ArrayBuffer[SRAMRules]() val validYRange = mutable.ArrayBuffer[SRAMRules]() - defaultSearchOrdering foreach { r => + defaultSearchOrdering foreach { r => if (m.width <= r.getValidWidths.max) validXRange += r if (m.depth <= r.getValidDepths.max) validYRange += r } - - if (validXRange.isEmpty && validYRange.isEmpty) - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2)) - else if (validXRange.isEmpty && validYRange.nonEmpty) - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) - else if (validXRange.nonEmpty && validYRange.isEmpty) - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) - else if (validXRange.intersect(validYRange).nonEmpty) - Seq(m) - else - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++ + (validXRange.isEmpty, validYRange.isEmpty) match { + case (true, true) => + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2)) + case (true, false) => + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) + case (false, true) => + getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) + case (false, false) if validXRange.intersect(validYRange).nonEmpty => + Seq(m) + case (false, false) => + getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) + } } private def getBestAlternative(m: DefMemory): SRAMConfig = { val validConfigs = getInRange(SRAMConfig(width = bitWidth(m.dataType).intValue, depth = m.depth)) - val minNum = validConfigs.map(x => x.num).min + val minNum = validConfigs.map(_.num).min val validMinConfigs = validConfigs.filter(_.num == minNum) - val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap - val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min + val validMinConfigsSquareness = validMinConfigs.map( + x => math.abs(x.width.toDouble / x.depth - 1) -> x).toMap + val squarestAspectRatio = validMinConfigsSquareness.unzip._1.min val validConfig = validMinConfigsSquareness(squarestAspectRatio) - val validRules = mutable.ArrayBuffer[SRAMRules]() - defaultSearchOrdering foreach { r => - if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r - } + val validRules = defaultSearchOrdering filter (r => + (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max)) // TODO: don't just take first option // TODO: More optimal split if particular value is in range but not supported // TODO: Support up to 2 read ports, 2 write ports; should be power of 2? val bestRule = validRules.head val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get - bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit) + (bestRule.getValidConfig(width = memWidth, depth = memDepth).get + copy (xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)) } // TODO def serialize = ??? - } // TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true? @@ -212,11 +213,11 @@ case class SRAMCompiler( class YamlFileReader(file: String) { import CustomYAMLProtocol._ def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { - if (new java.io.File(file).exists) { + if (new File(file).exists) { val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n") - yamlString.parseYamls.flatMap(x => - try Some(reader.read(x)) - catch { case e: Exception => None } + yamlString.parseYamls flatMap (x => + try Some(reader read x) + catch { case e: Exception => None } ) } else error("Yaml file doesn't exist!") @@ -225,22 +226,20 @@ class YamlFileReader(file: String) { class YamlFileWriter(file: String) { import CustomYAMLProtocol._ - val outputBuffer = new java.io.CharArrayWriter + val outputBuffer = new CharArrayWriter val separator = "--- \n" - def append(in: YamlValue) = { - outputBuffer.append(separator + in.prettyPrint) + def append(in: YamlValue) { + outputBuffer append s"$separator${in.prettyPrint}" } - def dump = { - val outputFile = new java.io.PrintWriter(file) - outputFile.write(outputBuffer.toString) - outputFile.close() + def dump { + val outputFile = new PrintWriter(file) + outputFile write outputBuffer.toString + outputFile.close } } class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { - import CustomYAMLProtocol._ - def name = "Annotate memories with valid split depths, widths, #\'s" // TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist @@ -249,43 +248,42 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { sp: Option[SRAMCompiler] = None, dp: Option[SRAMCompiler] = None) { def serialize = { - if (sp != None) sp.get.serialize - if (dp != None) dp.get.serialize + sp match { + case None => + case Some(p) => p.serialize + } + dp match { + case None => + case Some(p) => p.serialize + } } } - val sramCompilers = { - if (reader == None) None - else { - val compilers = reader.get.parse[SRAMCompiler] - val sp = compilers.find(_.portType == "RW") - val dp = compilers.find(_.portType == "R,W") + + val sramCompilers = reader match { + case None => None + case Some(r) => + val compilers = r.parse[SRAMCompiler] + val sp = compilers find (_.portType == "RW") + val dp = compilers find (_.portType == "R,W") Some(SRAMCompilerSet(sp = sp, dp = dp)) - } } - def run(c: Circuit) = { - def annotateModMems(m: Module) = { - def updateStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => { - if (sramCompilers == None) m - else { - if (m.readwriters.length == 1) - if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!") - else sramCompilers.get.sp.get.append(m) - else - if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!") - else sramCompilers.get.dp.get.append(m) - } - } - case b: Block => b map updateStmts - case s => s - } - m.copy(body=updateStmts(m.body)) - } - val updatedMods = c.modules map { - case m: Module => annotateModMems(m) - case m: ExtModule => m + def updateStmts(s: Statement): Statement = s match { + case m: DefMemory if containsInfo(m.info, "useMacro") => sramCompilers match { + case None => m + case Some(compiler) if (m.readwriters.length == 1) => + compiler.sp match { + case None => error("Design needs RW port memory compiler!") + case Some(p) => p append m + } + case Some(compiler) => + compiler.dp match { + case None => error("Design needs R,W port memory compiler!") + case Some(p) => p append m + } } - c.copy(modules = updatedMods) - } + case s => s map updateStmts + } + + def run(c: Circuit) = c copy (modules = (c.modules map (_ map updateStmts))) } |
