diff options
| author | Donggyu Kim | 2016-09-16 14:32:43 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-21 13:19:09 -0700 |
| commit | a142551bfcce6b05e445bc75dd284d994c8e91f2 (patch) | |
| tree | 368a2db73034e411dc89a30d0b137bca3bcd3739 /src | |
| parent | 6fede8c92edd414ba63ed185fbad2cc48fd29d01 (diff) | |
refactor ReplaceMemMacros
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceMemMacros.scala | 228 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala | 7 |
2 files changed, 97 insertions, 138 deletions
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala index 0ca1a32f..78211cca 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala @@ -2,91 +2,36 @@ package firrtl.passes -import scala.collection.mutable -import firrtl.ir._ -import AnalysisUtils._ -import MemTransformUtils._ import firrtl._ +import firrtl.ir._ import firrtl.Utils._ -import MemPortUtils._ import firrtl.Mappers._ +import MemPortUtils._ +import MemTransformUtils._ +import AnalysisUtils._ class ReplaceMemMacros(writer: ConfWriter) extends Pass { + def name = "Replace memories with black box wrappers" + + " (optimizes when write mask isn't needed) + configuration file" - def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed) + configuration file" - - def run(c: Circuit) = { - - lazy val moduleNamespace = Namespace(c) - val memMods = mutable.ArrayBuffer[DefModule]() - val uniqueMems = mutable.ArrayBuffer[DefMemory]() - - def updateMemMods(m: Module) = { - val memPortMap = new MemPortMap - - def updateMemStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => - if(!containsInfo(m.info, "maskGran")) { - m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression } - m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression } - } - val infoT = getInfo(m.info, "info") - val info = if (infoT == None) NoInfo else infoT.get match { case i: Info => i } - val ref = getInfo(m.info, "ref") - - // prototype mem - if (ref == None) { - val newWrapperName = moduleNamespace.newName(m.name) - val newMemBBName = moduleNamespace.newName(m.name + "_ext") - val newMem = m.copy(name = newMemBBName) - memMods ++= createMemModule(newMem, newWrapperName) - uniqueMems += newMem - WDefInstance(info, m.name, newWrapperName, UnknownType) - } - else { - val r = ref.get match { case s: String => s } - WDefInstance(info, m.name, r, UnknownType) - } - case b: Block => b map updateMemStmts - case s => s - } - - val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(memPortMap)(updatedMems) - m.copy(body = updatedConns) - } - - val updatedMods = c.modules map { - case m: Module => updateMemMods(m) - case m: ExtModule => m - } - - // print conf - writer.serialize - c.copy(modules = updatedMods ++ memMods.toSeq) - } // from Albert def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) - val stmts = mutable.ArrayBuffer[Statement]() - val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) - val bbProto = m.copy(dataType = flattenType(m.dataType)) - val bbioPorts = MemPortUtils.memToFlattenBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) - - stmts += WDefInstance(NoInfo, m.name, m.name, UnknownType) - val bbRef = createRef(m.name) - stmts ++= (m.readers zip bbProto.readers).flatMap { - case (x, y) => adaptReader(createRef(x), m, createSubField(bbRef, y), bbProto) - } - stmts ++= (m.writers zip bbProto.writers).flatMap { - case (x, y) => adaptWriter(createRef(x), m, createSubField(bbRef, y), bbProto) - } - stmts ++= (m.readwriters zip bbProto.readwriters).flatMap { - case (x, y) => adaptReadWriter(createRef(x), m, createSubField(bbRef, y), bbProto) - } - val wrapper = Module(NoInfo, wrapperName, wrapperioPorts, Block(stmts)) - val bb = ExtModule(NoInfo, m.name, bbioPorts) + val wrapperIoType = MemPortUtils.memToBundle(m) + val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbIoType = MemPortUtils.memToFlattenBundle(m) + val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbRef = createRef(m.name, bbIoType) + val hasMask = containsInfo(m.info, "maskGran") + val fillMask = getFillWMask(m) + def portRef(p: String) = createRef(p, field_type(wrapperIoType, p)) + val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ + (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++ + (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask))) + val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) + val bb = ExtModule(NoInfo, m.name, bbIoPorts) // TODO: Annotate? -- use actual annotation map // add to conf file @@ -95,75 +40,86 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { } // TODO: get rid of copy pasta - def adaptReader(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = Seq( - connectFields(bbPort, "addr", wrapperPort, "addr"), - connectFields(bbPort, "en", wrapperPort, "en"), - connectFields(bbPort, "clk", wrapperPort, "clk"), - fromBits( - WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER), - WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER) - ) - ) - - def adaptWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = { - val defaultSeq = Seq( - connectFields(bbPort, "addr", wrapperPort, "addr"), - connectFields(bbPort, "en", wrapperPort, "en"), - connectFields(bbPort, "clk", wrapperPort, "clk"), - Connect( - NoInfo, - WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER), - toBits(WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER)) - ) - ) - if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = createMask(wrapperMem.dataType) - val fillWMask = getFillWMask(wrapperMem) - val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) - val rhs = { - if (fillWMask) toBitMask(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType) - else toBits(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER)) - } - defaultSeq :+ Connect( + def defaultConnects(wrapperPort: WRef, bbPort: WSubField) = + Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) + + def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) = + if (fillMask) toBitMask(mask, dataType) else toBits(mask) + + def adaptReader(wrapperPort: WRef, bbPort: WSubField) = + defaultConnects(wrapperPort, bbPort) :+ + fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data")) + + def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = { + val wrapperData = createSubField(wrapperPort, "data") + val defaultSeq = defaultConnects(wrapperPort, bbPort) :+ + Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData)) + hasMask match { + case false => defaultSeq + case true => defaultSeq :+ Connect( NoInfo, - WSubField(bbPort, "mask", bbMask, UNKNOWNGENDER), - rhs + createSubField(bbPort, "mask"), + maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) ) } - else defaultSeq } - def adaptReadWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = { - val defaultSeq = Seq( - connectFields(bbPort, "addr", wrapperPort, "addr"), - connectFields(bbPort, "en", wrapperPort, "en"), - connectFields(bbPort, "clk", wrapperPort, "clk"), - connectFields(bbPort, "wmode", wrapperPort, "wmode"), - Connect( - NoInfo, - WSubField(bbPort, "wdata", bbMem.dataType, UNKNOWNGENDER), - toBits(WSubField(wrapperPort, "wdata", wrapperMem.dataType, UNKNOWNGENDER)) - ), - fromBits( - WSubField(wrapperPort, "rdata", wrapperMem.dataType, UNKNOWNGENDER), - WSubField(bbPort, "rdata", bbMem.dataType, UNKNOWNGENDER) - ) - ) - if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = createMask(wrapperMem.dataType) - val fillWMask = getFillWMask(wrapperMem) - val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) - val rhs = { - if (fillWMask) toBitMask(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType) - else toBits(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER)) - } - defaultSeq :+ Connect( + def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = { + val wrapperWData = createSubField(wrapperPort, "wdata") + val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( + fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")), + connectFields(bbPort, "wmode", wrapperPort, "wmode"), + Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData))) + hasMask match { + case false => defaultSeq + case true => defaultSeq :+ Connect( NoInfo, - WSubField(bbPort, "wmask", bbMask, UNKNOWNGENDER), - rhs + createSubField(bbPort, "wmask"), + maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) ) } - else defaultSeq } + def updateMemStmts(namespace: Namespace, + memPortMap: MemPortMap, + memMods: Modules) + (s: Statement): Statement = s match { + case m: DefMemory if containsInfo(m.info, "useMacro") => + if (!containsInfo(m.info, "maskGran")) { + m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression } + m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression } + } + val info = getInfo(m.info, "info") match { + case None => NoInfo + case Some(p: Info) => p + } + getInfo(m.info, "ref") match { + case None => + // prototype mem + val newWrapperName = namespace newName m.name + val newMemBBName = namespace newName s"${m.name}_ext" + val newMem = m copy (name = newMemBBName) + memMods ++= createMemModule(newMem, newWrapperName) + WDefInstance(info, m.name, newWrapperName, UnknownType) + case Some(ref: String) => + WDefInstance(info, m.name, ref, UnknownType) + } + case s => s map updateMemStmts(namespace, memPortMap, memMods) + } + + def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = { + val memPortMap = new MemPortMap + + (m map updateMemStmts(namespace, memPortMap, memMods) + map updateStmtRefs(memPortMap)) + } + + def run(c: Circuit) = { + val namespace = Namespace(c) + val memMods = new Modules + val modules = c.modules map updateMemMods(namespace, memMods) + // print conf + writer.serialize + c copy (modules = modules ++ memMods) + } } diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala index 098d83f0..0a685c3c 100644 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala @@ -13,9 +13,12 @@ object MemTransformUtils { type MemPortMap = collection.mutable.HashMap[String, Expression] type Memories = collection.mutable.ArrayBuffer[DefMemory] + type Modules = collection.mutable.ArrayBuffer[DefModule] - def createRef(n: String) = WRef(n, UnknownType, ExpKind, UNKNOWNGENDER) - def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER) + def createRef(n: String, t: Type = UnknownType, k: Kind = ExpKind) = + WRef(n, t, k, UNKNOWNGENDER) + def createSubField(exp: Expression, n: String) = + WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER) def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) = Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) |
