aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Namespace.scala15
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala141
-rw-r--r--src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala174
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala249
-rw-r--r--src/main/scala/firrtl/passes/MemUtils.scala80
-rw-r--r--src/main/scala/firrtl/passes/ReplSeqMem.scala96
-rw-r--r--src/main/scala/firrtl/passes/ReplaceMemMacros.scala229
-rw-r--r--src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala148
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala2
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala2
10 files changed, 530 insertions, 606 deletions
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala
index 952670cf..1e922673 100644
--- a/src/main/scala/firrtl/Namespace.scala
+++ b/src/main/scala/firrtl/Namespace.scala
@@ -57,8 +57,6 @@ class Namespace private {
}
object Namespace {
- def apply(): Namespace = new Namespace
-
// Initializes a namespace from a Module
def apply(m: DefModule): Namespace = {
val namespace = new Namespace
@@ -69,7 +67,7 @@ object Namespace {
case s: Block => s.stmts flatMap buildNamespaceStmt
case _ => Nil
}
- namespace.namespace ++= (m.ports collect { case dec: IsDeclaration => dec.name })
+ namespace.namespace ++= m.ports map (_.name)
m match {
case in: Module =>
namespace.namespace ++= buildNamespaceStmt(in.body)
@@ -82,9 +80,14 @@ object Namespace {
/** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */
def apply(c: Circuit): Namespace = {
val namespace = new Namespace
- c.modules foreach { m =>
- namespace.namespace += m.name
- }
+ namespace.namespace ++= c.modules map (_.name)
+ namespace
+ }
+
+ /** Initializes a [[Namespace]] from arbitrary strings **/
+ def apply(names: Seq[String] = Nil): Namespace = {
+ val namespace = new Namespace
+ namespace.namespace ++= names
namespace
}
}
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
index 58e10a66..7ced7a99 100644
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
@@ -2,13 +2,13 @@
package firrtl.passes
-import scala.collection.mutable
-import AnalysisUtils._
-import firrtl.WrappedExpression._
-import firrtl.ir._
import firrtl._
+import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
+import WrappedExpression.weq
+import MemPortUtils.memPortField
+import AnalysisUtils._
case class AppendableInfo(fields: Map[String, Any]) extends Info {
def append(a: Map[String, Any]) = this.copy(fields = fields ++ a)
@@ -21,22 +21,22 @@ case class AppendableInfo(fields: Map[String, Any]) extends Info {
}
object AnalysisUtils {
-
- def getConnects(m: Module) = {
- val connects = mutable.HashMap[String, Expression]()
- def getConnects(s: Statement): Statement = {
- s map getConnects match {
+ type Connects = collection.mutable.HashMap[String, Expression]
+ def getConnects(m: DefModule): Connects = {
+ def getConnects(connects: Connects)(s: Statement): Statement = {
+ s match {
case Connect(_, loc, expr) =>
connects(loc.serialize) = expr
case DefNode(_, name, value) =>
connects(name) = value
case _ => // do nothing
}
- s // return because we only have map and not foreach
+ s map getConnects(connects)
}
- getConnects(m.body)
- connects.toMap
- }
+ val connects = new Connects
+ m map getConnects(connects)
+ connects
+ }
// takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS.
// if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation,
@@ -44,35 +44,40 @@ object AnalysisUtils {
// use case: compare if two nodes have the same origin
// limitation: only works in a module (stops @ module inputs)
// TODO: more thorough (i.e. a + 0 = a)
- def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = {
- if (connects contains node) getOrigin(connects, connects(node))
- else EmptyExpression
- }
+ def getConnectOrigin(connects: Connects)(node: String): Expression =
+ connects get node match {
+ case None => EmptyExpression
+ case Some(e) => getOrigin(connects, e)
+ }
+ def getConnectOrigin(connects: Connects, e: Expression): Expression =
+ getConnectOrigin(connects)(e.serialize)
- private def getOrigin(connects: Map[String, Expression], e: Expression): Expression = e match {
+ private def getOrigin(connects: Connects, e: Expression): Expression = e match {
case Mux(cond, tv, fv, _) =>
val fvOrigin = getOrigin(connects, fv)
val tvOrigin = getOrigin(connects, tv)
val condOrigin = getOrigin(connects, cond)
- if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin
- else if (we(condOrigin) == we(one)) tvOrigin
- else if (we(condOrigin) == we(zero)) fvOrigin
- else if (we(tvOrigin) == we(fvOrigin)) tvOrigin
- else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin
+ if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
+ else if (weq(condOrigin, one)) tvOrigin
+ else if (weq(condOrigin, zero)) fvOrigin
+ else if (weq(tvOrigin, fvOrigin)) tvOrigin
+ else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args.contains(one) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args.contains(zero) => zero
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
- val extractionWidth = (msb-lsb)+1
+ val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
// if you're extracting the full bitwidth, then keep searching for origin
- if (nodeWidth == extractionWidth) getOrigin(connects, args.head)
- else e
+ if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects, args.head)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
- case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects.contains(e.serialize) && kind(e) != RegKind =>
- getConnectOrigin(connects, e.serialize)
+ case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
+ connects get e.serialize match {
+ case Some(ex) => getOrigin(connects, ex)
+ case None => e
+ }
case _ => e
}
@@ -92,55 +97,45 @@ object AnalysisUtils {
// memories equivalent as long as all fields (except name) are the same
def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name)
-
}
object AnnotateMemMacros extends Pass {
+ def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)"
+
+ // returns # of mask bits if used
+ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
+ val wenOrigin = getConnectOrigin(connects, wen)
+ val wmaskOrigin = connects.keys filter
+ (_ startsWith wmask.serialize) map getConnectOrigin(connects)
+ // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
+ val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ if (redundantMask) None else Some(wmaskOrigin.size)
+ }
- def name = "Analyze sequential memories and tag with info for future passes (useMacro,maskGran)"
-
- def run(c: Circuit) = {
-
- def annotateModMems(m: Module) = {
- val connects = getConnects(m)
-
- // returns # of mask bits if used
- def getMaskBits(wen: String, wmask: String): Option[Int] = {
- val wenOrigin = we(getConnectOrigin(connects, wen))
- val one1 = we(one)
- val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects, x)))
- // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin.map( x => (x == wenOrigin) || (x == one1) ).foldLeft(true)(_ && _)
- if (redundantMask) None else Some(wmaskOrigin.length)
- }
-
- def updateStmts(s: Statement): Statement = s match {
- // only annotate memories that are candidates for memory macro replacements
- // i.e. rw, w + r (read, write 1 cycle delay)
- case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
- (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
- val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode", s"${m.name}.$w.wmask"))
- val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en", s"${m.name}.$w.mask"))
- val maskBits = (rwMasks ++ wMasks).head
- val memAnnotations = Map("useMacro" -> true)
- val tempInfo = appendInfo(m.info, memAnnotations)
- if (maskBits == None) m.copy(info = tempInfo)
- else m.copy(info = tempInfo.append("maskGran" -> dataBits/maskBits.get))
- case b: Block => b map updateStmts
- case s => s
+ def updateStmts(connects: Connects)(s: Statement): Statement = s match {
+ // only annotate memories that are candidates for memory macro replacements
+ // i.e. rw, w + r (read, write 1 cycle delay)
+ case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
+ (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
+ val dataBits = bitWidth(m.dataType)
+ val rwMasks = m.readwriters map (rw =>
+ getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers map (w =>
+ getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val memAnnotations = Map("useMacro" -> true)
+ val tempInfo = appendInfo(m.info, memAnnotations)
+ (rwMasks ++ wMasks).head match {
+ case None =>
+ m copy (info = tempInfo)
+ case Some(maskBits) =>
+ m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits))
}
- m.copy(body=updateStmts(m.body))
- }
-
- val updatedMods = c.modules map {
- case m: Module => annotateModMems(m)
- case m: ExtModule => m
- }
- c.copy(modules = updatedMods)
+ case s => s map updateStmts(connects)
+ }
- }
+ def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
+ def run(c: Circuit) = c copy (modules = (c.modules map annotateModMems))
}
-// TODO: Add floorplan info? \ No newline at end of file
+// TODO: Add floorplan info?
diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
index 816a179b..f80d4a0c 100644
--- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
+++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
@@ -2,13 +2,16 @@
package firrtl.passes
-import firrtl.ir._
import firrtl._
+import firrtl.ir._
+import firrtl.Mappers._
+import Utils.error
+import AnalysisUtils._
+
import net.jcazevedo.moultingyaml._
import net.jcazevedo.moultingyaml.DefaultYamlProtocol._
-import AnalysisUtils._
import scala.collection.mutable
-import firrtl.Mappers._
+import java.io.{File, CharArrayWriter, PrintWriter}
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
@@ -36,7 +39,7 @@ case class MemDimension(
rules: Option[DimensionRules],
set: Option[List[Int]]) {
require (
- if(rules == None) set != None else set == None,
+ if (rules == None) set != None else set == None,
"Should specify either rules or a list of valid options, but not both"
)
def getValid = set.getOrElse(rules.get.getValid).sorted
@@ -128,15 +131,12 @@ case class SRAMCompiler(
rules.find(r => r.ymux._1 == x && r.ybank._1 == y).get
}
- private val maskConfigOutputBuffer = new java.io.CharArrayWriter
- private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter
+ private val maskConfigOutputBuffer = new CharArrayWriter
+ private val noMaskConfigOutputBuffer = new CharArrayWriter
def append(m: DefMemory) : DefMemory = {
- val validCombos = mutable.ArrayBuffer[SRAMConfig]()
- defaultSearchOrdering foreach { r =>
- val config = r.getValidConfig(m)
- if (config != None) validCombos += config.get
- }
+ val validCombos = (defaultSearchOrdering map (_ getValidConfig m)
+ collect { case Some(config) => config })
// non empty if successfully found compiler option that supports depth/width
// TODO: don't just take first option
val usedConfig = {
@@ -144,18 +144,19 @@ case class SRAMCompiler(
else getBestAlternative(m)
}
val usesMaskGran = containsInfo(m.info, "maskGran")
- if (configPattern != None) {
- val newConfig = usedConfig.serialize(configPattern.get) + "\n"
- val currentBuff = {
- if (usesMaskGran) maskConfigOutputBuffer
- else noMaskConfigOutputBuffer
- }
- if (!currentBuff.toString.contains(newConfig))
- currentBuff.append(newConfig)
+ configPattern match {
+ case None =>
+ case Some(p) =>
+ val newConfig = usedConfig.serialize(p) + "\n"
+ val currentBuff = {
+ if (usesMaskGran) maskConfigOutputBuffer
+ else noMaskConfigOutputBuffer
+ }
+ if (!currentBuff.toString.contains(newConfig)) currentBuff append newConfig
}
val temp = appendInfo(m.info, "sramConfig" -> usedConfig)
- val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
- m.copy(info = newInfo)
+ val newInfo = if (usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
+ m copy (info = newInfo)
}
// TODO: Should you really be splitting in 2 if, say, depth is 1 more than allowed? should be thresholded and
@@ -164,47 +165,47 @@ case class SRAMCompiler(
private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = {
val validXRange = mutable.ArrayBuffer[SRAMRules]()
val validYRange = mutable.ArrayBuffer[SRAMRules]()
- defaultSearchOrdering foreach { r =>
+ defaultSearchOrdering foreach { r =>
if (m.width <= r.getValidWidths.max) validXRange += r
if (m.depth <= r.getValidDepths.max) validYRange += r
}
-
- if (validXRange.isEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
- else if (validXRange.isEmpty && validYRange.nonEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
- else if (validXRange.nonEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
- else if (validXRange.intersect(validYRange).nonEmpty)
- Seq(m)
- else
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
+ (validXRange.isEmpty, validYRange.isEmpty) match {
+ case (true, true) =>
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
+ case (true, false) =>
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
+ case (false, true) =>
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
+ case (false, false) if validXRange.intersect(validYRange).nonEmpty =>
+ Seq(m)
+ case (false, false) =>
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
+ }
}
private def getBestAlternative(m: DefMemory): SRAMConfig = {
val validConfigs = getInRange(SRAMConfig(width = bitWidth(m.dataType).intValue, depth = m.depth))
- val minNum = validConfigs.map(x => x.num).min
+ val minNum = validConfigs.map(_.num).min
val validMinConfigs = validConfigs.filter(_.num == minNum)
- val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap
- val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min
+ val validMinConfigsSquareness = validMinConfigs.map(
+ x => math.abs(x.width.toDouble / x.depth - 1) -> x).toMap
+ val squarestAspectRatio = validMinConfigsSquareness.unzip._1.min
val validConfig = validMinConfigsSquareness(squarestAspectRatio)
- val validRules = mutable.ArrayBuffer[SRAMRules]()
- defaultSearchOrdering foreach { r =>
- if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r
- }
+ val validRules = defaultSearchOrdering filter (r =>
+ (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max))
// TODO: don't just take first option
// TODO: More optimal split if particular value is in range but not supported
// TODO: Support up to 2 read ports, 2 write ports; should be power of 2?
val bestRule = validRules.head
val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get
val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get
- bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)
+ (bestRule.getValidConfig(width = memWidth, depth = memDepth).get
+ copy (xsplit = validConfig.xsplit, ysplit = validConfig.ysplit))
}
// TODO
def serialize = ???
-
}
// TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true?
@@ -212,11 +213,11 @@ case class SRAMCompiler(
class YamlFileReader(file: String) {
import CustomYAMLProtocol._
def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
- if (new java.io.File(file).exists) {
+ if (new File(file).exists) {
val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
- yamlString.parseYamls.flatMap(x =>
- try Some(reader.read(x))
- catch { case e: Exception => None }
+ yamlString.parseYamls flatMap (x =>
+ try Some(reader read x)
+ catch { case e: Exception => None }
)
}
else error("Yaml file doesn't exist!")
@@ -225,22 +226,20 @@ class YamlFileReader(file: String) {
class YamlFileWriter(file: String) {
import CustomYAMLProtocol._
- val outputBuffer = new java.io.CharArrayWriter
+ val outputBuffer = new CharArrayWriter
val separator = "--- \n"
- def append(in: YamlValue) = {
- outputBuffer.append(separator + in.prettyPrint)
+ def append(in: YamlValue) {
+ outputBuffer append s"$separator${in.prettyPrint}"
}
- def dump = {
- val outputFile = new java.io.PrintWriter(file)
- outputFile.write(outputBuffer.toString)
- outputFile.close()
+ def dump {
+ val outputFile = new PrintWriter(file)
+ outputFile write outputBuffer.toString
+ outputFile.close
}
}
class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
-
import CustomYAMLProtocol._
-
def name = "Annotate memories with valid split depths, widths, #\'s"
// TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist
@@ -249,43 +248,42 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
sp: Option[SRAMCompiler] = None,
dp: Option[SRAMCompiler] = None) {
def serialize = {
- if (sp != None) sp.get.serialize
- if (dp != None) dp.get.serialize
+ sp match {
+ case None =>
+ case Some(p) => p.serialize
+ }
+ dp match {
+ case None =>
+ case Some(p) => p.serialize
+ }
}
}
- val sramCompilers = {
- if (reader == None) None
- else {
- val compilers = reader.get.parse[SRAMCompiler]
- val sp = compilers.find(_.portType == "RW")
- val dp = compilers.find(_.portType == "R,W")
+
+ val sramCompilers = reader match {
+ case None => None
+ case Some(r) =>
+ val compilers = r.parse[SRAMCompiler]
+ val sp = compilers find (_.portType == "RW")
+ val dp = compilers find (_.portType == "R,W")
Some(SRAMCompilerSet(sp = sp, dp = dp))
- }
}
- def run(c: Circuit) = {
- def annotateModMems(m: Module) = {
- def updateStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") => {
- if (sramCompilers == None) m
- else {
- if (m.readwriters.length == 1)
- if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!")
- else sramCompilers.get.sp.get.append(m)
- else
- if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!")
- else sramCompilers.get.dp.get.append(m)
- }
- }
- case b: Block => b map updateStmts
- case s => s
- }
- m.copy(body=updateStmts(m.body))
- }
- val updatedMods = c.modules map {
- case m: Module => annotateModMems(m)
- case m: ExtModule => m
+ def updateStmts(s: Statement): Statement = s match {
+ case m: DefMemory if containsInfo(m.info, "useMacro") => sramCompilers match {
+ case None => m
+ case Some(compiler) if (m.readwriters.length == 1) =>
+ compiler.sp match {
+ case None => error("Design needs RW port memory compiler!")
+ case Some(p) => p append m
+ }
+ case Some(compiler) =>
+ compiler.dp match {
+ case None => error("Design needs R,W port memory compiler!")
+ case Some(p) => p append m
+ }
}
- c.copy(modules = updatedMods)
- }
+ case s => s map updateStmts
+ }
+
+ def run(c: Circuit) = c copy (modules = (c.modules map (_ map updateStmts)))
}
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
index 9fbd6ab3..ec996fdb 100644
--- a/src/main/scala/firrtl/passes/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -27,13 +27,14 @@ MODIFICATIONS.
package firrtl.passes
-import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap}
-import com.typesafe.scalalogging.LazyLogging
-
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.PrimOps._
+import firrtl.Utils.{one, zero, BoolType}
+import MemPortUtils.memPortField
+import AnalysisUtils.{Connects, getConnects}
+import WrappedExpression.weq
import Annotations._
case class InferReadWriteAnnotation(t: String, tID: TransID)
@@ -50,153 +51,133 @@ case class InferReadWriteAnnotation(t: String, tID: TransID)
object InferReadWritePass extends Pass {
def name = "Infer ReadWrite Ports"
- def inferReadWrite(m: Module) = {
- import WrappedExpression.we
- val connects = HashMap[String, Expression]()
- val repl = HashMap[String, Expression]()
- val stmts = ArrayBuffer[Statement]()
- val zero = we(UIntLiteral(0, IntWidth(1)))
- val one = we(UIntLiteral(1, IntWidth(1)))
+ type Netlist = collection.mutable.HashMap[String, Expression]
+ type Statements = collection.mutable.ArrayBuffer[Statement]
+ type PortSet = collection.mutable.HashSet[String]
+
+ private implicit def toString(e: Expression) = e.serialize
- // find all wire connections
- def analyze(s: Statement): Unit = s match {
- case s: Connect =>
- connects(s.loc.serialize) = s.expr
- case s: PartialConnect =>
- connects(s.loc.serialize) = s.expr
- case s: DefNode =>
- connects(s.name) = s.value
- case s: Block =>
- s.stmts foreach analyze
- case _ =>
+ def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match {
+ // No ConstProp yet...
+ case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) =>
+ getProductTerms(connects)(cond)
+ // Visit each term of AND operation
+ case DoPrim(op, args, consts, tpe) if op == And =>
+ e +: (args flatMap getProductTerms(connects))
+ // Visit connected nodes to references
+ case _: WRef | _: WSubField | _: WSubIndex => connects get e match {
+ case None => Seq(e)
+ case Some(ex) => e +: getProductTerms(connects)(ex)
}
+ // Otherwise just return itself
+ case _ => Seq(e)
+ }
- def getProductTermsFromExp(e: Expression): Seq[Expression] =
- e match {
- // No ConstProp yet...
- case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero =>
- cond +: getProductTerms(cond.serialize)
- // Visit each term of AND operation
- case DoPrim(op, args, consts, tpe) if op == And =>
- e +: (args flatMap getProductTermsFromExp)
- // Visit connected nodes to references
- case _: WRef | _: SubField | _: SubIndex | _: SubAccess =>
- e +: getProductTerms(e.serialize)
- // Otherwise just return itselt
- case _ =>
- List(e)
- }
+ def checkComplement(a: Expression, b: Expression) = (a, b) match {
+ // b ?= Not(a)
+ case (_, DoPrim(Not, args, _, _)) => weq(args.head, a)
+ // a ?= Not(b)
+ case (DoPrim(Not, args, _, _), _) => weq(args.head, b)
+ // b ?= Eq(a, 0) or b ?= Eq(0, a)
+ case (_, DoPrim(Eq, args, _, _)) =>
+ weq(args(0), a) && weq(args(1), zero) ||
+ weq(args(1), a) && weq(args(0), zero)
+ // a ?= Eq(b, 0) or b ?= Eq(0, a)
+ case (DoPrim(Eq, args, _, _), _) =>
+ weq(args(0), b) && weq(args(1), zero) ||
+ weq(args(1), b) && weq(args(0), zero)
+ case _ => false
+ }
- def getProductTerms(node: String): Seq[Expression] =
- if (connects contains node) getProductTermsFromExp(connects(node)) else Nil
- def checkComplement(a: Expression, b: Expression) = (a, b) match {
- // b ?= Not(a)
- case (_, DoPrim(op, args, _, _)) if op == Not =>
- args.head.serialize == a.serialize
- // a ?= Not(b)
- case (DoPrim(op, args, _, _), _) if op == Not =>
- args.head.serialize == b.serialize
- // b ?= Eq(a, 0) or b ?= Eq(0, a)
- case (_, DoPrim(op, args, _, _)) if op == Eq =>
- args(0).serialize == a.serialize && we(args(1)) == zero ||
- args(1).serialize == a.serialize && we(args(0)) == zero
- // a ?= Eq(b, 0) or b ?= Eq(0, a)
- case (DoPrim(op, args, _, _), _) if op == Eq =>
- args(0).serialize == b.serialize && we(args(1)) == zero ||
- args(1).serialize == b.serialize && we(args(0)) == zero
- case _ => false
+ def replaceExp(repl: Netlist)(e: Expression): Expression =
+ e map replaceExp(repl) match {
+ case e: WSubField => repl getOrElse (e.serialize, e)
+ case e => e
}
- def inferReadWrite(s: Statement): Statement = s map inferReadWrite match {
- // infer readwrite ports only for non combinational memories
- case mem: DefMemory if mem.readLatency > 0 =>
- val bt = UIntType(IntWidth(1))
- val ut = UnknownType
- val ug = UNKNOWNGENDER
- val readers = HashSet[String]()
- val writers = HashSet[String]()
- val readwriters = ArrayBuffer[String]()
- for (w <- mem.writers ; r <- mem.readers) {
- val wp = getProductTerms(s"${mem.name}.$w.en")
- val rp = getProductTerms(s"${mem.name}.$r.en")
- if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
- val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet
- // Uniquify names by examining all ports of the memory
- var rw = (for {
- idx <- Stream from 0
- newName = s"rw_$idx"
- if !allPorts(newName)
- } yield newName).head
- val rw_exp = WSubField(WRef(mem.name, ut, MemKind, ug), rw, ut, ug)
- readwriters += rw
- readers += r
- writers += w
- repl(s"${mem.name}.$r.en") = EmptyExpression
- repl(s"${mem.name}.$r.clk") = EmptyExpression
- repl(s"${mem.name}.$r.addr") = EmptyExpression
- repl(s"${mem.name}.$r.data") = WSubField(rw_exp, "rdata", mem.dataType, MALE)
- repl(s"${mem.name}.$w.en") = WSubField(rw_exp, "wmode", bt, FEMALE)
- repl(s"${mem.name}.$w.clk") = EmptyExpression
- repl(s"${mem.name}.$w.addr") = EmptyExpression
- repl(s"${mem.name}.$w.data") = WSubField(rw_exp, "wdata", mem.dataType, FEMALE)
- repl(s"${mem.name}.$w.mask") = WSubField(rw_exp, "wmask", ut, FEMALE)
- stmts += Connect(NoInfo, WSubField(rw_exp, "clk", ClockType, FEMALE),
- WRef("clk", ClockType, NodeKind, MALE))
- stmts += Connect(NoInfo, WSubField(rw_exp, "en", bt, FEMALE),
- DoPrim(Or, List(connects(s"${mem.name}.$r.en"), connects(s"${mem.name}.$w.en")), Nil, bt))
- stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE),
- Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"),
- connects(s"${mem.name}.$r.addr"), ut))
- }
- }
- if (readwriters.isEmpty) mem else DefMemory(mem.info,
- mem.name, mem.dataType, mem.depth, mem.writeLatency, mem.readLatency,
- mem.readers filterNot readers, mem.writers filterNot writers,
- mem.readwriters ++ readwriters)
+ def replaceStmt(repl: Netlist)(s: Statement): Statement =
+ s map replaceStmt(repl) map replaceExp(repl) match {
+ case Connect(_, EmptyExpression, _) => EmptyStmt
case s => s
}
-
- def replaceExp(e: Expression): Expression =
- e map replaceExp match {
- case e: WSubField => repl getOrElse (e.serialize, e)
- case e => e
+
+ def inferReadWriteStmt(connects: Connects,
+ repl: Netlist,
+ stmts: Statements)
+ (s: Statement): Statement = s match {
+ // infer readwrite ports only for non combinational memories
+ case mem: DefMemory if mem.readLatency > 0 =>
+ val ut = UnknownType
+ val ug = UNKNOWNGENDER
+ val readers = new PortSet
+ val writers = new PortSet
+ val readwriters = collection.mutable.ArrayBuffer[String]()
+ val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
+ for (w <- mem.writers ; r <- mem.readers) {
+ val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
+ val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
+ if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
+ val rw = namespace newName "rw"
+ val rwExp = createSubField(createRef(mem.name), rw)
+ readwriters += rw
+ readers += r
+ writers += w
+ repl(memPortField(mem, r, "clk")) = EmptyExpression
+ repl(memPortField(mem, r, "en")) = EmptyExpression
+ repl(memPortField(mem, r, "addr")) = EmptyExpression
+ repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata")
+ repl(memPortField(mem, w, "clk")) = EmptyExpression
+ repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode")
+ repl(memPortField(mem, w, "addr")) = EmptyExpression
+ repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata")
+ repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask")
+ stmts += Connect(NoInfo, createSubField(rwExp, "clk"), createRef("clk")) // TODO: fix it
+ stmts += Connect(NoInfo, createSubField(rwExp, "en"),
+ DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
+ connects(memPortField(mem, w, "en"))), Nil, BoolType))
+ stmts += Connect(NoInfo, createSubField(rwExp, "addr"),
+ Mux(connects(memPortField(mem, w, "en")),
+ connects(memPortField(mem, w, "addr")),
+ connects(memPortField(mem, r, "addr")), UnknownType))
+ }
}
+ if (readwriters.isEmpty) mem else mem copy (
+ readers = mem.readers filterNot readers,
+ writers = mem.writers filterNot writers,
+ readwriters = mem.readwriters ++ readwriters)
+ case s => s map inferReadWriteStmt(connects, repl, stmts)
+ }
- def replaceStmt(s: Statement): Statement =
- s map replaceStmt map replaceExp match {
- case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt
- case s => s
- }
-
- analyze(m.body)
- Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq))
+ def inferReadWrite(m: DefModule) = {
+ val connects = getConnects(m)
+ val repl = new Netlist
+ val stmts = new Statements
+ (m map inferReadWriteStmt(connects, repl, stmts)
+ map replaceStmt(repl)) match {
+ case m: ExtModule => m
+ case m: Module => m copy (body = Block(m.body +: stmts))
+ }
}
- def run (c:Circuit) = Circuit(c.info, c.modules map {
- case m: Module => inferReadWrite(m)
- case m: ExtModule => m
- }, c.main)
+ def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite)
}
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite(transID: TransID) extends Transform with LazyLogging {
- def execute(circuit:Circuit, map: AnnotationMap) =
- map get transID match {
- case Some(p) => p get CircuitName(circuit.main) match {
- case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq(
- InferReadWritePass,
- CheckInitialization,
- ResolveKinds,
- InferTypes,
- ResolveGenders) foldLeft circuit){ (c, pass) =>
- val x = Utils.time(pass.name)(pass run c)
- logger debug x.serialize
- x
- }, None, Some(map))
- case _ => TransformResult(circuit, None, Some(map))
- }
- case _ => TransformResult(circuit, None, Some(map))
+class InferReadWrite(transID: TransID) extends Transform with SimpleRun {
+ def passSeq = Seq(
+ InferReadWritePass,
+ CheckInitialization,
+ InferTypes,
+ ResolveKinds,
+ ResolveGenders
+ )
+ def execute(c: Circuit, map: AnnotationMap) = map get transID match {
+ case Some(p) => p get CircuitName(c.main) match {
+ case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq)
+ case _ => error("Unexpected annotation for InferReadWrite")
}
+ case _ => TransformResult(c)
+ }
}
diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala
index d2557f8d..1091db5f 100644
--- a/src/main/scala/firrtl/passes/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/MemUtils.scala
@@ -145,11 +145,27 @@ object createMask {
}
}
-object MemPortUtils {
+object createRef {
+ def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind) = WRef(n, t, k, UNKNOWNGENDER)
+}
+
+object createSubField {
+ def apply(exp: Expression, n: String) = WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER)
+}
- import AnalysisUtils._
+object connectFields {
+ def apply(lref: Expression, lname: String, rref: Expression, rname: String) =
+ Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))
+}
- def flattenType(t: Type) = UIntType(IntWidth(bitWidth(t)))
+object flattenType {
+ def apply(t: Type) = UIntType(IntWidth(bitWidth(t)))
+}
+
+object MemPortUtils {
+ type MemPortMap = collection.mutable.HashMap[String, Expression]
+ type Memories = collection.mutable.ArrayBuffer[DefMemory]
+ type Modules = collection.mutable.ArrayBuffer[DefModule]
def defaultPortSeq(mem: DefMemory) = Seq(
Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))),
@@ -157,64 +173,10 @@ object MemPortUtils {
Field("clk", Default, ClockType)
)
- def getFillWMask(mem: DefMemory) =
- getInfo(mem.info, "maskGran") match {
- case None => false
- case Some(maskGran) => maskGran == 1
- }
-
- def rPortToBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
- def rPortToFlattenBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
-
- def wPortToBundle(mem: DefMemory) = BundleType(
- (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++
- (if (!containsInfo(mem.info, "maskGran")) Nil
- else Seq(Field("mask", Default, createMask(mem.dataType))))
- )
- def wPortToFlattenBundle(mem: DefMemory) = BundleType(
- (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++
- (if (!containsInfo(mem.info, "maskGran")) Nil
- else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType)))
- else Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))))
- )
- // TODO: Don't use createMask???
-
- def rwPortToBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) ++ Seq(
- Field("wmode", Default, BoolType),
- Field("wdata", Default, mem.dataType),
- Field("rdata", Flip, mem.dataType)
- ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
- else Seq(Field("wmask", Default, createMask(mem.dataType)))
- )
- )
-
- def rwPortToFlattenBundle(mem: DefMemory) = BundleType(
- defaultPortSeq(mem) ++ Seq(
- Field("wmode", Default, BoolType),
- Field("wdata", Default, flattenType(mem.dataType)),
- Field("rdata", Flip, flattenType(mem.dataType))
- ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
- else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType)))
- else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
- )
- )
-
- def memToBundle(s: DefMemory) = BundleType(
- s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
-
- def memToFlattenBundle(s: DefMemory) = BundleType(
- s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
-
// Todo: merge it with memToBundle
def memType(mem: DefMemory) = {
- val rType = rPortToBundle(mem)
+ val rType = BundleType(defaultPortSeq(mem) :+
+ Field("data", Flip, mem.dataType))
val wType = BundleType(defaultPortSeq(mem) ++ Seq(
Field("data", Default, mem.dataType),
Field("mask", Default, createMask(mem.dataType))))
diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala
index ae842a0b..c2c1b303 100644
--- a/src/main/scala/firrtl/passes/ReplSeqMem.scala
+++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala
@@ -2,12 +2,12 @@
package firrtl.passes
-import com.typesafe.scalalogging.LazyLogging
import firrtl._
import firrtl.ir._
import Annotations._
-import java.io.Writer
import AnalysisUtils._
+import Utils.error
+import java.io.{File, CharArrayWriter, PrintWriter}
sealed trait PassOption
case object InputConfigFileName extends PassOption
@@ -15,11 +15,9 @@ case object OutputConfigFileName extends PassOption
case object PassCircuitName extends PassOption
object PassConfigUtil {
-
+ type PassOptionMap = Map[PassOption, String]
+
def getPassOptions(t: String, usage: String = "") = {
-
- type PassOptionMap = Map[PassOption, String]
-
// can't use space to delimit sub arguments (otherwise, Driver.scala will throw error)
val passArgList = t.split(":").toList
@@ -38,25 +36,24 @@ object PassConfigUtil {
}
nextPassOption(Map[PassOption, String](), passArgList)
}
-
}
class ConfWriter(filename: String) {
- val outputBuffer = new java.io.CharArrayWriter
+ val outputBuffer = new CharArrayWriter
def append(m: DefMemory) = {
// legacy
val maskGran = getInfo(m.info, "maskGran")
- val writers = m.writers map (x => if (maskGran == None) "write" else "mwrite")
val readers = List.fill(m.readers.length)("read")
- val readwriters = m.readwriters map (x => if (maskGran == None) "rw" else "mrw")
- val ports = (writers ++ readers ++ readwriters).mkString(",")
- val maskGranConf = if (maskGran == None) "" else s"mask_gran ${maskGran.get}"
+ val writers = List.fill(m.writers.length)(if (maskGran == None) "write" else "mwrite")
+ val readwriters = List.fill(m.readwriters.length)(if (maskGran == None) "rw" else "mrw")
+ val ports = (writers ++ readers ++ readwriters) mkString ","
+ val maskGranConf = maskGran match { case None => "" case Some(p) => s"mask_gran $p" }
val width = bitWidth(m.dataType)
val conf = s"name ${m.name} depth ${m.depth} width ${width} ports ${ports} ${maskGranConf} \n"
outputBuffer.append(conf)
}
def serialize = {
- val outputFile = new java.io.PrintWriter(filename)
+ val outputFile = new PrintWriter(filename)
outputFile.write(outputBuffer.toString)
outputFile.close()
}
@@ -91,50 +88,35 @@ Optional Arguments:
error("No circuit name specified for ReplSeqMem!" + usage)
)
val target = CircuitName(passCircuit)
- def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit, "-c:"+n.name))
-
+ def duplicate(n: Named) = this copy (t = (t replace (s"-c:$passCircuit", s"-c:${n.name}")))
}
-class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
- def execute(circuit:Circuit, map: AnnotationMap) =
- map get transID match {
- case Some(p) => p get CircuitName(circuit.main) match {
- case Some(ReplSeqMemAnnotation(t, _)) => {
-
- val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "")
- val inConfigFile = {
- if (inputFileName.isEmpty) None
- else if (new java.io.File(inputFileName).exists) Some(new YamlFileReader(inputFileName))
- else error("Input configuration file does not exist!")
- }
-
- val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t).get(OutputConfigFileName).get)
- TransformResult(
- (
- Seq(
- Legalize,
- AnnotateMemMacros,
- UpdateDuplicateMemMacros,
- new AnnotateValidMemConfigs(inConfigFile),
- new ReplaceMemMacros(outConfigFile),
- RemoveEmpty,
- CheckInitialization,
- ResolveKinds, // Must be run for the transform to work!
- InferTypes,
- ResolveGenders
- ) foldLeft circuit
- ) {
- (c, pass) =>
- val x = Utils.time(pass.name)(pass run c)
- logger debug x.serialize
- x
- } ,
- None,
- Some(map)
- )
- }
- case _ => error("Unexpected transform annotation")
- }
- case _ => TransformResult(circuit, None, Some(map))
+class ReplSeqMem(transID: TransID) extends Transform with SimpleRun {
+ def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) =
+ Seq(Legalize,
+ AnnotateMemMacros,
+ UpdateDuplicateMemMacros,
+ new AnnotateValidMemConfigs(inConfigFile),
+ new ReplaceMemMacros(outConfigFile),
+ RemoveEmpty,
+ CheckInitialization,
+ InferTypes,
+ ResolveKinds, // Must be run for the transform to work!
+ ResolveGenders)
+
+ def execute(c: Circuit, map: AnnotationMap) = map get transID match {
+ case Some(p) => p get CircuitName(c.main) match {
+ case Some(ReplSeqMemAnnotation(t, _)) =>
+ val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "")
+ val inConfigFile = {
+ if (inputFileName.isEmpty) None
+ else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName))
+ else error("Input configuration file does not exist!")
+ }
+ val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName))
+ run(c, passSeq(inConfigFile, outConfigFile))
+ case _ => error("Unexpected transform annotation")
}
-} \ No newline at end of file
+ case _ => TransformResult(c)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
index 7bb9c6c4..33a371a0 100644
--- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
@@ -2,91 +2,35 @@
package firrtl.passes
-import scala.collection.mutable
-import firrtl.ir._
-import AnalysisUtils._
-import MemTransformUtils._
import firrtl._
+import firrtl.ir._
import firrtl.Utils._
-import MemPortUtils._
import firrtl.Mappers._
+import MemPortUtils._
+import MemTransformUtils._
+import AnalysisUtils._
class ReplaceMemMacros(writer: ConfWriter) extends Pass {
-
- def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed) + configuration file"
-
- def run(c: Circuit) = {
-
- lazy val moduleNamespace = Namespace(c)
- val memMods = mutable.ArrayBuffer[DefModule]()
- val uniqueMems = mutable.ArrayBuffer[DefMemory]()
-
- def updateMemMods(m: Module) = {
- val memPortMap = mutable.HashMap[String, Expression]()
-
- def updateMemStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") =>
- if(!containsInfo(m.info, "maskGran")) {
- m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression }
- m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression }
- }
- val infoT = getInfo(m.info, "info")
- val info = if (infoT == None) NoInfo else infoT.get match { case i: Info => i }
- val ref = getInfo(m.info, "ref")
-
- // prototype mem
- if (ref == None) {
- val newWrapperName = moduleNamespace.newName(m.name)
- val newMemBBName = moduleNamespace.newName(m.name + "_ext")
- val newMem = m.copy(name = newMemBBName)
- memMods ++= createMemModule(newMem, newWrapperName)
- uniqueMems += newMem
- WDefInstance(info, m.name, newWrapperName, UnknownType)
- }
- else {
- val r = ref.get match { case s: String => s }
- WDefInstance(info, m.name, r, UnknownType)
- }
- case b: Block => b map updateMemStmts
- case s => s
- }
-
- val updatedMems = updateMemStmts(m.body)
- val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap)
- m.copy(body = updatedConns)
- }
-
- val updatedMods = c.modules map {
- case m: Module => updateMemMods(m)
- case m: ExtModule => m
- }
-
- // print conf
- writer.serialize
- c.copy(modules = updatedMods ++ memMods.toSeq)
- }
+ def name = "Replace memories with black box wrappers" +
+ " (optimizes when write mask isn't needed) + configuration file"
// from Albert
def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
- val stmts = mutable.ArrayBuffer[Statement]()
- val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
- val bbProto = m.copy(dataType = flattenType(m.dataType))
- val bbioPorts = MemPortUtils.memToFlattenBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
-
- stmts += WDefInstance(NoInfo, m.name, m.name, UnknownType)
- val bbRef = createRef(m.name)
- stmts ++= (m.readers zip bbProto.readers).flatMap {
- case (x, y) => adaptReader(createRef(x), m, createSubField(bbRef, y), bbProto)
- }
- stmts ++= (m.writers zip bbProto.writers).flatMap {
- case (x, y) => adaptWriter(createRef(x), m, createSubField(bbRef, y), bbProto)
- }
- stmts ++= (m.readwriters zip bbProto.readwriters).flatMap {
- case (x, y) => adaptReadWriter(createRef(x), m, createSubField(bbRef, y), bbProto)
- }
- val wrapper = Module(NoInfo, wrapperName, wrapperioPorts, Block(stmts))
- val bb = ExtModule(NoInfo, m.name, bbioPorts)
+ val wrapperIoType = memToBundle(m)
+ val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbIoType = memToFlattenBundle(m)
+ val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbRef = createRef(m.name, bbIoType)
+ val hasMask = containsInfo(m.info, "maskGran")
+ val fillMask = getFillWMask(m)
+ def portRef(p: String) = createRef(p, field_type(wrapperIoType, p))
+ val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
+ (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++
+ (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask)))
+ val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
+ val bb = ExtModule(NoInfo, m.name, bbIoPorts)
// TODO: Annotate? -- use actual annotation map
// add to conf file
@@ -95,75 +39,86 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
}
// TODO: get rid of copy pasta
- def adaptReader(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = Seq(
- connectFields(bbPort, "addr", wrapperPort, "addr"),
- connectFields(bbPort, "en", wrapperPort, "en"),
- connectFields(bbPort, "clk", wrapperPort, "clk"),
- fromBits(
- WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER),
- WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER)
- )
- )
-
- def adaptWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = {
- val defaultSeq = Seq(
- connectFields(bbPort, "addr", wrapperPort, "addr"),
- connectFields(bbPort, "en", wrapperPort, "en"),
- connectFields(bbPort, "clk", wrapperPort, "clk"),
- Connect(
- NoInfo,
- WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER),
- toBits(WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER))
- )
- )
- if (containsInfo(wrapperMem.info, "maskGran")) {
- val wrapperMask = createMask(wrapperMem.dataType)
- val fillWMask = getFillWMask(wrapperMem)
- val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
- val rhs = {
- if (fillWMask) toBitMask(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType)
- else toBits(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER))
- }
- defaultSeq :+ Connect(
+ def defaultConnects(wrapperPort: WRef, bbPort: WSubField) =
+ Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
+
+ def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) =
+ if (fillMask) toBitMask(mask, dataType) else toBits(mask)
+
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField) =
+ defaultConnects(wrapperPort, bbPort) :+
+ fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data"))
+
+ def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
+ val wrapperData = createSubField(wrapperPort, "data")
+ val defaultSeq = defaultConnects(wrapperPort, bbPort) :+
+ Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData))
+ hasMask match {
+ case false => defaultSeq
+ case true => defaultSeq :+ Connect(
NoInfo,
- WSubField(bbPort, "mask", bbMask, UNKNOWNGENDER),
- rhs
+ createSubField(bbPort, "mask"),
+ maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
)
}
- else defaultSeq
}
- def adaptReadWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = {
- val defaultSeq = Seq(
- connectFields(bbPort, "addr", wrapperPort, "addr"),
- connectFields(bbPort, "en", wrapperPort, "en"),
- connectFields(bbPort, "clk", wrapperPort, "clk"),
- connectFields(bbPort, "wmode", wrapperPort, "wmode"),
- Connect(
- NoInfo,
- WSubField(bbPort, "wdata", bbMem.dataType, UNKNOWNGENDER),
- toBits(WSubField(wrapperPort, "wdata", wrapperMem.dataType, UNKNOWNGENDER))
- ),
- fromBits(
- WSubField(wrapperPort, "rdata", wrapperMem.dataType, UNKNOWNGENDER),
- WSubField(bbPort, "rdata", bbMem.dataType, UNKNOWNGENDER)
- )
- )
- if (containsInfo(wrapperMem.info, "maskGran")) {
- val wrapperMask = createMask(wrapperMem.dataType)
- val fillWMask = getFillWMask(wrapperMem)
- val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
- val rhs = {
- if (fillWMask) toBitMask(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType)
- else toBits(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER))
- }
- defaultSeq :+ Connect(
+ def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
+ val wrapperWData = createSubField(wrapperPort, "wdata")
+ val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
+ fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")),
+ connectFields(bbPort, "wmode", wrapperPort, "wmode"),
+ Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ hasMask match {
+ case false => defaultSeq
+ case true => defaultSeq :+ Connect(
NoInfo,
- WSubField(bbPort, "wmask", bbMask, UNKNOWNGENDER),
- rhs
+ createSubField(bbPort, "wmask"),
+ maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
)
}
- else defaultSeq
}
+ def updateMemStmts(namespace: Namespace,
+ memPortMap: MemPortMap,
+ memMods: Modules)
+ (s: Statement): Statement = s match {
+ case m: DefMemory if containsInfo(m.info, "useMacro") =>
+ if (!containsInfo(m.info, "maskGran")) {
+ m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression }
+ m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression }
+ }
+ val info = getInfo(m.info, "info") match {
+ case None => NoInfo
+ case Some(p: Info) => p
+ }
+ getInfo(m.info, "ref") match {
+ case None =>
+ // prototype mem
+ val newWrapperName = namespace newName m.name
+ val newMemBBName = namespace newName s"${m.name}_ext"
+ val newMem = m copy (name = newMemBBName)
+ memMods ++= createMemModule(newMem, newWrapperName)
+ WDefInstance(info, m.name, newWrapperName, UnknownType)
+ case Some(ref: String) =>
+ WDefInstance(info, m.name, ref, UnknownType)
+ }
+ case s => s map updateMemStmts(namespace, memPortMap, memMods)
+ }
+
+ def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = {
+ val memPortMap = new MemPortMap
+
+ (m map updateMemStmts(namespace, memPortMap, memMods)
+ map updateStmtRefs(memPortMap))
+ }
+
+ def run(c: Circuit) = {
+ val namespace = Namespace(c)
+ val memMods = new Modules
+ val modules = c.modules map updateMemMods(namespace, memMods)
+ // print conf
+ writer.serialize
+ c copy (modules = modules ++ memMods)
+ }
}
diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
index 0098fa5f..fbff9bd6 100644
--- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
@@ -2,23 +2,73 @@
package firrtl.passes
-import scala.collection.mutable
-import AnalysisUtils._
-import MemTransformUtils._
-import firrtl.ir._
import firrtl._
-import firrtl.Mappers._
+import firrtl.ir._
import firrtl.Utils._
+import firrtl.Mappers._
+import AnalysisUtils._
+import MemPortUtils._
+import MemTransformUtils._
object MemTransformUtils {
+ def getFillWMask(mem: DefMemory) =
+ getInfo(mem.info, "maskGran") match {
+ case None => false
+ case Some(maskGran) => maskGran == 1
+ }
+
+ def rPortToBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ def rPortToFlattenBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+
+ def wPortToBundle(mem: DefMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++
+ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else Seq(Field("mask", Default, createMask(mem.dataType))))
+ )
+ def wPortToFlattenBundle(mem: DefMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++
+ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType)))
+ else Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))))
+ )
+ // TODO: Don't use createMask???
+
+ def rwPortToBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("rdata", Flip, mem.dataType)
+ ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else Seq(Field("wmask", Default, createMask(mem.dataType)))
+ )
+ )
+
+ def rwPortToFlattenBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, flattenType(mem.dataType)),
+ Field("rdata", Flip, flattenType(mem.dataType))
+ ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType)))
+ else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
+ )
+ )
+
+ def memToBundle(s: DefMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+
+ def memToFlattenBundle(s: DefMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
- def createRef(n: String) = WRef(n, UnknownType, ExpKind, UNKNOWNGENDER)
- def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER)
- def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) =
- Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))
def getMemPortMap(m: DefMemory) = {
- val memPortMap = mutable.HashMap[String, Expression]()
+ val memPortMap = new MemPortMap
val defaultFields = Seq("addr", "en", "clk")
val rFields = defaultFields :+ "data"
val wFields = rFields :+ "mask"
@@ -33,35 +83,41 @@ object MemTransformUtils {
updateMemPortMap(m.readers, rFields, "R")
updateMemPortMap(m.writers, wFields, "W")
updateMemPortMap(m.readwriters, rwFields, "RW")
- memPortMap.toMap
+ memPortMap
}
+
def createMemProto(m: DefMemory) = {
val rports = (0 until m.readers.length) map (i => s"R$i")
val wports = (0 until m.writers.length) map (i => s"W$i")
val rwports = (0 until m.readwriters.length) map (i => s"RW$i")
- m.copy(readers = rports, writers = wports, readwriters = rwports)
+ m copy (readers = rports, writers = wports, readwriters = rwports)
}
- def updateStmtRefs(s: Statement, repl: Map[String, Expression]): Statement = {
- def updateRef(e: Expression): Expression = e map updateRef match {
- case e => repl getOrElse (e.serialize, e)
+ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
+ def updateRef(e: Expression): Expression = {
+ val ex = e map updateRef
+ repl getOrElse (ex.serialize, ex)
}
+
def hasEmptyExpr(stmt: Statement): Boolean = {
var foundEmpty = false
def testEmptyExpr(e: Expression): Expression = {
- e map testEmptyExpr match {
+ e match {
case EmptyExpression => foundEmpty = true
case _ =>
}
- e // map must return; no foreach
+ e map testEmptyExpr // map must return; no foreach
}
stmt map testEmptyExpr
foundEmpty
}
- def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match {
- case c: Connect if hasEmptyExpr(c) => EmptyStmt
- case s => s
- }
+
+ def updateStmtRefs(s: Statement): Statement =
+ s map updateStmtRefs map updateRef match {
+ case c: Connect if hasEmptyExpr(c) => EmptyStmt
+ case s => s
+ }
+
updateStmtRefs(s)
}
@@ -71,37 +127,29 @@ object UpdateDuplicateMemMacros extends Pass {
def name = "Convert memory port names to be more meaningful and tag duplicate memories"
- def run(c: Circuit) = {
- val uniqueMems = mutable.ArrayBuffer[DefMemory]()
-
- def updateMemMods(m: Module) = {
- val memPortMap = mutable.HashMap[String, Expression]()
-
- def updateMemStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") =>
- val updatedMem = createMemProto(m)
- memPortMap ++= getMemPortMap(m)
- val proto = uniqueMems find (x => eqMems(x, updatedMem))
- if (proto == None) {
- uniqueMems += updatedMem
- updatedMem
- }
- else updatedMem.copy(info = appendInfo(updatedMem.info, "ref" -> proto.get.name))
- case b: Block => b map updateMemStmts
- case s => s
+ def updateMemStmts(uniqueMems: Memories,
+ memPortMap: MemPortMap)
+ (s: Statement): Statement = s match {
+ case m: DefMemory if containsInfo(m.info, "useMacro") =>
+ val updatedMem = createMemProto(m)
+ memPortMap ++= getMemPortMap(m)
+ uniqueMems find (x => eqMems(x, updatedMem)) match {
+ case None =>
+ uniqueMems += updatedMem
+ updatedMem
+ case Some(proto) =>
+ updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name))
}
+ case s => s map updateMemStmts(uniqueMems, memPortMap)
+ }
- val updatedMems = updateMemStmts(m.body)
- val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap)
- m.copy(body = updatedConns)
- }
-
- val updatedMods = c.modules map {
- case m: Module => updateMemMods(m)
- case m: ExtModule => m
- }
- c.copy(modules = updatedMods)
- }
+ def updateMemMods(m: DefModule) = {
+ val uniqueMems = new Memories
+ val memPortMap = new MemPortMap
+ (m map updateMemStmts(uniqueMems, memPortMap)
+ map updateStmtRefs(memPortMap))
+ }
+ def run(c: Circuit) = c copy (modules = (c.modules map updateMemMods))
}
// TODO: Module namespace?
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index 7e3383b2..3af018bd 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -38,7 +38,7 @@ class InferReadWriteSpec extends SimpleTransformSpec {
val name = "Check Infer ReadWrite Ports"
def findReadWrite(s: Statement): Boolean = s match {
case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 =>
- s.name == "mem" && s.readwriters.head == "rw_0"
+ s.name == "mem" && s.readwriters.head == "rw"
case s: Block =>
s.stmts exists findReadWrite
case _ => false
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 7219b1ce..8aeafc9e 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -107,7 +107,7 @@ circuit Top :
val circuit = InferTypes.run(ToWorkingIR.run(parse(input)))
val m = circuit.modules.head.asInstanceOf[ir.Module]
val connects = AnalysisUtils.getConnects(m)
- val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects,"f").serialize
+ val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects)("f").serialize
require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !")
}