diff options
| author | Adam Izraelevitz | 2016-10-17 18:53:19 -0700 |
|---|---|---|
| committer | Angie Wang | 2016-10-17 18:53:19 -0700 |
| commit | 85baeda249e59c7d9d9f159aaf29ff46d685cf02 (patch) | |
| tree | cfb5f4a6a0a80f9033275de6e5e36b9d5b96faad /src/main/scala/firrtl/passes/memlib | |
| parent | 7d08b9a1486fef0459481f6e542464a29fbe1db5 (diff) | |
Reorganized memory blackboxing (#336)
* Reorganized memory blackboxing
Moved to new package memlib
Added comments
Moved utility functions around
Removed unused AnnotateValidMemConfigs.scala
* Fixed tests to pass
* Use DefAnnotatedMemory instead of AppendableInfo
* Broke passes up into simpler passes
AnnotateMemMacros ->
(ToMemIR, ResolveMaskGranularity)
UpdateDuplicateMemMacros ->
(RenameAnnotatedMemoryPorts, ResolveMemoryReference)
* Fixed to make tests run
* Minor changes from code review
* Removed vim comments and renamed ReplSeqMem
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib')
9 files changed, 710 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala new file mode 100644 index 00000000..6dca5961 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import Utils.indent + +case class DefAnnotatedMemory( + info: Info, + name: String, + dataType: Type, + depth: Int, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String], + readUnderWrite: Option[String], + maskGran: Option[BigInt], + memRef: Option[String] + //pins: Seq[Pin], + ) extends Statement with IsDeclaration { + def serialize: String = this.toMem.serialize + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def toMem = DefMemory(info, name, dataType, depth, + writeLatency, readLatency, readers, writers, + readwriters, readUnderWrite) +} diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala new file mode 100644 index 00000000..78a386b2 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala @@ -0,0 +1,47 @@ +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import AnalysisUtils._ +import MemPortUtils.{MemPortMap} + +object MemTransformUtils { + + /** Replaces references to old memory port names with new memory port names + */ + def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { + //TODO(izraelevitz): check speed + def updateRef(e: Expression): Expression = { + val ex = e map updateRef + repl getOrElse (ex.serialize, ex) + } + + def hasEmptyExpr(stmt: Statement): Boolean = { + var foundEmpty = false + def testEmptyExpr(e: Expression): Expression = { + e match { + case EmptyExpression => foundEmpty = true + case _ => + } + e map testEmptyExpr // map must return; no foreach + } + stmt map testEmptyExpr + foundEmpty + } + + def updateStmtRefs(s: Statement): Statement = + s map updateStmtRefs map updateRef match { + case c: Connect if hasEmptyExpr(c) => EmptyStmt + case s => s + } + + updateStmtRefs(s) + } + + def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem) + def memPortField(s: DefAnnotatedMemory, p: String, f: String): Expression = + MemPortUtils.memPortField(s.toMem, p, f) +} diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala new file mode 100644 index 00000000..168a6a48 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -0,0 +1,77 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import AnalysisUtils._ +import MemPortUtils._ +import MemTransformUtils._ + + +/** Changes memory port names to standard port names (i.e. RW0 instead T_408) + */ +object RenameAnnotatedMemoryPorts extends Pass { + + def name = "Rename Annotated Memory Ports" + + /** Renames memory ports to a standard naming scheme: + * - R0, R1, ... for each read port + * - W0, W1, ... for each write port + * - RW0, RW1, ... for each readwrite port + */ + def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = { + val rports = m.readers.indices map (i => s"R$i") + val wports = m.writers.indices map (i => s"W$i") + val rwports = m.readwriters.indices map (i => s"RW$i") + m copy (readers = rports, writers = wports, readwriters = rwports) + } + + /** Maps the serialized form of all memory port field names to the + * corresponding new memory port field Expression. + * E.g.: + * - ("m.read.addr") becomes (m.R0.addr) + */ + def getMemPortMap(m: DefAnnotatedMemory): MemPortMap = { + val memPortMap = new MemPortMap + val defaultFields = Seq("addr", "en", "clk") + val rFields = defaultFields :+ "data" + val wFields = rFields :+ "mask" + val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") + + def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit = + for ((p, i) <- ports.zipWithIndex; f <- fields) { + val newPort = createSubField(createRef(m.name), newPortKind + i) + val field = createSubField(newPort, f) + memPortMap(s"${m.name}.$p.$f") = field + } + updateMemPortMap(m.readers, rFields, "R") + updateMemPortMap(m.writers, wFields, "W") + updateMemPortMap(m.readwriters, rwFields, "RW") + memPortMap + } + + /** Replaces candidate memories with memories with standard port names + * Does not update the references (this is done via updateStmtRefs) + */ + def updateMemStmts(memPortMap: MemPortMap)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + val updatedMem = createMemProto(m) + memPortMap ++= getMemPortMap(m) + updatedMem + case s => s map updateMemStmts(memPortMap) + } + + /** Replaces candidate memories and their references with standard port names + */ + def updateMemMods(m: DefModule) = { + val memPortMap = new MemPortMap + (m map updateMemStmts(memPortMap) + map updateStmtRefs(memPortMap)) + } + + def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) +} diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala new file mode 100644 index 00000000..3139ef21 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -0,0 +1,188 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import MemPortUtils.{MemPortMap, Modules} +import MemTransformUtils._ +import AnalysisUtils._ + +/** Replace DefAnnotatedMemory with memory blackbox + wrapper + conf file. + * This will not generate wmask ports if not needed. + * Creates the minimum # of black boxes needed by the design. + */ +class ReplaceMemMacros(writer: ConfWriter) extends Pass { + def name = "Replace Memory Macros" + + /** Return true if mask granularity is per bit, false if per byte or unspecified + */ + private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match { + case None => false + case Some(v) => v == 1 + } + + private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + + private def wPortToBundle(mem: DefAnnotatedMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ (mem.maskGran match { + case None => Nil + case Some(_) => Seq(Field("mask", Default, createMask(mem.dataType))) + }) + ) + private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match { + case None => Nil + case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType))) + case Some(_) => Seq(Field("mask", Default, flattenType(createMask(mem.dataType)))) + }) + ) + // TODO(shunshou): Don't use createMask??? + + private def rwPortToBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("rdata", Flip, mem.dataType) + ) ++ (mem.maskGran match { + case None => Nil + case Some(_) => Seq(Field("wmask", Default, createMask(mem.dataType))) + }) + ) + private def rwPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), + Field("wdata", Default, flattenType(mem.dataType)), + Field("rdata", Flip, flattenType(mem.dataType)) + ) ++ (mem.maskGran match { + case None => Nil + case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType))) + case Some(_) => Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) + }) + ) + + def memToBundle(s: DefAnnotatedMemory) = BundleType( + s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) + def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( + s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) + + /** Creates a wrapper module and external module to replace a candidate memory + * The wrapper module has the same type as the memory it replaces + * The external module + */ + def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = { + assert(m.dataType != UnknownType) + val wrapperIoType = memToBundle(m) + val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + // Creates a type with the write/readwrite masks omitted if necessary + val bbIoType = memToFlattenBundle(m) + val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbRef = createRef(m.name, bbIoType) + val hasMask = m.maskGran.isDefined + val fillMask = getFillWMask(m) + def portRef(p: String) = createRef(p, field_type(wrapperIoType, p)) + val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ + (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++ + (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask))) + val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) + val bb = ExtModule(NoInfo, m.name, bbIoPorts) + // TODO: Annotate? -- use actual annotation map + + // add to conf file + writer.append(m) + Seq(bb, wrapper) + } + + // TODO(shunshou): get rid of copy pasta + // Connects the clk, en, and addr fields from the wrapperPort to the bbPort + def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = + Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) + + // Connects the clk, en, and addr fields from the wrapperPort to the bbPort + def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = + if (fillMask) toBitMask(mask, dataType) else toBits(mask) + + def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = + defaultConnects(wrapperPort, bbPort) :+ + fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data")) + + def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { + val wrapperData = createSubField(wrapperPort, "data") + val defaultSeq = defaultConnects(wrapperPort, bbPort) :+ + Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData)) + hasMask match { + case false => defaultSeq + case true => defaultSeq :+ Connect( + NoInfo, + createSubField(bbPort, "mask"), + maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) + ) + } + } + + def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { + val wrapperWData = createSubField(wrapperPort, "wdata") + val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( + fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")), + connectFields(bbPort, "wmode", wrapperPort, "wmode"), + Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData))) + hasMask match { + case false => defaultSeq + case true => defaultSeq :+ Connect( + NoInfo, + createSubField(bbPort, "wmask"), + maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) + ) + } + } + + def updateMemStmts(namespace: Namespace, + memPortMap: MemPortMap, + memMods: Modules) + (s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + if (m.maskGran.isEmpty) { + m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } + m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } + } + m.memRef match { + case None => + // prototype mem + val newWrapperName = namespace newName m.name + val newMemBBName = namespace newName s"${m.name}_ext" + val newMem = m copy (name = newMemBBName) + memMods ++= createMemModule(newMem, newWrapperName) + WDefInstance(m.info, m.name, newWrapperName, UnknownType) + case Some(ref: String) => + WDefInstance(m.info, m.name, ref, UnknownType) + } + case sx => sx map updateMemStmts(namespace, memPortMap, memMods) + } + + def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = { + val memPortMap = new MemPortMap + + (m map updateMemStmts(namespace, memPortMap, memMods) + map updateStmtRefs(memPortMap)) + } + + def run(c: Circuit) = { + val namespace = Namespace(c) + val memMods = new Modules + val modules = c.modules map updateMemMods(namespace, memMods) + // print conf + writer.serialize() + c copy (modules = modules ++ memMods) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala new file mode 100644 index 00000000..dfa828c9 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -0,0 +1,126 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import Annotations._ +import AnalysisUtils._ +import Utils.error +import java.io.{File, CharArrayWriter, PrintWriter} + +sealed trait PassOption +case object InputConfigFileName extends PassOption +case object OutputConfigFileName extends PassOption +case object PassCircuitName extends PassOption + +object PassConfigUtil { + type PassOptionMap = Map[PassOption, String] + + def getPassOptions(t: String, usage: String = "") = { + // can't use space to delimit sub arguments (otherwise, Driver.scala will throw error) + val passArgList = t.split(":").toList + + def nextPassOption(map: PassOptionMap, list: List[String]): PassOptionMap = { + list match { + case Nil => map + case "-i" :: value :: tail => + nextPassOption(map + (InputConfigFileName -> value), tail) + case "-o" :: value :: tail => + nextPassOption(map + (OutputConfigFileName -> value), tail) + case "-c" :: value :: tail => + nextPassOption(map + (PassCircuitName -> value), tail) + case option :: tail => + error("Unknown option " + option + usage) + } + } + nextPassOption(Map[PassOption, String](), passArgList) + } +} + +class ConfWriter(filename: String) { + val outputBuffer = new CharArrayWriter + def append(m: DefAnnotatedMemory) = { + // legacy + val maskGran = m.maskGran + val readers = List.fill(m.readers.length)("read") + val writers = List.fill(m.writers.length)(if (maskGran.isEmpty) "write" else "mwrite") + val readwriters = List.fill(m.readwriters.length)(if (maskGran.isEmpty) "rw" else "mrw") + val ports = (writers ++ readers ++ readwriters) mkString "," + val maskGranConf = maskGran match { case None => "" case Some(p) => s"mask_gran $p" } + val width = bitWidth(m.dataType) + val conf = s"name ${m.name} depth ${m.depth} width $width ports $ports $maskGranConf \n" + outputBuffer.append(conf) + } + def serialize() = { + val outputFile = new PrintWriter(filename) + outputFile.write(outputBuffer.toString) + outputFile.close() + } +} + +case class ReplSeqMemAnnotation(t: String, tID: TransID) + extends Annotation with Loose with Unstable { + + val usage = """ +[Optional] ReplSeqMem + Pass to replace sequential memories with blackboxes + configuration file + +Usage: + --replSeqMem -c:<circuit>:-i:<filename>:-o:<filename> + *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space! + +Required Arguments: + -o<filename> Specify the output configuration file + -c<compiler> Specify the target circuit + +Optional Arguments: + -i<filename> Specify the input configuration file (for additional optimizations) +""" + + val passOptions = PassConfigUtil.getPassOptions(t, usage) + val outputConfig = passOptions.getOrElse( + OutputConfigFileName, + error("No output config file provided for ReplSeqMem!" + usage) + ) + val passCircuit = passOptions.getOrElse( + PassCircuitName, + error("No circuit name specified for ReplSeqMem!" + usage) + ) + val target = CircuitName(passCircuit) + def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}")) +} + +class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { + def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) = + Seq(Legalize, + ToMemIR, + ResolveMaskGranularity, + RenameAnnotatedMemoryPorts, + ResolveMemoryReference, + //new AnnotateValidMemConfigs(inConfigFile), + new ReplaceMemMacros(outConfigFile), + RemoveEmpty, + CheckInitialization, + InferTypes, + Uniquify, + ResolveKinds, // Must be run for the transform to work! + ResolveGenders) + + def execute(c: Circuit, map: AnnotationMap) = map get transID match { + case Some(p) => p get CircuitName(c.main) match { + case Some(ReplSeqMemAnnotation(t, _)) => + val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) + run(c, passSeq(inConfigFile, outConfigFile)) + case _ => error("Unexpected transform annotation") + } + case _ => TransformResult(c) + } +} diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala new file mode 100644 index 00000000..a8ff9fe3 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -0,0 +1,124 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import WrappedExpression.weq +import AnalysisUtils._ +import MemTransformUtils._ + +object AnalysisUtils { + type Connects = collection.mutable.HashMap[String, Expression] + + /** Builds a map from named component to assigned value + * Named components are serialized LHS of connections, nodes, invalids + */ + def getConnects(m: DefModule): Connects = { + def getConnects(connects: Connects)(s: Statement): Statement = { + s match { + case Connect(_, loc, expr) => + connects(loc.serialize) = expr + case DefNode(_, name, value) => + connects(name) = value + case IsInvalid(_, value) => + connects(value.serialize) = WInvalid + case _ => // do nothing + } + s map getConnects(connects) + } + val connects = new Connects + m map getConnects(connects) + connects + } + + /** Find a connection LHS's origin from a module's list of node-to-node connections + * regardless of whether constant propagation has been run. + * Will search past trivial primop/mux's which do not affect its origin. + * Limitations: + * - Only works in a module (stops @ module inputs) + * - Only does trivial primop/mux's (is not complete) + * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a) + */ + def getOrigin(connects: Connects, s: String): Expression = + getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER)) + def getOrigin(connects: Connects)(e: Expression): Expression = e match { + case Mux(cond, tv, fv, _) => + val fvOrigin = getOrigin(connects)(fv) + val tvOrigin = getOrigin(connects)(tv) + val condOrigin = getOrigin(connects)(cond) + if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin + else if (weq(condOrigin, one)) tvOrigin + else if (weq(condOrigin, zero)) fvOrigin + else if (weq(tvOrigin, fvOrigin)) tvOrigin + else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin + else e + case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero + case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => + val extractionWidth = (msb - lsb) + 1 + val nodeWidth = bitWidth(args.head.tpe) + // if you're extracting the full bitwidth, then keep searching for origin + if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e + case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => + getOrigin(connects)(args.head) + case ValidIf(cond, value, ClockType) => getOrigin(connects)(value) + // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => + connects get e.serialize match { + case Some(ex) => getOrigin(connects)(ex) + case None => e + } + case _ => e + } + + /** Checks whether the two memories are equivalent in all respects except name + */ + def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = a == b.copy(name = a.name) +} + +/** Determines if a write mask is needed (wmode/en and wmask are equivalent). + * Populates the maskGran field of DefAnnotatedMemory + * Annotations: + * - maskGran = (dataType size) / (number of mask bits) + * - i.e. 1 if bitmask, 8 if bytemask, absent for no mask + * TODO(shunshou): Add floorplan info? + */ +object ResolveMaskGranularity extends Pass { + def name = "Resolve Mask Granularity" + + /** Returns the number of mask bits, if used + */ + def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { + val wenOrigin = getOrigin(connects)(wen) + val wmaskOrigin = connects.keys filter + (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)} + // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) + val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + if (redundantMask) None else Some(wmaskOrigin.size) + } + + /** Only annotate memories that are candidates for memory macro replacements + * i.e. rw, w + r (read, write 1 cycle delay) + */ + def updateStmts(connects: Connects)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + val dataBits = bitWidth(m.dataType) + val rwMasks = m.readwriters map (rw => + getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers map (w => + getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val maskGran = (rwMasks ++ wMasks).head match { + case None => None + case Some(maskBits) => Some(dataBits / maskBits) + } + m.copy(maskGran = maskGran) + case sx => sx map updateStmts(connects) + } + + def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) + def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) +} diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala new file mode 100644 index 00000000..783c179f --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -0,0 +1,38 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib +import firrtl.ir._ +import AnalysisUtils.eqMems +import firrtl.Mappers._ + + +/** Resolves annotation ref to memories that exactly match (except name) another memory + */ +object ResolveMemoryReference extends Pass { + + def name = "Resolve Memory Reference" + + type AnnotatedMemories = collection.mutable.ArrayBuffer[DefAnnotatedMemory] + + /** If a candidate memory is identical except for name to another, add an + * annotation that references the name of the other memory. + */ + def updateMemStmts(uniqueMems: AnnotatedMemories)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + uniqueMems find (x => eqMems(x, m)) match { + case None => + uniqueMems += m + m + case Some(proto) => m copy (memRef = Some(proto.name)) + } + case s => s map updateMemStmts(uniqueMems) + } + + def updateMemMods(m: DefModule) = { + val uniqueMems = new AnnotatedMemories + (m map updateMemStmts(uniqueMems)) + } + + def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) +} diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala new file mode 100644 index 00000000..741ea5ef --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -0,0 +1,41 @@ +package firrtl.passes +package memlib + +import firrtl.Mappers._ +import firrtl.ir._ + +/** Annotates sequential memories that are candidates for macro replacement. + * Requirements for macro replacement: + * - read latency and write latency of one + * - only one readwrite port or write port + * - zero or one read port + */ +object ToMemIR extends Pass { + def name = "To Memory IR" + + /** Only annotate memories that are candidates for memory macro replacements + * i.e. rw, w + r (read, write 1 cycle delay) + */ + def updateStmts(s: Statement): Statement = s match { + case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && + (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => + DefAnnotatedMemory( + m.info, + m.name, + m.dataType, + m.depth, + m.writeLatency, + m.readLatency, + m.readers, + m.writers, + m.readwriters, + m.readUnderWrite, + None, // mask granularity annotation + None // No reference yet to another memory + ) + case sx => sx map updateStmts + } + + def annotateModMems(m: DefModule) = m map updateStmts + def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) +} diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala new file mode 100644 index 00000000..a1088300 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -0,0 +1,36 @@ +package firrtl.passes +package memlib +import net.jcazevedo.moultingyaml._ +import java.io.{File, CharArrayWriter, PrintWriter} + +object CustomYAMLProtocol extends DefaultYamlProtocol { + // bottom depends on top +} + +class YamlFileReader(file: String) { + import CustomYAMLProtocol._ + def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { + if (new File(file).exists) { + val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n") + yamlString.parseYamls flatMap (x => + try Some(reader read x) + catch { case e: Exception => None } + ) + } + else error("Yaml file doesn't exist!") + } +} + +class YamlFileWriter(file: String) { + import CustomYAMLProtocol._ + val outputBuffer = new CharArrayWriter + val separator = "--- \n" + def append(in: YamlValue) { + outputBuffer append s"$separator${in.prettyPrint}" + } + def dump() { + val outputFile = new PrintWriter(file) + outputFile write outputBuffer.toString + outputFile.close() + } +} |
