aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu Kim2016-09-16 14:32:43 -0700
committerDonggyu Kim2016-09-21 13:19:09 -0700
commita142551bfcce6b05e445bc75dd284d994c8e91f2 (patch)
tree368a2db73034e411dc89a30d0b137bca3bcd3739 /src
parent6fede8c92edd414ba63ed185fbad2cc48fd29d01 (diff)
refactor ReplaceMemMacros
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/ReplaceMemMacros.scala228
-rw-r--r--src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala7
2 files changed, 97 insertions, 138 deletions
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
index 0ca1a32f..78211cca 100644
--- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
@@ -2,91 +2,36 @@
package firrtl.passes
-import scala.collection.mutable
-import firrtl.ir._
-import AnalysisUtils._
-import MemTransformUtils._
import firrtl._
+import firrtl.ir._
import firrtl.Utils._
-import MemPortUtils._
import firrtl.Mappers._
+import MemPortUtils._
+import MemTransformUtils._
+import AnalysisUtils._
class ReplaceMemMacros(writer: ConfWriter) extends Pass {
+ def name = "Replace memories with black box wrappers" +
+ " (optimizes when write mask isn't needed) + configuration file"
- def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed) + configuration file"
-
- def run(c: Circuit) = {
-
- lazy val moduleNamespace = Namespace(c)
- val memMods = mutable.ArrayBuffer[DefModule]()
- val uniqueMems = mutable.ArrayBuffer[DefMemory]()
-
- def updateMemMods(m: Module) = {
- val memPortMap = new MemPortMap
-
- def updateMemStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info, "useMacro") =>
- if(!containsInfo(m.info, "maskGran")) {
- m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression }
- m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression }
- }
- val infoT = getInfo(m.info, "info")
- val info = if (infoT == None) NoInfo else infoT.get match { case i: Info => i }
- val ref = getInfo(m.info, "ref")
-
- // prototype mem
- if (ref == None) {
- val newWrapperName = moduleNamespace.newName(m.name)
- val newMemBBName = moduleNamespace.newName(m.name + "_ext")
- val newMem = m.copy(name = newMemBBName)
- memMods ++= createMemModule(newMem, newWrapperName)
- uniqueMems += newMem
- WDefInstance(info, m.name, newWrapperName, UnknownType)
- }
- else {
- val r = ref.get match { case s: String => s }
- WDefInstance(info, m.name, r, UnknownType)
- }
- case b: Block => b map updateMemStmts
- case s => s
- }
-
- val updatedMems = updateMemStmts(m.body)
- val updatedConns = updateStmtRefs(memPortMap)(updatedMems)
- m.copy(body = updatedConns)
- }
-
- val updatedMods = c.modules map {
- case m: Module => updateMemMods(m)
- case m: ExtModule => m
- }
-
- // print conf
- writer.serialize
- c.copy(modules = updatedMods ++ memMods.toSeq)
- }
// from Albert
def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
- val stmts = mutable.ArrayBuffer[Statement]()
- val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
- val bbProto = m.copy(dataType = flattenType(m.dataType))
- val bbioPorts = MemPortUtils.memToFlattenBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
-
- stmts += WDefInstance(NoInfo, m.name, m.name, UnknownType)
- val bbRef = createRef(m.name)
- stmts ++= (m.readers zip bbProto.readers).flatMap {
- case (x, y) => adaptReader(createRef(x), m, createSubField(bbRef, y), bbProto)
- }
- stmts ++= (m.writers zip bbProto.writers).flatMap {
- case (x, y) => adaptWriter(createRef(x), m, createSubField(bbRef, y), bbProto)
- }
- stmts ++= (m.readwriters zip bbProto.readwriters).flatMap {
- case (x, y) => adaptReadWriter(createRef(x), m, createSubField(bbRef, y), bbProto)
- }
- val wrapper = Module(NoInfo, wrapperName, wrapperioPorts, Block(stmts))
- val bb = ExtModule(NoInfo, m.name, bbioPorts)
+ val wrapperIoType = MemPortUtils.memToBundle(m)
+ val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbIoType = MemPortUtils.memToFlattenBundle(m)
+ val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbRef = createRef(m.name, bbIoType)
+ val hasMask = containsInfo(m.info, "maskGran")
+ val fillMask = getFillWMask(m)
+ def portRef(p: String) = createRef(p, field_type(wrapperIoType, p))
+ val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
+ (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++
+ (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask)))
+ val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
+ val bb = ExtModule(NoInfo, m.name, bbIoPorts)
// TODO: Annotate? -- use actual annotation map
// add to conf file
@@ -95,75 +40,86 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
}
// TODO: get rid of copy pasta
- def adaptReader(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = Seq(
- connectFields(bbPort, "addr", wrapperPort, "addr"),
- connectFields(bbPort, "en", wrapperPort, "en"),
- connectFields(bbPort, "clk", wrapperPort, "clk"),
- fromBits(
- WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER),
- WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER)
- )
- )
-
- def adaptWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = {
- val defaultSeq = Seq(
- connectFields(bbPort, "addr", wrapperPort, "addr"),
- connectFields(bbPort, "en", wrapperPort, "en"),
- connectFields(bbPort, "clk", wrapperPort, "clk"),
- Connect(
- NoInfo,
- WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER),
- toBits(WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER))
- )
- )
- if (containsInfo(wrapperMem.info, "maskGran")) {
- val wrapperMask = createMask(wrapperMem.dataType)
- val fillWMask = getFillWMask(wrapperMem)
- val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
- val rhs = {
- if (fillWMask) toBitMask(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType)
- else toBits(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER))
- }
- defaultSeq :+ Connect(
+ def defaultConnects(wrapperPort: WRef, bbPort: WSubField) =
+ Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
+
+ def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) =
+ if (fillMask) toBitMask(mask, dataType) else toBits(mask)
+
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField) =
+ defaultConnects(wrapperPort, bbPort) :+
+ fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data"))
+
+ def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
+ val wrapperData = createSubField(wrapperPort, "data")
+ val defaultSeq = defaultConnects(wrapperPort, bbPort) :+
+ Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData))
+ hasMask match {
+ case false => defaultSeq
+ case true => defaultSeq :+ Connect(
NoInfo,
- WSubField(bbPort, "mask", bbMask, UNKNOWNGENDER),
- rhs
+ createSubField(bbPort, "mask"),
+ maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
)
}
- else defaultSeq
}
- def adaptReadWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = {
- val defaultSeq = Seq(
- connectFields(bbPort, "addr", wrapperPort, "addr"),
- connectFields(bbPort, "en", wrapperPort, "en"),
- connectFields(bbPort, "clk", wrapperPort, "clk"),
- connectFields(bbPort, "wmode", wrapperPort, "wmode"),
- Connect(
- NoInfo,
- WSubField(bbPort, "wdata", bbMem.dataType, UNKNOWNGENDER),
- toBits(WSubField(wrapperPort, "wdata", wrapperMem.dataType, UNKNOWNGENDER))
- ),
- fromBits(
- WSubField(wrapperPort, "rdata", wrapperMem.dataType, UNKNOWNGENDER),
- WSubField(bbPort, "rdata", bbMem.dataType, UNKNOWNGENDER)
- )
- )
- if (containsInfo(wrapperMem.info, "maskGran")) {
- val wrapperMask = createMask(wrapperMem.dataType)
- val fillWMask = getFillWMask(wrapperMem)
- val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
- val rhs = {
- if (fillWMask) toBitMask(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType)
- else toBits(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER))
- }
- defaultSeq :+ Connect(
+ def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = {
+ val wrapperWData = createSubField(wrapperPort, "wdata")
+ val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
+ fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")),
+ connectFields(bbPort, "wmode", wrapperPort, "wmode"),
+ Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ hasMask match {
+ case false => defaultSeq
+ case true => defaultSeq :+ Connect(
NoInfo,
- WSubField(bbPort, "wmask", bbMask, UNKNOWNGENDER),
- rhs
+ createSubField(bbPort, "wmask"),
+ maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
)
}
- else defaultSeq
}
+ def updateMemStmts(namespace: Namespace,
+ memPortMap: MemPortMap,
+ memMods: Modules)
+ (s: Statement): Statement = s match {
+ case m: DefMemory if containsInfo(m.info, "useMacro") =>
+ if (!containsInfo(m.info, "maskGran")) {
+ m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression }
+ m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression }
+ }
+ val info = getInfo(m.info, "info") match {
+ case None => NoInfo
+ case Some(p: Info) => p
+ }
+ getInfo(m.info, "ref") match {
+ case None =>
+ // prototype mem
+ val newWrapperName = namespace newName m.name
+ val newMemBBName = namespace newName s"${m.name}_ext"
+ val newMem = m copy (name = newMemBBName)
+ memMods ++= createMemModule(newMem, newWrapperName)
+ WDefInstance(info, m.name, newWrapperName, UnknownType)
+ case Some(ref: String) =>
+ WDefInstance(info, m.name, ref, UnknownType)
+ }
+ case s => s map updateMemStmts(namespace, memPortMap, memMods)
+ }
+
+ def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = {
+ val memPortMap = new MemPortMap
+
+ (m map updateMemStmts(namespace, memPortMap, memMods)
+ map updateStmtRefs(memPortMap))
+ }
+
+ def run(c: Circuit) = {
+ val namespace = Namespace(c)
+ val memMods = new Modules
+ val modules = c.modules map updateMemMods(namespace, memMods)
+ // print conf
+ writer.serialize
+ c copy (modules = modules ++ memMods)
+ }
}
diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
index 098d83f0..0a685c3c 100644
--- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
@@ -13,9 +13,12 @@ object MemTransformUtils {
type MemPortMap = collection.mutable.HashMap[String, Expression]
type Memories = collection.mutable.ArrayBuffer[DefMemory]
+ type Modules = collection.mutable.ArrayBuffer[DefModule]
- def createRef(n: String) = WRef(n, UnknownType, ExpKind, UNKNOWNGENDER)
- def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER)
+ def createRef(n: String, t: Type = UnknownType, k: Kind = ExpKind) =
+ WRef(n, t, k, UNKNOWNGENDER)
+ def createSubField(exp: Expression, n: String) =
+ WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER)
def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) =
Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))