aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-10-17 18:53:19 -0700
committerAngie Wang2016-10-17 18:53:19 -0700
commit85baeda249e59c7d9d9f159aaf29ff46d685cf02 (patch)
treecfb5f4a6a0a80f9033275de6e5e36b9d5b96faad /src
parent7d08b9a1486fef0459481f6e542464a29fbe1db5 (diff)
Reorganized memory blackboxing (#336)
* Reorganized memory blackboxing Moved to new package memlib Added comments Moved utility functions around Removed unused AnnotateValidMemConfigs.scala * Fixed tests to pass * Use DefAnnotatedMemory instead of AppendableInfo * Broke passes up into simpler passes AnnotateMemMacros -> (ToMemIR, ResolveMaskGranularity) UpdateDuplicateMemMacros -> (RenameAnnotatedMemoryPorts, ResolveMemoryReference) * Fixed to make tests run * Minor changes from code review * Removed vim comments and renamed ReplSeqMem
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Driver.scala2
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala4
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala143
-rw-r--r--src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala288
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala7
-rw-r--r--src/main/scala/firrtl/passes/MemUtils.scala50
-rw-r--r--src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala155
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala33
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala47
-rw-r--r--src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala77
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala (renamed from src/main/scala/firrtl/passes/ReplaceMemMacros.scala)106
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala (renamed from src/main/scala/firrtl/passes/ReplSeqMem.scala)14
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala124
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala38
-rw-r--r--src/main/scala/firrtl/passes/memlib/ToMemIR.scala41
-rw-r--r--src/main/scala/firrtl/passes/memlib/YamlUtils.scala36
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala5
17 files changed, 533 insertions, 637 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala
index 9e4c2ec0..2a5a379a 100644
--- a/src/main/scala/firrtl/Driver.scala
+++ b/src/main/scala/firrtl/Driver.scala
@@ -81,7 +81,7 @@ Optional Arguments:
passes.InferReadWriteAnnotation(value, TransID(-1))
def handleReplSeqMem(value: String) =
- passes.ReplSeqMemAnnotation(value, TransID(-2))
+ passes.memlib.ReplSeqMemAnnotation(value, TransID(-2))
run(args: Array[String],
Map( "high" -> new HighFirrtlCompiler(),
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 307ef9d1..53491922 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -192,7 +192,7 @@ class LowFirrtlCompiler extends Compiler {
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new passes.InferReadWrite(TransID(-1)),
- new passes.ReplSeqMem(TransID(-2)),
+ new passes.memlib.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl,
new EmitFirrtl(writer)
)
@@ -206,7 +206,7 @@ class VerilogCompiler extends Compiler {
new ResolveAndCheck,
new HighFirrtlToMiddleFirrtl,
new passes.InferReadWrite(TransID(-1)),
- new passes.ReplSeqMem(TransID(-2)),
+ new passes.memlib.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl,
new passes.InlineInstances(TransID(0)),
new EmitVerilogFromLowFirrtl(writer)
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
deleted file mode 100644
index 21287922..00000000
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ /dev/null
@@ -1,143 +0,0 @@
-// See LICENSE for license details.
-
-package firrtl.passes
-
-import firrtl._
-import firrtl.ir._
-import firrtl.Utils._
-import firrtl.Mappers._
-import WrappedExpression.weq
-import MemPortUtils.memPortField
-import AnalysisUtils._
-
-case class AppendableInfo(fields: Map[String, Any]) extends Info {
- def append(a: Map[String, Any]) = this.copy(fields = fields ++ a)
- def append(a: (String, Any)): AppendableInfo = append(Map(a))
- def get(f: String) = fields.get(f)
- override def equals(b: Any) = b match {
- case i: AppendableInfo => fields - "info" == i.fields - "info"
- case _ => false
- }
-}
-
-object AnalysisUtils {
- type Connects = collection.mutable.HashMap[String, Expression]
- def getConnects(m: DefModule): Connects = {
- def getConnects(connects: Connects)(s: Statement): Statement = {
- s match {
- case Connect(_, loc, expr) =>
- connects(loc.serialize) = expr
- case DefNode(_, name, value) =>
- connects(name) = value
- case _ => // do nothing
- }
- s map getConnects(connects)
- }
- val connects = new Connects
- m map getConnects(connects)
- connects
- }
-
- // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS.
- // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation,
- // the function will try to search backwards past the primop/mux.
- // use case: compare if two nodes have the same origin
- // limitation: only works in a module (stops @ module inputs)
- // TODO: more thorough (i.e. a + 0 = a)
- def getConnectOrigin(connects: Connects)(node: String): Expression =
- connects get node match {
- case None => EmptyExpression
- case Some(e) => getOrigin(connects, e)
- }
- def getConnectOrigin(connects: Connects, e: Expression): Expression =
- getConnectOrigin(connects)(e.serialize)
-
- private def getOrigin(connects: Connects, e: Expression): Expression = e match {
- case Mux(cond, tv, fv, _) =>
- val fvOrigin = getOrigin(connects, fv)
- val tvOrigin = getOrigin(connects, tv)
- val condOrigin = getOrigin(connects, cond)
- if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
- else if (weq(condOrigin, one)) tvOrigin
- else if (weq(condOrigin, zero)) fvOrigin
- else if (weq(tvOrigin, fvOrigin)) tvOrigin
- else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
- else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
- case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
- val extractionWidth = (msb - lsb) + 1
- val nodeWidth = bitWidth(args.head.tpe)
- // if you're extracting the full bitwidth, then keep searching for origin
- if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e
- case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
- getOrigin(connects, args.head)
- // Todo: It's not clear it's ok to call remove validifs before mem passes...
- case ValidIf(cond, value, ClockType) => getOrigin(connects, value)
- // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
- case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
- connects get e.serialize match {
- case Some(ex) => getOrigin(connects, ex)
- case None => e
- }
- case _ => e
- }
-
- def appendInfo[T <: Info](info: T, add: Map[String, Any]) = info match {
- case i: AppendableInfo => i.append(add)
- case _ => AppendableInfo(fields = add + ("info" -> info))
- }
- def appendInfo[T <: Info](info: T, add: (String, Any)): AppendableInfo = appendInfo(info, Map(add))
- def getInfo[T <: Info](info: T, k: String) = info match {
- case i: AppendableInfo => i.get(k)
- case _ => None
- }
- def containsInfo[T <: Info](info: T, k: String) = info match {
- case i: AppendableInfo => i.fields.contains(k)
- case _ => false
- }
-
- // memories equivalent as long as all fields (except name) are the same
- def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name)
-}
-
-object AnnotateMemMacros extends Pass {
- def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)"
-
- // returns # of mask bits if used
- def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
- val wenOrigin = getConnectOrigin(connects, wen)
- val wmaskOrigin = connects.keys filter
- (_ startsWith wmask.serialize) map getConnectOrigin(connects)
- // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
- if (redundantMask) None else Some(wmaskOrigin.size)
- }
-
- def updateStmts(connects: Connects)(s: Statement): Statement = s match {
- // only annotate memories that are candidates for memory macro replacements
- // i.e. rw, w + r (read, write 1 cycle delay)
- case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
- (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
- val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (rw =>
- getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
- val wMasks = m.writers map (w =>
- getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
- val memAnnotations = Map("useMacro" -> true)
- val tempInfo = appendInfo(m.info, memAnnotations)
- (rwMasks ++ wMasks).head match {
- case None =>
- m copy (info = tempInfo)
- case Some(maskBits) =>
- m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits))
- }
- case sx => sx map updateStmts(connects)
- }
-
- def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
-
- def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
-}
-
-// TODO: Add floorplan info?
diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
deleted file mode 100644
index b5149953..00000000
--- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
+++ /dev/null
@@ -1,288 +0,0 @@
-// See LICENSE for license details.
-
-package firrtl.passes
-
-import firrtl._
-import firrtl.ir._
-import firrtl.Mappers._
-import Utils.error
-import AnalysisUtils._
-
-import net.jcazevedo.moultingyaml._
-import scala.collection.mutable
-import java.io.{File, CharArrayWriter, PrintWriter}
-
-object CustomYAMLProtocol extends DefaultYamlProtocol {
- // bottom depends on top
- implicit val dr = yamlFormat4(DimensionRules)
- implicit val md = yamlFormat2(MemDimension)
- implicit val sr = yamlFormat4(SRAMRules)
- implicit val wm = yamlFormat2(WMaskArg)
- implicit val sc = yamlFormat11(SRAMCompiler)
-}
-
-case class DimensionRules(
- min: Int,
- // step size
- inc: Int,
- max: Int,
- // these values should not be used, regardless of min,inc,max
- illegal: Option[List[Int]]) {
- def getValid = {
- val range = (min to max by inc).toList
- range.filterNot(illegal.getOrElse(List[Int]()).toSet)
- }
-}
-
-case class MemDimension(
- rules: Option[DimensionRules],
- set: Option[List[Int]]) {
- require (
- if (rules.isEmpty) set.isDefined else set.isEmpty,
- "Should specify either rules or a list of valid options, but not both"
- )
- def getValid = set.getOrElse(rules.get.getValid).sorted
-}
-
-case class SRAMConfig(
- ymux: String = "",
- ybank: String = "",
- width: Int,
- depth: Int,
- xsplit: Int = 1,
- ysplit: Int = 1) {
- // how many duplicate copies of this SRAM are needed
- def num = xsplit * ysplit
- def serialize(pattern: String): String = {
- val fieldMap = getClass.getDeclaredFields.map { f =>
- f.setAccessible(true)
- f.getName -> f.get(this)
- }.toMap
-
- val fieldDelimiter = """\[.*?\]""".r
- val configOptions = fieldDelimiter.findAllIn(pattern).toList
-
- configOptions.foldLeft(pattern)((b, a) => {
- // Expects the contents of [] are valid configuration fields (otherwise key match error)
- val fieldVal = {
- try fieldMap(a.substring(1, a.length-1))
- catch { case e: Exception => error("**SRAM config field incorrect**") }
- }
- b.replace(a, fieldVal.toString)
- } )
- }
-}
-
-// Ex: https://www.ece.cmu.edu/~ece548/hw/hw5/meml80.pdf
-case class SRAMRules(
- // column mux parameter (for adjusting aspect ratio)
- ymux: (Int, String),
- // vertical segmentation (banking -- tradeoff performance / area)
- ybank: (Int, String),
- width: MemDimension,
- depth: MemDimension) {
- def getValidWidths = width.getValid
- def getValidDepths = depth.getValid
- def getValidConfig(width: Int, depth: Int): Option[SRAMConfig] = {
- if (getValidWidths.contains(width) && getValidDepths.contains(depth))
- Some(SRAMConfig(ymux = ymux._2, ybank = ybank._2, width = width, depth = depth))
- else
- None
- }
- def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue, m.depth)
-}
-
-case class WMaskArg(
- t: String,
- f: String)
-
-// vendor-specific compilers
-case class SRAMCompiler(
- vendor: String,
- node: String,
- // i.e. RF, SRAM, etc.
- memType: String,
- portType: String,
- wMaskArg: Option[WMaskArg],
- // rules for valid SRAM flavors
- rules: Seq[SRAMRules],
- // path to executable
- path: Option[String],
- // (output) config file path
- configFile: Option[String],
- // config pattern
- configPattern: Option[String],
- // read documentation for details
- defaultArgs: Option[String],
- // default behavior (if not used) is to have wmask port width = datawidth/maskgran
- // if true: wmask port width pre-filled to datawidth
- fillWMask: Boolean) {
- require(portType == "RW" || portType == "R,W", "Memory must be single port RW or dual port R,W")
- require(
- (configFile.isDefined && configPattern.isDefined && wMaskArg.isDefined) || configFile.isEmpty,
- "Config pattern must be provided with config file"
- )
- def ymuxVals = rules.map(_.ymux._1).sortWith(_ < _)
- def ybankVals = rules.map(_.ybank._1).sortWith(_ > _)
- // TODO: verify this default ordering works out
- // optimize search for better FoM (area,power,clk); ymux has more effect
- def defaultSearchOrdering = for (x <- ymuxVals; y <- ybankVals) yield {
- rules.find(r => r.ymux._1 == x && r.ybank._1 == y).get
- }
-
- private val maskConfigOutputBuffer = new CharArrayWriter
- private val noMaskConfigOutputBuffer = new CharArrayWriter
-
- def append(m: DefMemory) : DefMemory = {
- val validCombos = (defaultSearchOrdering map (_ getValidConfig m)
- collect { case Some(config) => config })
- // non empty if successfully found compiler option that supports depth/width
- // TODO: don't just take first option
- val usedConfig = {
- if (validCombos.nonEmpty) validCombos.head
- else getBestAlternative(m)
- }
- val usesMaskGran = containsInfo(m.info, "maskGran")
- configPattern match {
- case None =>
- case Some(p) =>
- val newConfig = usedConfig.serialize(p) + "\n"
- val currentBuff = {
- if (usesMaskGran) maskConfigOutputBuffer
- else noMaskConfigOutputBuffer
- }
- if (!currentBuff.toString.contains(newConfig)) currentBuff append newConfig
- }
- val temp = appendInfo(m.info, "sramConfig" -> usedConfig)
- val newInfo = if (usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
- m copy (info = newInfo)
- }
-
- // TODO: Should you really be splitting in 2 if, say, depth is 1 more than allowed? should be thresholded and
- // handled w/ a separate set of registers ?
- // split memory until width, depth achievable via given memory compiler
- private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = {
- val validXRange = mutable.ArrayBuffer[SRAMRules]()
- val validYRange = mutable.ArrayBuffer[SRAMRules]()
- defaultSearchOrdering foreach { r =>
- if (m.width <= r.getValidWidths.max) validXRange += r
- if (m.depth <= r.getValidDepths.max) validYRange += r
- }
- (validXRange.isEmpty, validYRange.isEmpty) match {
- case (true, true) =>
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
- case (true, false) =>
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
- case (false, true) =>
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
- case (false, false) if validXRange.intersect(validYRange).nonEmpty =>
- Seq(m)
- case (false, false) =>
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
- }
- }
-
- private def getBestAlternative(m: DefMemory): SRAMConfig = {
- val validConfigs = getInRange(SRAMConfig(width = bitWidth(m.dataType).intValue, depth = m.depth))
- val minNum = validConfigs.map(_.num).min
- val validMinConfigs = validConfigs.filter(_.num == minNum)
- val validMinConfigsSquareness = validMinConfigs.map(
- x => math.abs(x.width.toDouble / x.depth - 1) -> x).toMap
- val squarestAspectRatio = validMinConfigsSquareness.unzip._1.min
- val validConfig = validMinConfigsSquareness(squarestAspectRatio)
- val validRules = defaultSearchOrdering filter (r =>
- validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max)
- // TODO: don't just take first option
- // TODO: More optimal split if particular value is in range but not supported
- // TODO: Support up to 2 read ports, 2 write ports; should be power of 2?
- val bestRule = validRules.head
- val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get
- val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get
- (bestRule.getValidConfig(width = memWidth, depth = memDepth).get
- copy (xsplit = validConfig.xsplit, ysplit = validConfig.ysplit))
- }
-
- // TODO
- def serialize = ???
-}
-
-// TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true?
-// Or is this where module-level transforms (rather than circuit-level) make sense?
-class YamlFileReader(file: String) {
- import CustomYAMLProtocol._
- def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
- if (new File(file).exists) {
- val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
- yamlString.parseYamls flatMap (x =>
- try Some(reader read x)
- catch { case e: Exception => None }
- )
- }
- else error("Yaml file doesn't exist!")
- }
-}
-
-class YamlFileWriter(file: String) {
- import CustomYAMLProtocol._
- val outputBuffer = new CharArrayWriter
- val separator = "--- \n"
- def append(in: YamlValue) {
- outputBuffer append s"$separator${in.prettyPrint}"
- }
- def dump() {
- val outputFile = new PrintWriter(file)
- outputFile write outputBuffer.toString
- outputFile.close()
- }
-}
-
-class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
- import CustomYAMLProtocol._
- def name = "Annotate memories with valid split depths, widths, #\'s"
-
- // TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist
- // TODO: Don't get first available?
- case class SRAMCompilerSet(
- sp: Option[SRAMCompiler] = None,
- dp: Option[SRAMCompiler] = None) {
- def serialize() = {
- sp match {
- case None =>
- case Some(p) => p.serialize
- }
- dp match {
- case None =>
- case Some(p) => p.serialize
- }
- }
- }
-
- val sramCompilers = reader match {
- case None => None
- case Some(r) =>
- val compilers = r.parse[SRAMCompiler]
- val sp = compilers find (_.portType == "RW")
- val dp = compilers find (_.portType == "R,W")
- Some(SRAMCompilerSet(sp = sp, dp = dp))
- }
-
- def updateStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") => sramCompilers match {
- case None => m
- case Some(compiler) if m.readwriters.length == 1 =>
- compiler.sp match {
- case None => error("Design needs RW port memory compiler!")
- case Some(p) => p append m
- }
- case Some(compiler) =>
- compiler.dp match {
- case None => error("Design needs R,W port memory compiler!")
- case Some(p) => p append m
- }
- }
- case sx => sx map updateStmts
- }
-
- def run(c: Circuit) = c copy (modules = c.modules map (_ map updateStmts))
-}
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
index a1875ae7..9adbdd95 100644
--- a/src/main/scala/firrtl/passes/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -32,8 +32,9 @@ import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.Utils.{one, zero, BoolType}
+import firrtl.passes.memlib._
import MemPortUtils.memPortField
-import AnalysisUtils.{Connects, getConnects, getConnectOrigin}
+import AnalysisUtils.{Connects, getConnects, getOrigin}
import WrappedExpression.weq
import Annotations._
@@ -117,8 +118,8 @@ object InferReadWritePass extends Pass {
for (w <- mem.writers ; r <- mem.readers) {
val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
- val wclk = getConnectOrigin(connects, memPortField(mem, w, "clk"))
- val rclk = getConnectOrigin(connects, memPortField(mem, r, "clk"))
+ val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
+ val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) {
val rw = namespace newName "rw"
val rwExp = createSubField(createRef(mem.name), rw)
diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala
index 92673433..8cd58afb 100644
--- a/src/main/scala/firrtl/passes/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/MemUtils.scala
@@ -43,39 +43,55 @@ object seqCat {
}
}
+/** Given an expression, return an expression consisting of all sub-expressions
+ * concatenated (or flattened).
+ */
object toBits {
def apply(e: Expression): Expression = e match {
- case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex, ex.tpe)
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiercat(ex)
case t => error("Invalid operand expression for toBits!")
}
- private def hiercat(e: Expression, dt: Type): Expression = dt match {
+ private def hiercat(e: Expression): Expression = e.tpe match {
case t: VectorType => seqCat((0 until t.size) map (i =>
- hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe)))
+ hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER))))
case t: BundleType => seqCat(t.fields map (f =>
- hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe)))
+ hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER))))
case t: GroundType => e
case t => error("Unknown type encountered in toBits!")
}
}
-// TODO: make easier to understand
+/** Given a mask, return a bitmask corresponding to the desired datatype.
+ * Requirements:
+ * - The mask type and datatype must be equivalent, except any ground type in
+ * datatype must be matched by a 1-bit wide UIntType.
+ * - The mask must be a reference, subfield, or subindex
+ * The bitmask is a series of concatenations of the single mask bit over the
+ * length of the corresponding ground type, e.g.:
+ *{{{
+ * wire mask: {x: UInt<1>, y: UInt<1>}
+ * wire data: {x: UInt<2>, y: SInt<2>}
+ * // this would return:
+ * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
+ * }}}
+ */
object toBitMask {
- def apply(e: Expression, dataType: Type): Expression = e match {
- case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, ex.tpe, dataType)
+ def apply(mask: Expression, dataType: Type): Expression = mask match {
+ case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType)
case t => error("Invalid operand expression for toBits!")
}
- private def hiermask(e: Expression, maskType: Type, dataType: Type): Expression =
- (maskType, dataType) match {
+ private def hiermask(mask: Expression, dataType: Type): Expression =
+ (mask.tpe, dataType) match {
case (mt: VectorType, dt: VectorType) =>
seqCat((0 until mt.size).reverse map { i =>
- hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe)
+ hiermask(WSubIndex(mask, i, mt.tpe, UNKNOWNGENDER), dt.tpe)
})
case (mt: BundleType, dt: BundleType) =>
seqCat((mt.fields zip dt.fields) map { case (mf, df) =>
- hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe)
+ hiermask(WSubField(mask, mf.name, mf.tpe, UNKNOWNGENDER), df.tpe)
})
- case (mt: UIntType, dt: GroundType) =>
- seqCat(List.fill(bitWidth(dt).intValue)(e))
+ case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) =>
+ seqCat(List.fill(bitWidth(dt).intValue)(mask))
case (mt, dt) => error("Invalid type for mask component!")
}
}
@@ -153,7 +169,7 @@ object createSubField {
}
object connectFields {
- def apply(lref: Expression, lname: String, rref: Expression, rname: String) =
+ def apply(lref: Expression, lname: String, rref: Expression, rname: String): Connect =
Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))
}
@@ -166,14 +182,14 @@ object MemPortUtils {
type Memories = collection.mutable.ArrayBuffer[DefMemory]
type Modules = collection.mutable.ArrayBuffer[DefModule]
- def defaultPortSeq(mem: DefMemory) = Seq(
+ def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq(
Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))),
Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)
// Todo: merge it with memToBundle
- def memType(mem: DefMemory) = {
+ def memType(mem: DefMemory): Type = {
val rType = BundleType(defaultPortSeq(mem) :+
Field("data", Flip, mem.dataType))
val wType = BundleType(defaultPortSeq(mem) ++ Seq(
@@ -190,7 +206,7 @@ object MemPortUtils {
(mem.readwriters map (Field(_, Flip, rwType))))
}
- def memPortField(s: DefMemory, p: String, f: String) = {
+ def memPortField(s: DefMemory, p: String, f: String): Expression = {
val mem = WRef(s.name, memType(s), MemKind, UNKNOWNGENDER)
val t1 = field_type(mem.tpe, p)
val t2 = field_type(t1, f)
diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
deleted file mode 100644
index 2e6c3338..00000000
--- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
+++ /dev/null
@@ -1,155 +0,0 @@
-// See LICENSE for license details.
-
-package firrtl.passes
-
-import firrtl._
-import firrtl.ir._
-import firrtl.Utils._
-import firrtl.Mappers._
-import AnalysisUtils._
-import MemPortUtils._
-import MemTransformUtils._
-
-object MemTransformUtils {
- def getFillWMask(mem: DefMemory) =
- getInfo(mem.info, "maskGran") match {
- case None => false
- case Some(maskGran) => maskGran == 1
- }
-
- def rPortToBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
- def rPortToFlattenBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
-
- def wPortToBundle(mem: DefMemory) = BundleType(
- (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++
- (if (!containsInfo(mem.info, "maskGran")) Nil
- else Seq(Field("mask", Default, createMask(mem.dataType))))
- )
- def wPortToFlattenBundle(mem: DefMemory) = BundleType(
- (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++
- (if (!containsInfo(mem.info, "maskGran")) Nil
- else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType)))
- else Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))))
- )
- // TODO: Don't use createMask???
-
- def rwPortToBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) ++ Seq(
- Field("wmode", Default, BoolType),
- Field("wdata", Default, mem.dataType),
- Field("rdata", Flip, mem.dataType)
- ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
- else Seq(Field("wmask", Default, createMask(mem.dataType)))
- )
- )
-
- def rwPortToFlattenBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) ++ Seq(
- Field("wmode", Default, BoolType),
- Field("wdata", Default, flattenType(mem.dataType)),
- Field("rdata", Flip, flattenType(mem.dataType))
- ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
- else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType)))
- else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
- )
- )
-
- def memToBundle(s: DefMemory) = BundleType(
- s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
-
- def memToFlattenBundle(s: DefMemory) = BundleType(
- s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
-
-
- def getMemPortMap(m: DefMemory) = {
- val memPortMap = new MemPortMap
- val defaultFields = Seq("addr", "en", "clk")
- val rFields = defaultFields :+ "data"
- val wFields = rFields :+ "mask"
- val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
-
- def updateMemPortMap(ports: Seq[String], fields: Seq[String], portType: String) =
- for ((p, i) <- ports.zipWithIndex; f <- fields) {
- val newPort = createSubField(createRef(m.name), portType+i)
- val field = createSubField(newPort, f)
- memPortMap(s"${m.name}.$p.$f") = field
- }
- updateMemPortMap(m.readers, rFields, "R")
- updateMemPortMap(m.writers, wFields, "W")
- updateMemPortMap(m.readwriters, rwFields, "RW")
- memPortMap
- }
-
- def createMemProto(m: DefMemory) = {
- val rports = m.readers.indices map (i => s"R$i")
- val wports = m.writers.indices map (i => s"W$i")
- val rwports = m.readwriters.indices map (i => s"RW$i")
- m copy (readers = rports, writers = wports, readwriters = rwports)
- }
-
- def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
- def updateRef(e: Expression): Expression = {
- val ex = e map updateRef
- repl getOrElse (ex.serialize, ex)
- }
-
- def hasEmptyExpr(stmt: Statement): Boolean = {
- var foundEmpty = false
- def testEmptyExpr(e: Expression): Expression = {
- e match {
- case EmptyExpression => foundEmpty = true
- case _ =>
- }
- e map testEmptyExpr // map must return; no foreach
- }
- stmt map testEmptyExpr
- foundEmpty
- }
-
- def updateStmtRefs(s: Statement): Statement =
- s map updateStmtRefs map updateRef match {
- case c: Connect if hasEmptyExpr(c) => EmptyStmt
- case sx => sx
- }
-
- updateStmtRefs(s)
- }
-
-}
-
-object UpdateDuplicateMemMacros extends Pass {
-
- def name = "Convert memory port names to be more meaningful and tag duplicate memories"
-
- def updateMemStmts(uniqueMems: Memories,
- memPortMap: MemPortMap)
- (s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") =>
- val updatedMem = createMemProto(m)
- memPortMap ++= getMemPortMap(m)
- uniqueMems find (x => eqMems(x, updatedMem)) match {
- case None =>
- uniqueMems += updatedMem
- updatedMem
- case Some(proto) =>
- updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name))
- }
- case sx => sx map updateMemStmts(uniqueMems, memPortMap)
- }
-
- def updateMemMods(m: DefModule) = {
- val uniqueMems = new Memories
- val memPortMap = new MemPortMap
- (m map updateMemStmts(uniqueMems, memPortMap)
- map updateStmtRefs(memPortMap))
- }
-
- def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
-}
-// TODO: Module namespace?
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
new file mode 100644
index 00000000..6dca5961
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -0,0 +1,33 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import Utils.indent
+
+case class DefAnnotatedMemory(
+ info: Info,
+ name: String,
+ dataType: Type,
+ depth: Int,
+ writeLatency: Int,
+ readLatency: Int,
+ readers: Seq[String],
+ writers: Seq[String],
+ readwriters: Seq[String],
+ readUnderWrite: Option[String],
+ maskGran: Option[BigInt],
+ memRef: Option[String]
+ //pins: Seq[Pin],
+ ) extends Statement with IsDeclaration {
+ def serialize: String = this.toMem.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def toMem = DefMemory(info, name, dataType, depth,
+ writeLatency, readLatency, readers, writers,
+ readwriters, readUnderWrite)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
new file mode 100644
index 00000000..78a386b2
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
@@ -0,0 +1,47 @@
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import AnalysisUtils._
+import MemPortUtils.{MemPortMap}
+
+object MemTransformUtils {
+
+ /** Replaces references to old memory port names with new memory port names
+ */
+ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
+ //TODO(izraelevitz): check speed
+ def updateRef(e: Expression): Expression = {
+ val ex = e map updateRef
+ repl getOrElse (ex.serialize, ex)
+ }
+
+ def hasEmptyExpr(stmt: Statement): Boolean = {
+ var foundEmpty = false
+ def testEmptyExpr(e: Expression): Expression = {
+ e match {
+ case EmptyExpression => foundEmpty = true
+ case _ =>
+ }
+ e map testEmptyExpr // map must return; no foreach
+ }
+ stmt map testEmptyExpr
+ foundEmpty
+ }
+
+ def updateStmtRefs(s: Statement): Statement =
+ s map updateStmtRefs map updateRef match {
+ case c: Connect if hasEmptyExpr(c) => EmptyStmt
+ case s => s
+ }
+
+ updateStmtRefs(s)
+ }
+
+ def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem)
+ def memPortField(s: DefAnnotatedMemory, p: String, f: String): Expression =
+ MemPortUtils.memPortField(s.toMem, p, f)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
new file mode 100644
index 00000000..168a6a48
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
@@ -0,0 +1,77 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import AnalysisUtils._
+import MemPortUtils._
+import MemTransformUtils._
+
+
+/** Changes memory port names to standard port names (i.e. RW0 instead T_408)
+ */
+object RenameAnnotatedMemoryPorts extends Pass {
+
+ def name = "Rename Annotated Memory Ports"
+
+ /** Renames memory ports to a standard naming scheme:
+ * - R0, R1, ... for each read port
+ * - W0, W1, ... for each write port
+ * - RW0, RW1, ... for each readwrite port
+ */
+ def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = {
+ val rports = m.readers.indices map (i => s"R$i")
+ val wports = m.writers.indices map (i => s"W$i")
+ val rwports = m.readwriters.indices map (i => s"RW$i")
+ m copy (readers = rports, writers = wports, readwriters = rwports)
+ }
+
+ /** Maps the serialized form of all memory port field names to the
+ * corresponding new memory port field Expression.
+ * E.g.:
+ * - ("m.read.addr") becomes (m.R0.addr)
+ */
+ def getMemPortMap(m: DefAnnotatedMemory): MemPortMap = {
+ val memPortMap = new MemPortMap
+ val defaultFields = Seq("addr", "en", "clk")
+ val rFields = defaultFields :+ "data"
+ val wFields = rFields :+ "mask"
+ val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
+
+ def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
+ for ((p, i) <- ports.zipWithIndex; f <- fields) {
+ val newPort = createSubField(createRef(m.name), newPortKind + i)
+ val field = createSubField(newPort, f)
+ memPortMap(s"${m.name}.$p.$f") = field
+ }
+ updateMemPortMap(m.readers, rFields, "R")
+ updateMemPortMap(m.writers, wFields, "W")
+ updateMemPortMap(m.readwriters, rwFields, "RW")
+ memPortMap
+ }
+
+ /** Replaces candidate memories with memories with standard port names
+ * Does not update the references (this is done via updateStmtRefs)
+ */
+ def updateMemStmts(memPortMap: MemPortMap)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ val updatedMem = createMemProto(m)
+ memPortMap ++= getMemPortMap(m)
+ updatedMem
+ case s => s map updateMemStmts(memPortMap)
+ }
+
+ /** Replaces candidate memories and their references with standard port names
+ */
+ def updateMemMods(m: DefModule) = {
+ val memPortMap = new MemPortMap
+ (m map updateMemStmts(memPortMap)
+ map updateStmtRefs(memPortMap))
+ }
+
+ def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+}
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index c3516283..3139ef21 100644
--- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -1,28 +1,94 @@
// See LICENSE for license details.
package firrtl.passes
+package memlib
import firrtl._
import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
-import MemPortUtils._
+import MemPortUtils.{MemPortMap, Modules}
import MemTransformUtils._
import AnalysisUtils._
+/** Replace DefAnnotatedMemory with memory blackbox + wrapper + conf file.
+ * This will not generate wmask ports if not needed.
+ * Creates the minimum # of black boxes needed by the design.
+ */
class ReplaceMemMacros(writer: ConfWriter) extends Pass {
- def name = "Replace memories with black box wrappers" +
- " (optimizes when write mask isn't needed) + configuration file"
+ def name = "Replace Memory Macros"
- // from Albert
- def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = {
+ /** Return true if mask granularity is per bit, false if per byte or unspecified
+ */
+ private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match {
+ case None => false
+ case Some(v) => v == 1
+ }
+
+ private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+
+ private def wPortToBundle(mem: DefAnnotatedMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) => Seq(Field("mask", Default, createMask(mem.dataType)))
+ })
+ )
+ private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType)))
+ case Some(_) => Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))
+ })
+ )
+ // TODO(shunshou): Don't use createMask???
+
+ private def rwPortToBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("rdata", Flip, mem.dataType)
+ ) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) => Seq(Field("wmask", Default, createMask(mem.dataType)))
+ })
+ )
+ private def rwPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, flattenType(mem.dataType)),
+ Field("rdata", Flip, flattenType(mem.dataType))
+ ) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType)))
+ case Some(_) => Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
+ })
+ )
+
+ def memToBundle(s: DefAnnotatedMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+ def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
+
+ /** Creates a wrapper module and external module to replace a candidate memory
+ * The wrapper module has the same type as the memory it replaces
+ * The external module
+ */
+ def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
val wrapperIoType = memToBundle(m)
val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ // Creates a type with the write/readwrite masks omitted if necessary
val bbIoType = memToFlattenBundle(m)
val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
val bbRef = createRef(m.name, bbIoType)
- val hasMask = containsInfo(m.info, "maskGran")
+ val hasMask = m.maskGran.isDefined
val fillMask = getFillWMask(m)
def portRef(p: String) = createRef(p, field_type(wrapperIoType, p))
val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
@@ -38,18 +104,20 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
Seq(bb, wrapper)
}
- // TODO: get rid of copy pasta
- def defaultConnects(wrapperPort: WRef, bbPort: WSubField) =
+ // TODO(shunshou): get rid of copy pasta
+ // Connects the clk, en, and addr fields from the wrapperPort to the bbPort
+ def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] =
Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
- def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) =
+ // Connects the clk, en, and addr fields from the wrapperPort to the bbPort
+ def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression =
if (fillMask) toBitMask(mask, dataType) else toBits(mask)
- def adaptReader(wrapperPort: WRef, bbPort: WSubField) =
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
defaultConnects(wrapperPort, bbPort) :+
fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data"))
- def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
+ def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
val wrapperData = createSubField(wrapperPort, "data")
val defaultSeq = defaultConnects(wrapperPort, bbPort) :+
Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData))
@@ -63,7 +131,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
}
}
- def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
+ def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
val wrapperWData = createSubField(wrapperPort, "wdata")
val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")),
@@ -83,25 +151,21 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
memPortMap: MemPortMap,
memMods: Modules)
(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") =>
- if (!containsInfo(m.info, "maskGran")) {
+ case m: DefAnnotatedMemory =>
+ if (m.maskGran.isEmpty) {
m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
}
- val info = getInfo(m.info, "info") match {
- case None => NoInfo
- case Some(p: Info) => p
- }
- getInfo(m.info, "ref") match {
+ m.memRef match {
case None =>
// prototype mem
val newWrapperName = namespace newName m.name
val newMemBBName = namespace newName s"${m.name}_ext"
val newMem = m copy (name = newMemBBName)
memMods ++= createMemModule(newMem, newWrapperName)
- WDefInstance(info, m.name, newWrapperName, UnknownType)
+ WDefInstance(m.info, m.name, newWrapperName, UnknownType)
case Some(ref: String) =>
- WDefInstance(info, m.name, ref, UnknownType)
+ WDefInstance(m.info, m.name, ref, UnknownType)
}
case sx => sx map updateMemStmts(namespace, memPortMap, memMods)
}
diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index 62546a84..dfa828c9 100644
--- a/src/main/scala/firrtl/passes/ReplSeqMem.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -1,6 +1,7 @@
// See LICENSE for license details.
package firrtl.passes
+package memlib
import firrtl._
import firrtl.ir._
@@ -40,9 +41,9 @@ object PassConfigUtil {
class ConfWriter(filename: String) {
val outputBuffer = new CharArrayWriter
- def append(m: DefMemory) = {
+ def append(m: DefAnnotatedMemory) = {
// legacy
- val maskGran = getInfo(m.info, "maskGran")
+ val maskGran = m.maskGran
val readers = List.fill(m.readers.length)("read")
val writers = List.fill(m.writers.length)(if (maskGran.isEmpty) "write" else "mwrite")
val readwriters = List.fill(m.readwriters.length)(if (maskGran.isEmpty) "rw" else "mrw")
@@ -94,13 +95,16 @@ Optional Arguments:
class ReplSeqMem(transID: TransID) extends Transform with SimpleRun {
def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) =
Seq(Legalize,
- AnnotateMemMacros,
- UpdateDuplicateMemMacros,
- new AnnotateValidMemConfigs(inConfigFile),
+ ToMemIR,
+ ResolveMaskGranularity,
+ RenameAnnotatedMemoryPorts,
+ ResolveMemoryReference,
+ //new AnnotateValidMemConfigs(inConfigFile),
new ReplaceMemMacros(outConfigFile),
RemoveEmpty,
CheckInitialization,
InferTypes,
+ Uniquify,
ResolveKinds, // Must be run for the transform to work!
ResolveGenders)
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
new file mode 100644
index 00000000..a8ff9fe3
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -0,0 +1,124 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import WrappedExpression.weq
+import AnalysisUtils._
+import MemTransformUtils._
+
+object AnalysisUtils {
+ type Connects = collection.mutable.HashMap[String, Expression]
+
+ /** Builds a map from named component to assigned value
+ * Named components are serialized LHS of connections, nodes, invalids
+ */
+ def getConnects(m: DefModule): Connects = {
+ def getConnects(connects: Connects)(s: Statement): Statement = {
+ s match {
+ case Connect(_, loc, expr) =>
+ connects(loc.serialize) = expr
+ case DefNode(_, name, value) =>
+ connects(name) = value
+ case IsInvalid(_, value) =>
+ connects(value.serialize) = WInvalid
+ case _ => // do nothing
+ }
+ s map getConnects(connects)
+ }
+ val connects = new Connects
+ m map getConnects(connects)
+ connects
+ }
+
+ /** Find a connection LHS's origin from a module's list of node-to-node connections
+ * regardless of whether constant propagation has been run.
+ * Will search past trivial primop/mux's which do not affect its origin.
+ * Limitations:
+ * - Only works in a module (stops @ module inputs)
+ * - Only does trivial primop/mux's (is not complete)
+ * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a)
+ */
+ def getOrigin(connects: Connects, s: String): Expression =
+ getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER))
+ def getOrigin(connects: Connects)(e: Expression): Expression = e match {
+ case Mux(cond, tv, fv, _) =>
+ val fvOrigin = getOrigin(connects)(fv)
+ val tvOrigin = getOrigin(connects)(tv)
+ val condOrigin = getOrigin(connects)(cond)
+ if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
+ else if (weq(condOrigin, one)) tvOrigin
+ else if (weq(condOrigin, zero)) fvOrigin
+ else if (weq(tvOrigin, fvOrigin)) tvOrigin
+ else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
+ else e
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
+ val extractionWidth = (msb - lsb) + 1
+ val nodeWidth = bitWidth(args.head.tpe)
+ // if you're extracting the full bitwidth, then keep searching for origin
+ if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
+ case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
+ getOrigin(connects)(args.head)
+ case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
+ // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
+ case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
+ connects get e.serialize match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
+ case _ => e
+ }
+
+ /** Checks whether the two memories are equivalent in all respects except name
+ */
+ def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = a == b.copy(name = a.name)
+}
+
+/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
+ * Populates the maskGran field of DefAnnotatedMemory
+ * Annotations:
+ * - maskGran = (dataType size) / (number of mask bits)
+ * - i.e. 1 if bitmask, 8 if bytemask, absent for no mask
+ * TODO(shunshou): Add floorplan info?
+ */
+object ResolveMaskGranularity extends Pass {
+ def name = "Resolve Mask Granularity"
+
+ /** Returns the number of mask bits, if used
+ */
+ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
+ val wenOrigin = getOrigin(connects)(wen)
+ val wmaskOrigin = connects.keys filter
+ (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
+ val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ if (redundantMask) None else Some(wmaskOrigin.size)
+ }
+
+ /** Only annotate memories that are candidates for memory macro replacements
+ * i.e. rw, w + r (read, write 1 cycle delay)
+ */
+ def updateStmts(connects: Connects)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ val dataBits = bitWidth(m.dataType)
+ val rwMasks = m.readwriters map (rw =>
+ getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers map (w =>
+ getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val maskGran = (rwMasks ++ wMasks).head match {
+ case None => None
+ case Some(maskBits) => Some(dataBits / maskBits)
+ }
+ m.copy(maskGran = maskGran)
+ case sx => sx map updateStmts(connects)
+ }
+
+ def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
+ def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
new file mode 100644
index 00000000..783c179f
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -0,0 +1,38 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+import firrtl.ir._
+import AnalysisUtils.eqMems
+import firrtl.Mappers._
+
+
+/** Resolves annotation ref to memories that exactly match (except name) another memory
+ */
+object ResolveMemoryReference extends Pass {
+
+ def name = "Resolve Memory Reference"
+
+ type AnnotatedMemories = collection.mutable.ArrayBuffer[DefAnnotatedMemory]
+
+ /** If a candidate memory is identical except for name to another, add an
+ * annotation that references the name of the other memory.
+ */
+ def updateMemStmts(uniqueMems: AnnotatedMemories)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ uniqueMems find (x => eqMems(x, m)) match {
+ case None =>
+ uniqueMems += m
+ m
+ case Some(proto) => m copy (memRef = Some(proto.name))
+ }
+ case s => s map updateMemStmts(uniqueMems)
+ }
+
+ def updateMemMods(m: DefModule) = {
+ val uniqueMems = new AnnotatedMemories
+ (m map updateMemStmts(uniqueMems))
+ }
+
+ def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
new file mode 100644
index 00000000..741ea5ef
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
@@ -0,0 +1,41 @@
+package firrtl.passes
+package memlib
+
+import firrtl.Mappers._
+import firrtl.ir._
+
+/** Annotates sequential memories that are candidates for macro replacement.
+ * Requirements for macro replacement:
+ * - read latency and write latency of one
+ * - only one readwrite port or write port
+ * - zero or one read port
+ */
+object ToMemIR extends Pass {
+ def name = "To Memory IR"
+
+ /** Only annotate memories that are candidates for memory macro replacements
+ * i.e. rw, w + r (read, write 1 cycle delay)
+ */
+ def updateStmts(s: Statement): Statement = s match {
+ case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
+ (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
+ DefAnnotatedMemory(
+ m.info,
+ m.name,
+ m.dataType,
+ m.depth,
+ m.writeLatency,
+ m.readLatency,
+ m.readers,
+ m.writers,
+ m.readwriters,
+ m.readUnderWrite,
+ None, // mask granularity annotation
+ None // No reference yet to another memory
+ )
+ case sx => sx map updateStmts
+ }
+
+ def annotateModMems(m: DefModule) = m map updateStmts
+ def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
new file mode 100644
index 00000000..a1088300
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
@@ -0,0 +1,36 @@
+package firrtl.passes
+package memlib
+import net.jcazevedo.moultingyaml._
+import java.io.{File, CharArrayWriter, PrintWriter}
+
+object CustomYAMLProtocol extends DefaultYamlProtocol {
+ // bottom depends on top
+}
+
+class YamlFileReader(file: String) {
+ import CustomYAMLProtocol._
+ def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
+ if (new File(file).exists) {
+ val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
+ yamlString.parseYamls flatMap (x =>
+ try Some(reader read x)
+ catch { case e: Exception => None }
+ )
+ }
+ else error("Yaml file doesn't exist!")
+ }
+}
+
+class YamlFileWriter(file: String) {
+ import CustomYAMLProtocol._
+ val outputBuffer = new CharArrayWriter
+ val separator = "--- \n"
+ def append(in: YamlValue) {
+ outputBuffer append s"$separator${in.prettyPrint}"
+ }
+ def dump() {
+ val outputFile = new PrintWriter(file)
+ outputFile write outputBuffer.toString
+ outputFile.close()
+ }
+}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 8aeafc9e..277623cf 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -2,6 +2,7 @@ package firrtlTests
import firrtl._
import firrtl.passes._
+import firrtl.passes.memlib._
import Annotations._
class ReplSeqMemSpec extends SimpleTransformSpec {
@@ -13,7 +14,7 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new ResolveAndCheck(),
new HighFirrtlToMiddleFirrtl(),
new passes.InferReadWrite(TransID(-1)),
- new passes.ReplSeqMem(TransID(-2)),
+ new passes.memlib.ReplSeqMem(TransID(-2)),
new MiddleFirrtlToLowFirrtl(),
(new Transform with SimpleRun {
def execute(c: ir.Circuit, a: AnnotationMap) = run(c, passSeq) } ),
@@ -107,7 +108,7 @@ circuit Top :
val circuit = InferTypes.run(ToWorkingIR.run(parse(input)))
val m = circuit.modules.head.asInstanceOf[ir.Module]
val connects = AnalysisUtils.getConnects(m)
- val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects)("f").serialize
+ val calculatedOrigin = AnalysisUtils.getOrigin(connects, "f").serialize
require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !")
}