aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala')
-rw-r--r--src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala148
1 files changed, 98 insertions, 50 deletions
diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
index 0098fa5f..fbff9bd6 100644
--- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
@@ -2,23 +2,73 @@
package firrtl.passes
-import scala.collection.mutable
-import AnalysisUtils._
-import MemTransformUtils._
-import firrtl.ir._
import firrtl._
-import firrtl.Mappers._
+import firrtl.ir._
import firrtl.Utils._
+import firrtl.Mappers._
+import AnalysisUtils._
+import MemPortUtils._
+import MemTransformUtils._
object MemTransformUtils {
+ def getFillWMask(mem: DefMemory) =
+ getInfo(mem.info, "maskGran") match {
+ case None => false
+ case Some(maskGran) => maskGran == 1
+ }
+
+ def rPortToBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ def rPortToFlattenBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+
+ def wPortToBundle(mem: DefMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++
+ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else Seq(Field("mask", Default, createMask(mem.dataType))))
+ )
+ def wPortToFlattenBundle(mem: DefMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++
+ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType)))
+ else Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))))
+ )
+ // TODO: Don't use createMask???
+
+ def rwPortToBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("rdata", Flip, mem.dataType)
+ ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else Seq(Field("wmask", Default, createMask(mem.dataType)))
+ )
+ )
+
+ def rwPortToFlattenBundle(mem: DefMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, flattenType(mem.dataType)),
+ Field("rdata", Flip, flattenType(mem.dataType))
+ ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil
+ else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType)))
+ else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
+ )
+ )
+
+ def memToBundle(s: DefMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+
+ def memToFlattenBundle(s: DefMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
- def createRef(n: String) = WRef(n, UnknownType, ExpKind, UNKNOWNGENDER)
- def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER)
- def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) =
- Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))
def getMemPortMap(m: DefMemory) = {
- val memPortMap = mutable.HashMap[String, Expression]()
+ val memPortMap = new MemPortMap
val defaultFields = Seq("addr", "en", "clk")
val rFields = defaultFields :+ "data"
val wFields = rFields :+ "mask"
@@ -33,35 +83,41 @@ object MemTransformUtils {
updateMemPortMap(m.readers, rFields, "R")
updateMemPortMap(m.writers, wFields, "W")
updateMemPortMap(m.readwriters, rwFields, "RW")
- memPortMap.toMap
+ memPortMap
}
+
def createMemProto(m: DefMemory) = {
val rports = (0 until m.readers.length) map (i => s"R$i")
val wports = (0 until m.writers.length) map (i => s"W$i")
val rwports = (0 until m.readwriters.length) map (i => s"RW$i")
- m.copy(readers = rports, writers = wports, readwriters = rwports)
+ m copy (readers = rports, writers = wports, readwriters = rwports)
}
- def updateStmtRefs(s: Statement, repl: Map[String, Expression]): Statement = {
- def updateRef(e: Expression): Expression = e map updateRef match {
- case e => repl getOrElse (e.serialize, e)
+ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
+ def updateRef(e: Expression): Expression = {
+ val ex = e map updateRef
+ repl getOrElse (ex.serialize, ex)
}
+
def hasEmptyExpr(stmt: Statement): Boolean = {
var foundEmpty = false
def testEmptyExpr(e: Expression): Expression = {
- e map testEmptyExpr match {
+ e match {
case EmptyExpression => foundEmpty = true
case _ =>
}
- e // map must return; no foreach
+ e map testEmptyExpr // map must return; no foreach
}
stmt map testEmptyExpr
foundEmpty
}
- def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match {
- case c: Connect if hasEmptyExpr(c) => EmptyStmt
- case s => s
- }
+
+ def updateStmtRefs(s: Statement): Statement =
+ s map updateStmtRefs map updateRef match {
+ case c: Connect if hasEmptyExpr(c) => EmptyStmt
+ case s => s
+ }
+
updateStmtRefs(s)
}
@@ -71,37 +127,29 @@ object UpdateDuplicateMemMacros extends Pass {
def name = "Convert memory port names to be more meaningful and tag duplicate memories"
- def run(c: Circuit) = {
- val uniqueMems = mutable.ArrayBuffer[DefMemory]()
-
- def updateMemMods(m: Module) = {
- val memPortMap = mutable.HashMap[String, Expression]()
-
- def updateMemStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") =>
- val updatedMem = createMemProto(m)
- memPortMap ++= getMemPortMap(m)
- val proto = uniqueMems find (x => eqMems(x, updatedMem))
- if (proto == None) {
- uniqueMems += updatedMem
- updatedMem
- }
- else updatedMem.copy(info = appendInfo(updatedMem.info, "ref" -> proto.get.name))
- case b: Block => b map updateMemStmts
- case s => s
+ def updateMemStmts(uniqueMems: Memories,
+ memPortMap: MemPortMap)
+ (s: Statement): Statement = s match {
+ case m: DefMemory if containsInfo(m.info, "useMacro") =>
+ val updatedMem = createMemProto(m)
+ memPortMap ++= getMemPortMap(m)
+ uniqueMems find (x => eqMems(x, updatedMem)) match {
+ case None =>
+ uniqueMems += updatedMem
+ updatedMem
+ case Some(proto) =>
+ updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name))
}
+ case s => s map updateMemStmts(uniqueMems, memPortMap)
+ }
- val updatedMems = updateMemStmts(m.body)
- val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap)
- m.copy(body = updatedConns)
- }
-
- val updatedMods = c.modules map {
- case m: Module => updateMemMods(m)
- case m: ExtModule => m
- }
- c.copy(modules = updatedMods)
- }
+ def updateMemMods(m: DefModule) = {
+ val uniqueMems = new Memories
+ val memPortMap = new MemPortMap
+ (m map updateMemStmts(uniqueMems, memPortMap)
+ map updateStmtRefs(memPortMap))
+ }
+ def run(c: Circuit) = c copy (modules = (c.modules map updateMemMods))
}
// TODO: Module namespace?