diff options
| author | Adam Izraelevitz | 2016-10-17 18:53:19 -0700 |
|---|---|---|
| committer | Angie Wang | 2016-10-17 18:53:19 -0700 |
| commit | 85baeda249e59c7d9d9f159aaf29ff46d685cf02 (patch) | |
| tree | cfb5f4a6a0a80f9033275de6e5e36b9d5b96faad /src | |
| parent | 7d08b9a1486fef0459481f6e542464a29fbe1db5 (diff) | |
Reorganized memory blackboxing (#336)
* Reorganized memory blackboxing
Moved to new package memlib
Added comments
Moved utility functions around
Removed unused AnnotateValidMemConfigs.scala
* Fixed tests to pass
* Use DefAnnotatedMemory instead of AppendableInfo
* Broke passes up into simpler passes
AnnotateMemMacros ->
(ToMemIR, ResolveMaskGranularity)
UpdateDuplicateMemMacros ->
(RenameAnnotatedMemoryPorts, ResolveMemoryReference)
* Fixed to make tests run
* Minor changes from code review
* Removed vim comments and renamed ReplSeqMem
Diffstat (limited to 'src')
17 files changed, 533 insertions, 637 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 9e4c2ec0..2a5a379a 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -81,7 +81,7 @@ Optional Arguments: passes.InferReadWriteAnnotation(value, TransID(-1)) def handleReplSeqMem(value: String) = - passes.ReplSeqMemAnnotation(value, TransID(-2)) + passes.memlib.ReplSeqMemAnnotation(value, TransID(-2)) run(args: Array[String], Map( "high" -> new HighFirrtlCompiler(), diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 307ef9d1..53491922 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -192,7 +192,7 @@ class LowFirrtlCompiler extends Compiler { new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new passes.InferReadWrite(TransID(-1)), - new passes.ReplSeqMem(TransID(-2)), + new passes.memlib.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl, new EmitFirrtl(writer) ) @@ -206,7 +206,7 @@ class VerilogCompiler extends Compiler { new ResolveAndCheck, new HighFirrtlToMiddleFirrtl, new passes.InferReadWrite(TransID(-1)), - new passes.ReplSeqMem(TransID(-2)), + new passes.memlib.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl, new passes.InlineInstances(TransID(0)), new EmitVerilogFromLowFirrtl(writer) diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala deleted file mode 100644 index 21287922..00000000 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ /dev/null @@ -1,143 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import WrappedExpression.weq -import MemPortUtils.memPortField -import AnalysisUtils._ - -case class AppendableInfo(fields: Map[String, Any]) extends Info { - def append(a: Map[String, Any]) = this.copy(fields = fields ++ a) - def append(a: (String, Any)): AppendableInfo = append(Map(a)) - def get(f: String) = fields.get(f) - override def equals(b: Any) = b match { - case i: AppendableInfo => fields - "info" == i.fields - "info" - case _ => false - } -} - -object AnalysisUtils { - type Connects = collection.mutable.HashMap[String, Expression] - def getConnects(m: DefModule): Connects = { - def getConnects(connects: Connects)(s: Statement): Statement = { - s match { - case Connect(_, loc, expr) => - connects(loc.serialize) = expr - case DefNode(_, name, value) => - connects(name) = value - case _ => // do nothing - } - s map getConnects(connects) - } - val connects = new Connects - m map getConnects(connects) - connects - } - - // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS. - // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation, - // the function will try to search backwards past the primop/mux. - // use case: compare if two nodes have the same origin - // limitation: only works in a module (stops @ module inputs) - // TODO: more thorough (i.e. a + 0 = a) - def getConnectOrigin(connects: Connects)(node: String): Expression = - connects get node match { - case None => EmptyExpression - case Some(e) => getOrigin(connects, e) - } - def getConnectOrigin(connects: Connects, e: Expression): Expression = - getConnectOrigin(connects)(e.serialize) - - private def getOrigin(connects: Connects, e: Expression): Expression = e match { - case Mux(cond, tv, fv, _) => - val fvOrigin = getOrigin(connects, fv) - val tvOrigin = getOrigin(connects, tv) - val condOrigin = getOrigin(connects, cond) - if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin - else if (weq(condOrigin, one)) tvOrigin - else if (weq(condOrigin, zero)) fvOrigin - else if (weq(tvOrigin, fvOrigin)) tvOrigin - else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin - else e - case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one - case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero - case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => - val extractionWidth = (msb - lsb) + 1 - val nodeWidth = bitWidth(args.head.tpe) - // if you're extracting the full bitwidth, then keep searching for origin - if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e - case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => - getOrigin(connects, args.head) - // Todo: It's not clear it's ok to call remove validifs before mem passes... - case ValidIf(cond, value, ClockType) => getOrigin(connects, value) - // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) - case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => - connects get e.serialize match { - case Some(ex) => getOrigin(connects, ex) - case None => e - } - case _ => e - } - - def appendInfo[T <: Info](info: T, add: Map[String, Any]) = info match { - case i: AppendableInfo => i.append(add) - case _ => AppendableInfo(fields = add + ("info" -> info)) - } - def appendInfo[T <: Info](info: T, add: (String, Any)): AppendableInfo = appendInfo(info, Map(add)) - def getInfo[T <: Info](info: T, k: String) = info match { - case i: AppendableInfo => i.get(k) - case _ => None - } - def containsInfo[T <: Info](info: T, k: String) = info match { - case i: AppendableInfo => i.fields.contains(k) - case _ => false - } - - // memories equivalent as long as all fields (except name) are the same - def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name) -} - -object AnnotateMemMacros extends Pass { - def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)" - - // returns # of mask bits if used - def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { - val wenOrigin = getConnectOrigin(connects, wen) - val wmaskOrigin = connects.keys filter - (_ startsWith wmask.serialize) map getConnectOrigin(connects) - // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) - val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) - if (redundantMask) None else Some(wmaskOrigin.size) - } - - def updateStmts(connects: Connects)(s: Statement): Statement = s match { - // only annotate memories that are candidates for memory macro replacements - // i.e. rw, w + r (read, write 1 cycle delay) - case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && - (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => - val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (rw => - getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) - val wMasks = m.writers map (w => - getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) - val memAnnotations = Map("useMacro" -> true) - val tempInfo = appendInfo(m.info, memAnnotations) - (rwMasks ++ wMasks).head match { - case None => - m copy (info = tempInfo) - case Some(maskBits) => - m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits)) - } - case sx => sx map updateStmts(connects) - } - - def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) - - def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) -} - -// TODO: Add floorplan info? diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala deleted file mode 100644 index b5149953..00000000 --- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala +++ /dev/null @@ -1,288 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Mappers._ -import Utils.error -import AnalysisUtils._ - -import net.jcazevedo.moultingyaml._ -import scala.collection.mutable -import java.io.{File, CharArrayWriter, PrintWriter} - -object CustomYAMLProtocol extends DefaultYamlProtocol { - // bottom depends on top - implicit val dr = yamlFormat4(DimensionRules) - implicit val md = yamlFormat2(MemDimension) - implicit val sr = yamlFormat4(SRAMRules) - implicit val wm = yamlFormat2(WMaskArg) - implicit val sc = yamlFormat11(SRAMCompiler) -} - -case class DimensionRules( - min: Int, - // step size - inc: Int, - max: Int, - // these values should not be used, regardless of min,inc,max - illegal: Option[List[Int]]) { - def getValid = { - val range = (min to max by inc).toList - range.filterNot(illegal.getOrElse(List[Int]()).toSet) - } -} - -case class MemDimension( - rules: Option[DimensionRules], - set: Option[List[Int]]) { - require ( - if (rules.isEmpty) set.isDefined else set.isEmpty, - "Should specify either rules or a list of valid options, but not both" - ) - def getValid = set.getOrElse(rules.get.getValid).sorted -} - -case class SRAMConfig( - ymux: String = "", - ybank: String = "", - width: Int, - depth: Int, - xsplit: Int = 1, - ysplit: Int = 1) { - // how many duplicate copies of this SRAM are needed - def num = xsplit * ysplit - def serialize(pattern: String): String = { - val fieldMap = getClass.getDeclaredFields.map { f => - f.setAccessible(true) - f.getName -> f.get(this) - }.toMap - - val fieldDelimiter = """\[.*?\]""".r - val configOptions = fieldDelimiter.findAllIn(pattern).toList - - configOptions.foldLeft(pattern)((b, a) => { - // Expects the contents of [] are valid configuration fields (otherwise key match error) - val fieldVal = { - try fieldMap(a.substring(1, a.length-1)) - catch { case e: Exception => error("**SRAM config field incorrect**") } - } - b.replace(a, fieldVal.toString) - } ) - } -} - -// Ex: https://www.ece.cmu.edu/~ece548/hw/hw5/meml80.pdf -case class SRAMRules( - // column mux parameter (for adjusting aspect ratio) - ymux: (Int, String), - // vertical segmentation (banking -- tradeoff performance / area) - ybank: (Int, String), - width: MemDimension, - depth: MemDimension) { - def getValidWidths = width.getValid - def getValidDepths = depth.getValid - def getValidConfig(width: Int, depth: Int): Option[SRAMConfig] = { - if (getValidWidths.contains(width) && getValidDepths.contains(depth)) - Some(SRAMConfig(ymux = ymux._2, ybank = ybank._2, width = width, depth = depth)) - else - None - } - def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue, m.depth) -} - -case class WMaskArg( - t: String, - f: String) - -// vendor-specific compilers -case class SRAMCompiler( - vendor: String, - node: String, - // i.e. RF, SRAM, etc. - memType: String, - portType: String, - wMaskArg: Option[WMaskArg], - // rules for valid SRAM flavors - rules: Seq[SRAMRules], - // path to executable - path: Option[String], - // (output) config file path - configFile: Option[String], - // config pattern - configPattern: Option[String], - // read documentation for details - defaultArgs: Option[String], - // default behavior (if not used) is to have wmask port width = datawidth/maskgran - // if true: wmask port width pre-filled to datawidth - fillWMask: Boolean) { - require(portType == "RW" || portType == "R,W", "Memory must be single port RW or dual port R,W") - require( - (configFile.isDefined && configPattern.isDefined && wMaskArg.isDefined) || configFile.isEmpty, - "Config pattern must be provided with config file" - ) - def ymuxVals = rules.map(_.ymux._1).sortWith(_ < _) - def ybankVals = rules.map(_.ybank._1).sortWith(_ > _) - // TODO: verify this default ordering works out - // optimize search for better FoM (area,power,clk); ymux has more effect - def defaultSearchOrdering = for (x <- ymuxVals; y <- ybankVals) yield { - rules.find(r => r.ymux._1 == x && r.ybank._1 == y).get - } - - private val maskConfigOutputBuffer = new CharArrayWriter - private val noMaskConfigOutputBuffer = new CharArrayWriter - - def append(m: DefMemory) : DefMemory = { - val validCombos = (defaultSearchOrdering map (_ getValidConfig m) - collect { case Some(config) => config }) - // non empty if successfully found compiler option that supports depth/width - // TODO: don't just take first option - val usedConfig = { - if (validCombos.nonEmpty) validCombos.head - else getBestAlternative(m) - } - val usesMaskGran = containsInfo(m.info, "maskGran") - configPattern match { - case None => - case Some(p) => - val newConfig = usedConfig.serialize(p) + "\n" - val currentBuff = { - if (usesMaskGran) maskConfigOutputBuffer - else noMaskConfigOutputBuffer - } - if (!currentBuff.toString.contains(newConfig)) currentBuff append newConfig - } - val temp = appendInfo(m.info, "sramConfig" -> usedConfig) - val newInfo = if (usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp - m copy (info = newInfo) - } - - // TODO: Should you really be splitting in 2 if, say, depth is 1 more than allowed? should be thresholded and - // handled w/ a separate set of registers ? - // split memory until width, depth achievable via given memory compiler - private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = { - val validXRange = mutable.ArrayBuffer[SRAMRules]() - val validYRange = mutable.ArrayBuffer[SRAMRules]() - defaultSearchOrdering foreach { r => - if (m.width <= r.getValidWidths.max) validXRange += r - if (m.depth <= r.getValidDepths.max) validYRange += r - } - (validXRange.isEmpty, validYRange.isEmpty) match { - case (true, true) => - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2)) - case (true, false) => - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) - case (false, true) => - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) - case (false, false) if validXRange.intersect(validYRange).nonEmpty => - Seq(m) - case (false, false) => - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++ - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) - } - } - - private def getBestAlternative(m: DefMemory): SRAMConfig = { - val validConfigs = getInRange(SRAMConfig(width = bitWidth(m.dataType).intValue, depth = m.depth)) - val minNum = validConfigs.map(_.num).min - val validMinConfigs = validConfigs.filter(_.num == minNum) - val validMinConfigsSquareness = validMinConfigs.map( - x => math.abs(x.width.toDouble / x.depth - 1) -> x).toMap - val squarestAspectRatio = validMinConfigsSquareness.unzip._1.min - val validConfig = validMinConfigsSquareness(squarestAspectRatio) - val validRules = defaultSearchOrdering filter (r => - validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) - // TODO: don't just take first option - // TODO: More optimal split if particular value is in range but not supported - // TODO: Support up to 2 read ports, 2 write ports; should be power of 2? - val bestRule = validRules.head - val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get - val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get - (bestRule.getValidConfig(width = memWidth, depth = memDepth).get - copy (xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)) - } - - // TODO - def serialize = ??? -} - -// TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true? -// Or is this where module-level transforms (rather than circuit-level) make sense? -class YamlFileReader(file: String) { - import CustomYAMLProtocol._ - def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { - if (new File(file).exists) { - val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n") - yamlString.parseYamls flatMap (x => - try Some(reader read x) - catch { case e: Exception => None } - ) - } - else error("Yaml file doesn't exist!") - } -} - -class YamlFileWriter(file: String) { - import CustomYAMLProtocol._ - val outputBuffer = new CharArrayWriter - val separator = "--- \n" - def append(in: YamlValue) { - outputBuffer append s"$separator${in.prettyPrint}" - } - def dump() { - val outputFile = new PrintWriter(file) - outputFile write outputBuffer.toString - outputFile.close() - } -} - -class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { - import CustomYAMLProtocol._ - def name = "Annotate memories with valid split depths, widths, #\'s" - - // TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist - // TODO: Don't get first available? - case class SRAMCompilerSet( - sp: Option[SRAMCompiler] = None, - dp: Option[SRAMCompiler] = None) { - def serialize() = { - sp match { - case None => - case Some(p) => p.serialize - } - dp match { - case None => - case Some(p) => p.serialize - } - } - } - - val sramCompilers = reader match { - case None => None - case Some(r) => - val compilers = r.parse[SRAMCompiler] - val sp = compilers find (_.portType == "RW") - val dp = compilers find (_.portType == "R,W") - Some(SRAMCompilerSet(sp = sp, dp = dp)) - } - - def updateStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => sramCompilers match { - case None => m - case Some(compiler) if m.readwriters.length == 1 => - compiler.sp match { - case None => error("Design needs RW port memory compiler!") - case Some(p) => p append m - } - case Some(compiler) => - compiler.dp match { - case None => error("Design needs R,W port memory compiler!") - case Some(p) => p append m - } - } - case sx => sx map updateStmts - } - - def run(c: Circuit) = c copy (modules = c.modules map (_ map updateStmts)) -} diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala index a1875ae7..9adbdd95 100644 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -32,8 +32,9 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} +import firrtl.passes.memlib._ import MemPortUtils.memPortField -import AnalysisUtils.{Connects, getConnects, getConnectOrigin} +import AnalysisUtils.{Connects, getConnects, getOrigin} import WrappedExpression.weq import Annotations._ @@ -117,8 +118,8 @@ object InferReadWritePass extends Pass { for (w <- mem.writers ; r <- mem.readers) { val wp = getProductTerms(connects)(memPortField(mem, w, "en")) val rp = getProductTerms(connects)(memPortField(mem, r, "en")) - val wclk = getConnectOrigin(connects, memPortField(mem, w, "clk")) - val rclk = getConnectOrigin(connects, memPortField(mem, r, "clk")) + val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) + val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) { val rw = namespace newName "rw" val rwExp = createSubField(createRef(mem.name), rw) diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala index 92673433..8cd58afb 100644 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ b/src/main/scala/firrtl/passes/MemUtils.scala @@ -43,39 +43,55 @@ object seqCat { } } +/** Given an expression, return an expression consisting of all sub-expressions + * concatenated (or flattened). + */ object toBits { def apply(e: Expression): Expression = e match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex, ex.tpe) + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex) case t => error("Invalid operand expression for toBits!") } - private def hiercat(e: Expression, dt: Type): Expression = dt match { + private def hiercat(e: Expression): Expression = e.tpe match { case t: VectorType => seqCat((0 until t.size) map (i => - hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe))) + hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER)))) case t: BundleType => seqCat(t.fields map (f => - hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) + hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER)))) case t: GroundType => e case t => error("Unknown type encountered in toBits!") } } -// TODO: make easier to understand +/** Given a mask, return a bitmask corresponding to the desired datatype. + * Requirements: + * - The mask type and datatype must be equivalent, except any ground type in + * datatype must be matched by a 1-bit wide UIntType. + * - The mask must be a reference, subfield, or subindex + * The bitmask is a series of concatenations of the single mask bit over the + * length of the corresponding ground type, e.g.: + *{{{ + * wire mask: {x: UInt<1>, y: UInt<1>} + * wire data: {x: UInt<2>, y: SInt<2>} + * // this would return: + * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) + * }}} + */ object toBitMask { - def apply(e: Expression, dataType: Type): Expression = e match { - case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, ex.tpe, dataType) + def apply(mask: Expression, dataType: Type): Expression = mask match { + case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType) case t => error("Invalid operand expression for toBits!") } - private def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = - (maskType, dataType) match { + private def hiermask(mask: Expression, dataType: Type): Expression = + (mask.tpe, dataType) match { case (mt: VectorType, dt: VectorType) => seqCat((0 until mt.size).reverse map { i => - hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe) + hiermask(WSubIndex(mask, i, mt.tpe, UNKNOWNGENDER), dt.tpe) }) case (mt: BundleType, dt: BundleType) => seqCat((mt.fields zip dt.fields) map { case (mf, df) => - hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) + hiermask(WSubField(mask, mf.name, mf.tpe, UNKNOWNGENDER), df.tpe) }) - case (mt: UIntType, dt: GroundType) => - seqCat(List.fill(bitWidth(dt).intValue)(e)) + case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) => + seqCat(List.fill(bitWidth(dt).intValue)(mask)) case (mt, dt) => error("Invalid type for mask component!") } } @@ -153,7 +169,7 @@ object createSubField { } object connectFields { - def apply(lref: Expression, lname: String, rref: Expression, rname: String) = + def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect = Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) } @@ -166,14 +182,14 @@ object MemPortUtils { type Memories = collection.mutable.ArrayBuffer[DefMemory] type Modules = collection.mutable.ArrayBuffer[DefModule] - def defaultPortSeq(mem: DefMemory) = Seq( + def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))), Field("en", Default, BoolType), Field("clk", Default, ClockType) ) // Todo: merge it with memToBundle - def memType(mem: DefMemory) = { + def memType(mem: DefMemory): Type = { val rType = BundleType(defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) val wType = BundleType(defaultPortSeq(mem) ++ Seq( @@ -190,7 +206,7 @@ object MemPortUtils { (mem.readwriters map (Field(_, Flip, rwType)))) } - def memPortField(s: DefMemory, p: String, f: String) = { + def memPortField(s: DefMemory, p: String, f: String): Expression = { val mem = WRef(s.name, memType(s), MemKind, UNKNOWNGENDER) val t1 = field_type(mem.tpe, p) val t2 = field_type(t1, f) diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala deleted file mode 100644 index 2e6c3338..00000000 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ /dev/null @@ -1,155 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import AnalysisUtils._ -import MemPortUtils._ -import MemTransformUtils._ - -object MemTransformUtils { - def getFillWMask(mem: DefMemory) = - getInfo(mem.info, "maskGran") match { - case None => false - case Some(maskGran) => maskGran == 1 - } - - def rPortToBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) - def rPortToFlattenBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) - - def wPortToBundle(mem: DefMemory) = BundleType( - (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ - (if (!containsInfo(mem.info, "maskGran")) Nil - else Seq(Field("mask", Default, createMask(mem.dataType)))) - ) - def wPortToFlattenBundle(mem: DefMemory) = BundleType( - (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ - (if (!containsInfo(mem.info, "maskGran")) Nil - else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType))) - else Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))) - ) - // TODO: Don't use createMask??? - - def rwPortToBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) ++ Seq( - Field("wmode", Default, BoolType), - Field("wdata", Default, mem.dataType), - Field("rdata", Flip, mem.dataType) - ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil - else Seq(Field("wmask", Default, createMask(mem.dataType))) - ) - ) - - def rwPortToFlattenBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) ++ Seq( - Field("wmode", Default, BoolType), - Field("wdata", Default, flattenType(mem.dataType)), - Field("rdata", Flip, flattenType(mem.dataType)) - ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil - else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType))) - else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) - ) - ) - - def memToBundle(s: DefMemory) = BundleType( - s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) - - def memToFlattenBundle(s: DefMemory) = BundleType( - s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) - - - def getMemPortMap(m: DefMemory) = { - val memPortMap = new MemPortMap - val defaultFields = Seq("addr", "en", "clk") - val rFields = defaultFields :+ "data" - val wFields = rFields :+ "mask" - val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") - - def updateMemPortMap(ports: Seq[String], fields: Seq[String], portType: String) = - for ((p, i) <- ports.zipWithIndex; f <- fields) { - val newPort = createSubField(createRef(m.name), portType+i) - val field = createSubField(newPort, f) - memPortMap(s"${m.name}.$p.$f") = field - } - updateMemPortMap(m.readers, rFields, "R") - updateMemPortMap(m.writers, wFields, "W") - updateMemPortMap(m.readwriters, rwFields, "RW") - memPortMap - } - - def createMemProto(m: DefMemory) = { - val rports = m.readers.indices map (i => s"R$i") - val wports = m.writers.indices map (i => s"W$i") - val rwports = m.readwriters.indices map (i => s"RW$i") - m copy (readers = rports, writers = wports, readwriters = rwports) - } - - def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { - def updateRef(e: Expression): Expression = { - val ex = e map updateRef - repl getOrElse (ex.serialize, ex) - } - - def hasEmptyExpr(stmt: Statement): Boolean = { - var foundEmpty = false - def testEmptyExpr(e: Expression): Expression = { - e match { - case EmptyExpression => foundEmpty = true - case _ => - } - e map testEmptyExpr // map must return; no foreach - } - stmt map testEmptyExpr - foundEmpty - } - - def updateStmtRefs(s: Statement): Statement = - s map updateStmtRefs map updateRef match { - case c: Connect if hasEmptyExpr(c) => EmptyStmt - case sx => sx - } - - updateStmtRefs(s) - } - -} - -object UpdateDuplicateMemMacros extends Pass { - - def name = "Convert memory port names to be more meaningful and tag duplicate memories" - - def updateMemStmts(uniqueMems: Memories, - memPortMap: MemPortMap) - (s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => - val updatedMem = createMemProto(m) - memPortMap ++= getMemPortMap(m) - uniqueMems find (x => eqMems(x, updatedMem)) match { - case None => - uniqueMems += updatedMem - updatedMem - case Some(proto) => - updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name)) - } - case sx => sx map updateMemStmts(uniqueMems, memPortMap) - } - - def updateMemMods(m: DefModule) = { - val uniqueMems = new Memories - val memPortMap = new MemPortMap - (m map updateMemStmts(uniqueMems, memPortMap) - map updateStmtRefs(memPortMap)) - } - - def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) -} -// TODO: Module namespace? diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala new file mode 100644 index 00000000..6dca5961 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import Utils.indent + +case class DefAnnotatedMemory( + info: Info, + name: String, + dataType: Type, + depth: Int, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String], + readUnderWrite: Option[String], + maskGran: Option[BigInt], + memRef: Option[String] + //pins: Seq[Pin], + ) extends Statement with IsDeclaration { + def serialize: String = this.toMem.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def toMem = DefMemory(info, name, dataType, depth, + writeLatency, readLatency, readers, writers, + readwriters, readUnderWrite) +} diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala new file mode 100644 index 00000000..78a386b2 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala @@ -0,0 +1,47 @@ +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import AnalysisUtils._ +import MemPortUtils.{MemPortMap} + +object MemTransformUtils { + + /** Replaces references to old memory port names with new memory port names + */ + def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { + //TODO(izraelevitz): check speed + def updateRef(e: Expression): Expression = { + val ex = e map updateRef + repl getOrElse (ex.serialize, ex) + } + + def hasEmptyExpr(stmt: Statement): Boolean = { + var foundEmpty = false + def testEmptyExpr(e: Expression): Expression = { + e match { + case EmptyExpression => foundEmpty = true + case _ => + } + e map testEmptyExpr // map must return; no foreach + } + stmt map testEmptyExpr + foundEmpty + } + + def updateStmtRefs(s: Statement): Statement = + s map updateStmtRefs map updateRef match { + case c: Connect if hasEmptyExpr(c) => EmptyStmt + case s => s + } + + updateStmtRefs(s) + } + + def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem) + def memPortField(s: DefAnnotatedMemory, p: String, f: String): Expression = + MemPortUtils.memPortField(s.toMem, p, f) +} diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala new file mode 100644 index 00000000..168a6a48 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -0,0 +1,77 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import AnalysisUtils._ +import MemPortUtils._ +import MemTransformUtils._ + + +/** Changes memory port names to standard port names (i.e. RW0 instead T_408) + */ +object RenameAnnotatedMemoryPorts extends Pass { + + def name = "Rename Annotated Memory Ports" + + /** Renames memory ports to a standard naming scheme: + * - R0, R1, ... for each read port + * - W0, W1, ... for each write port + * - RW0, RW1, ... for each readwrite port + */ + def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = { + val rports = m.readers.indices map (i => s"R$i") + val wports = m.writers.indices map (i => s"W$i") + val rwports = m.readwriters.indices map (i => s"RW$i") + m copy (readers = rports, writers = wports, readwriters = rwports) + } + + /** Maps the serialized form of all memory port field names to the + * corresponding new memory port field Expression. + * E.g.: + * - ("m.read.addr") becomes (m.R0.addr) + */ + def getMemPortMap(m: DefAnnotatedMemory): MemPortMap = { + val memPortMap = new MemPortMap + val defaultFields = Seq("addr", "en", "clk") + val rFields = defaultFields :+ "data" + val wFields = rFields :+ "mask" + val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") + + def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit = + for ((p, i) <- ports.zipWithIndex; f <- fields) { + val newPort = createSubField(createRef(m.name), newPortKind + i) + val field = createSubField(newPort, f) + memPortMap(s"${m.name}.$p.$f") = field + } + updateMemPortMap(m.readers, rFields, "R") + updateMemPortMap(m.writers, wFields, "W") + updateMemPortMap(m.readwriters, rwFields, "RW") + memPortMap + } + + /** Replaces candidate memories with memories with standard port names + * Does not update the references (this is done via updateStmtRefs) + */ + def updateMemStmts(memPortMap: MemPortMap)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + val updatedMem = createMemProto(m) + memPortMap ++= getMemPortMap(m) + updatedMem + case s => s map updateMemStmts(memPortMap) + } + + /** Replaces candidate memories and their references with standard port names + */ + def updateMemMods(m: DefModule) = { + val memPortMap = new MemPortMap + (m map updateMemStmts(memPortMap) + map updateStmtRefs(memPortMap)) + } + + def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) +} diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index c3516283..3139ef21 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -1,28 +1,94 @@ // See LICENSE for license details. package firrtl.passes +package memlib import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ -import MemPortUtils._ +import MemPortUtils.{MemPortMap, Modules} import MemTransformUtils._ import AnalysisUtils._ +/** Replace DefAnnotatedMemory with memory blackbox + wrapper + conf file. + * This will not generate wmask ports if not needed. + * Creates the minimum # of black boxes needed by the design. + */ class ReplaceMemMacros(writer: ConfWriter) extends Pass { - def name = "Replace memories with black box wrappers" + - " (optimizes when write mask isn't needed) + configuration file" + def name = "Replace Memory Macros" - // from Albert - def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = { + /** Return true if mask granularity is per bit, false if per byte or unspecified + */ + private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match { + case None => false + case Some(v) => v == 1 + } + + private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + + private def wPortToBundle(mem: DefAnnotatedMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ (mem.maskGran match { + case None => Nil + case Some(_) => Seq(Field("mask", Default, createMask(mem.dataType))) + }) + ) + private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match { + case None => Nil + case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType))) + case Some(_) => Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))) + }) + ) + // TODO(shunshou): Don't use createMask??? + + private def rwPortToBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("rdata", Flip, mem.dataType) + ) ++ (mem.maskGran match { + case None => Nil + case Some(_) => Seq(Field("wmask", Default, createMask(mem.dataType))) + }) + ) + private def rwPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), + Field("wdata", Default, flattenType(mem.dataType)), + Field("rdata", Flip, flattenType(mem.dataType)) + ) ++ (mem.maskGran match { + case None => Nil + case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType))) + case Some(_) => Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) + }) + ) + + def memToBundle(s: DefAnnotatedMemory) = BundleType( + s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) + def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( + s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) + + /** Creates a wrapper module and external module to replace a candidate memory + * The wrapper module has the same type as the memory it replaces + * The external module + */ + def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) val wrapperIoType = memToBundle(m) val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + // Creates a type with the write/readwrite masks omitted if necessary val bbIoType = memToFlattenBundle(m) val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) val bbRef = createRef(m.name, bbIoType) - val hasMask = containsInfo(m.info, "maskGran") + val hasMask = m.maskGran.isDefined val fillMask = getFillWMask(m) def portRef(p: String) = createRef(p, field_type(wrapperIoType, p)) val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ @@ -38,18 +104,20 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { Seq(bb, wrapper) } - // TODO: get rid of copy pasta - def defaultConnects(wrapperPort: WRef, bbPort: WSubField) = + // TODO(shunshou): get rid of copy pasta + // Connects the clk, en, and addr fields from the wrapperPort to the bbPort + def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) - def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) = + // Connects the clk, en, and addr fields from the wrapperPort to the bbPort + def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = if (fillMask) toBitMask(mask, dataType) else toBits(mask) - def adaptReader(wrapperPort: WRef, bbPort: WSubField) = + def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = defaultConnects(wrapperPort, bbPort) :+ fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data")) - def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = { + def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { val wrapperData = createSubField(wrapperPort, "data") val defaultSeq = defaultConnects(wrapperPort, bbPort) :+ Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData)) @@ -63,7 +131,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { } } - def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = { + def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { val wrapperWData = createSubField(wrapperPort, "wdata") val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")), @@ -83,25 +151,21 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { memPortMap: MemPortMap, memMods: Modules) (s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => - if (!containsInfo(m.info, "maskGran")) { + case m: DefAnnotatedMemory => + if (m.maskGran.isEmpty) { m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } } - val info = getInfo(m.info, "info") match { - case None => NoInfo - case Some(p: Info) => p - } - getInfo(m.info, "ref") match { + m.memRef match { case None => // prototype mem val newWrapperName = namespace newName m.name val newMemBBName = namespace newName s"${m.name}_ext" val newMem = m copy (name = newMemBBName) memMods ++= createMemModule(newMem, newWrapperName) - WDefInstance(info, m.name, newWrapperName, UnknownType) + WDefInstance(m.info, m.name, newWrapperName, UnknownType) case Some(ref: String) => - WDefInstance(info, m.name, ref, UnknownType) + WDefInstance(m.info, m.name, ref, UnknownType) } case sx => sx map updateMemStmts(namespace, memPortMap, memMods) } diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 62546a84..dfa828c9 100644 --- a/src/main/scala/firrtl/passes/ReplSeqMem.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -1,6 +1,7 @@ // See LICENSE for license details. package firrtl.passes +package memlib import firrtl._ import firrtl.ir._ @@ -40,9 +41,9 @@ object PassConfigUtil { class ConfWriter(filename: String) { val outputBuffer = new CharArrayWriter - def append(m: DefMemory) = { + def append(m: DefAnnotatedMemory) = { // legacy - val maskGran = getInfo(m.info, "maskGran") + val maskGran = m.maskGran val readers = List.fill(m.readers.length)("read") val writers = List.fill(m.writers.length)(if (maskGran.isEmpty) "write" else "mwrite") val readwriters = List.fill(m.readwriters.length)(if (maskGran.isEmpty) "rw" else "mrw") @@ -94,13 +95,16 @@ Optional Arguments: class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) = Seq(Legalize, - AnnotateMemMacros, - UpdateDuplicateMemMacros, - new AnnotateValidMemConfigs(inConfigFile), + ToMemIR, + ResolveMaskGranularity, + RenameAnnotatedMemoryPorts, + ResolveMemoryReference, + //new AnnotateValidMemConfigs(inConfigFile), new ReplaceMemMacros(outConfigFile), RemoveEmpty, CheckInitialization, InferTypes, + Uniquify, ResolveKinds, // Must be run for the transform to work! ResolveGenders) diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala new file mode 100644 index 00000000..a8ff9fe3 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -0,0 +1,124 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import WrappedExpression.weq +import AnalysisUtils._ +import MemTransformUtils._ + +object AnalysisUtils { + type Connects = collection.mutable.HashMap[String, Expression] + + /** Builds a map from named component to assigned value + * Named components are serialized LHS of connections, nodes, invalids + */ + def getConnects(m: DefModule): Connects = { + def getConnects(connects: Connects)(s: Statement): Statement = { + s match { + case Connect(_, loc, expr) => + connects(loc.serialize) = expr + case DefNode(_, name, value) => + connects(name) = value + case IsInvalid(_, value) => + connects(value.serialize) = WInvalid + case _ => // do nothing + } + s map getConnects(connects) + } + val connects = new Connects + m map getConnects(connects) + connects + } + + /** Find a connection LHS's origin from a module's list of node-to-node connections + * regardless of whether constant propagation has been run. + * Will search past trivial primop/mux's which do not affect its origin. + * Limitations: + * - Only works in a module (stops @ module inputs) + * - Only does trivial primop/mux's (is not complete) + * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a) + */ + def getOrigin(connects: Connects, s: String): Expression = + getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER)) + def getOrigin(connects: Connects)(e: Expression): Expression = e match { + case Mux(cond, tv, fv, _) => + val fvOrigin = getOrigin(connects)(fv) + val tvOrigin = getOrigin(connects)(tv) + val condOrigin = getOrigin(connects)(cond) + if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin + else if (weq(condOrigin, one)) tvOrigin + else if (weq(condOrigin, zero)) fvOrigin + else if (weq(tvOrigin, fvOrigin)) tvOrigin + else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin + else e + case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero + case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => + val extractionWidth = (msb - lsb) + 1 + val nodeWidth = bitWidth(args.head.tpe) + // if you're extracting the full bitwidth, then keep searching for origin + if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e + case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => + getOrigin(connects)(args.head) + case ValidIf(cond, value, ClockType) => getOrigin(connects)(value) + // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => + connects get e.serialize match { + case Some(ex) => getOrigin(connects)(ex) + case None => e + } + case _ => e + } + + /** Checks whether the two memories are equivalent in all respects except name + */ + def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = a == b.copy(name = a.name) +} + +/** Determines if a write mask is needed (wmode/en and wmask are equivalent). + * Populates the maskGran field of DefAnnotatedMemory + * Annotations: + * - maskGran = (dataType size) / (number of mask bits) + * - i.e. 1 if bitmask, 8 if bytemask, absent for no mask + * TODO(shunshou): Add floorplan info? + */ +object ResolveMaskGranularity extends Pass { + def name = "Resolve Mask Granularity" + + /** Returns the number of mask bits, if used + */ + def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { + val wenOrigin = getOrigin(connects)(wen) + val wmaskOrigin = connects.keys filter + (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)} + // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) + val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + if (redundantMask) None else Some(wmaskOrigin.size) + } + + /** Only annotate memories that are candidates for memory macro replacements + * i.e. rw, w + r (read, write 1 cycle delay) + */ + def updateStmts(connects: Connects)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + val dataBits = bitWidth(m.dataType) + val rwMasks = m.readwriters map (rw => + getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers map (w => + getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val maskGran = (rwMasks ++ wMasks).head match { + case None => None + case Some(maskBits) => Some(dataBits / maskBits) + } + m.copy(maskGran = maskGran) + case sx => sx map updateStmts(connects) + } + + def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) + def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) +} diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala new file mode 100644 index 00000000..783c179f --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -0,0 +1,38 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib +import firrtl.ir._ +import AnalysisUtils.eqMems +import firrtl.Mappers._ + + +/** Resolves annotation ref to memories that exactly match (except name) another memory + */ +object ResolveMemoryReference extends Pass { + + def name = "Resolve Memory Reference" + + type AnnotatedMemories = collection.mutable.ArrayBuffer[DefAnnotatedMemory] + + /** If a candidate memory is identical except for name to another, add an + * annotation that references the name of the other memory. + */ + def updateMemStmts(uniqueMems: AnnotatedMemories)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + uniqueMems find (x => eqMems(x, m)) match { + case None => + uniqueMems += m + m + case Some(proto) => m copy (memRef = Some(proto.name)) + } + case s => s map updateMemStmts(uniqueMems) + } + + def updateMemMods(m: DefModule) = { + val uniqueMems = new AnnotatedMemories + (m map updateMemStmts(uniqueMems)) + } + + def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) +} diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala new file mode 100644 index 00000000..741ea5ef --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -0,0 +1,41 @@ +package firrtl.passes +package memlib + +import firrtl.Mappers._ +import firrtl.ir._ + +/** Annotates sequential memories that are candidates for macro replacement. + * Requirements for macro replacement: + * - read latency and write latency of one + * - only one readwrite port or write port + * - zero or one read port + */ +object ToMemIR extends Pass { + def name = "To Memory IR" + + /** Only annotate memories that are candidates for memory macro replacements + * i.e. rw, w + r (read, write 1 cycle delay) + */ + def updateStmts(s: Statement): Statement = s match { + case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && + (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => + DefAnnotatedMemory( + m.info, + m.name, + m.dataType, + m.depth, + m.writeLatency, + m.readLatency, + m.readers, + m.writers, + m.readwriters, + m.readUnderWrite, + None, // mask granularity annotation + None // No reference yet to another memory + ) + case sx => sx map updateStmts + } + + def annotateModMems(m: DefModule) = m map updateStmts + def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) +} diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala new file mode 100644 index 00000000..a1088300 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -0,0 +1,36 @@ +package firrtl.passes +package memlib +import net.jcazevedo.moultingyaml._ +import java.io.{File, CharArrayWriter, PrintWriter} + +object CustomYAMLProtocol extends DefaultYamlProtocol { + // bottom depends on top +} + +class YamlFileReader(file: String) { + import CustomYAMLProtocol._ + def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { + if (new File(file).exists) { + val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n") + yamlString.parseYamls flatMap (x => + try Some(reader read x) + catch { case e: Exception => None } + ) + } + else error("Yaml file doesn't exist!") + } +} + +class YamlFileWriter(file: String) { + import CustomYAMLProtocol._ + val outputBuffer = new CharArrayWriter + val separator = "--- \n" + def append(in: YamlValue) { + outputBuffer append s"$separator${in.prettyPrint}" + } + def dump() { + val outputFile = new PrintWriter(file) + outputFile write outputBuffer.toString + outputFile.close() + } +} diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 8aeafc9e..277623cf 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -2,6 +2,7 @@ package firrtlTests import firrtl._ import firrtl.passes._ +import firrtl.passes.memlib._ import Annotations._ class ReplSeqMemSpec extends SimpleTransformSpec { @@ -13,7 +14,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), new passes.InferReadWrite(TransID(-1)), - new passes.ReplSeqMem(TransID(-2)), + new passes.memlib.ReplSeqMem(TransID(-2)), new MiddleFirrtlToLowFirrtl(), (new Transform with SimpleRun { def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) } ), @@ -107,7 +108,7 @@ circuit Top : val circuit = InferTypes.run(ToWorkingIR.run(parse(input))) val m = circuit.modules.head.asInstanceOf[ir.Module] val connects = AnalysisUtils.getConnects(m) - val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects)("f").serialize + val calculatedOrigin = AnalysisUtils.getOrigin(connects, "f").serialize require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !") } |
