aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib')
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala33
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala47
-rw-r--r--src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala77
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala188
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala126
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala124
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala38
-rw-r--r--src/main/scala/firrtl/passes/memlib/ToMemIR.scala41
-rw-r--r--src/main/scala/firrtl/passes/memlib/YamlUtils.scala36
9 files changed, 710 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
new file mode 100644
index 00000000..6dca5961
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -0,0 +1,33 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import Utils.indent
+
+case class DefAnnotatedMemory(
+ info: Info,
+ name: String,
+ dataType: Type,
+ depth: Int,
+ writeLatency: Int,
+ readLatency: Int,
+ readers: Seq[String],
+ writers: Seq[String],
+ readwriters: Seq[String],
+ readUnderWrite: Option[String],
+ maskGran: Option[BigInt],
+ memRef: Option[String]
+ //pins: Seq[Pin],
+ ) extends Statement with IsDeclaration {
+ def serialize: String = this.toMem.serialize
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def toMem = DefMemory(info, name, dataType, depth,
+ writeLatency, readLatency, readers, writers,
+ readwriters, readUnderWrite)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
new file mode 100644
index 00000000..78a386b2
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
@@ -0,0 +1,47 @@
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import AnalysisUtils._
+import MemPortUtils.{MemPortMap}
+
+object MemTransformUtils {
+
+ /** Replaces references to old memory port names with new memory port names
+ */
+ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
+ //TODO(izraelevitz): check speed
+ def updateRef(e: Expression): Expression = {
+ val ex = e map updateRef
+ repl getOrElse (ex.serialize, ex)
+ }
+
+ def hasEmptyExpr(stmt: Statement): Boolean = {
+ var foundEmpty = false
+ def testEmptyExpr(e: Expression): Expression = {
+ e match {
+ case EmptyExpression => foundEmpty = true
+ case _ =>
+ }
+ e map testEmptyExpr // map must return; no foreach
+ }
+ stmt map testEmptyExpr
+ foundEmpty
+ }
+
+ def updateStmtRefs(s: Statement): Statement =
+ s map updateStmtRefs map updateRef match {
+ case c: Connect if hasEmptyExpr(c) => EmptyStmt
+ case s => s
+ }
+
+ updateStmtRefs(s)
+ }
+
+ def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem)
+ def memPortField(s: DefAnnotatedMemory, p: String, f: String): Expression =
+ MemPortUtils.memPortField(s.toMem, p, f)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
new file mode 100644
index 00000000..168a6a48
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
@@ -0,0 +1,77 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import AnalysisUtils._
+import MemPortUtils._
+import MemTransformUtils._
+
+
+/** Changes memory port names to standard port names (i.e. RW0 instead T_408)
+ */
+object RenameAnnotatedMemoryPorts extends Pass {
+
+ def name = "Rename Annotated Memory Ports"
+
+ /** Renames memory ports to a standard naming scheme:
+ * - R0, R1, ... for each read port
+ * - W0, W1, ... for each write port
+ * - RW0, RW1, ... for each readwrite port
+ */
+ def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = {
+ val rports = m.readers.indices map (i => s"R$i")
+ val wports = m.writers.indices map (i => s"W$i")
+ val rwports = m.readwriters.indices map (i => s"RW$i")
+ m copy (readers = rports, writers = wports, readwriters = rwports)
+ }
+
+ /** Maps the serialized form of all memory port field names to the
+ * corresponding new memory port field Expression.
+ * E.g.:
+ * - ("m.read.addr") becomes (m.R0.addr)
+ */
+ def getMemPortMap(m: DefAnnotatedMemory): MemPortMap = {
+ val memPortMap = new MemPortMap
+ val defaultFields = Seq("addr", "en", "clk")
+ val rFields = defaultFields :+ "data"
+ val wFields = rFields :+ "mask"
+ val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
+
+ def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
+ for ((p, i) <- ports.zipWithIndex; f <- fields) {
+ val newPort = createSubField(createRef(m.name), newPortKind + i)
+ val field = createSubField(newPort, f)
+ memPortMap(s"${m.name}.$p.$f") = field
+ }
+ updateMemPortMap(m.readers, rFields, "R")
+ updateMemPortMap(m.writers, wFields, "W")
+ updateMemPortMap(m.readwriters, rwFields, "RW")
+ memPortMap
+ }
+
+ /** Replaces candidate memories with memories with standard port names
+ * Does not update the references (this is done via updateStmtRefs)
+ */
+ def updateMemStmts(memPortMap: MemPortMap)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ val updatedMem = createMemProto(m)
+ memPortMap ++= getMemPortMap(m)
+ updatedMem
+ case s => s map updateMemStmts(memPortMap)
+ }
+
+ /** Replaces candidate memories and their references with standard port names
+ */
+ def updateMemMods(m: DefModule) = {
+ val memPortMap = new MemPortMap
+ (m map updateMemStmts(memPortMap)
+ map updateStmtRefs(memPortMap))
+ }
+
+ def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
new file mode 100644
index 00000000..3139ef21
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -0,0 +1,188 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import MemPortUtils.{MemPortMap, Modules}
+import MemTransformUtils._
+import AnalysisUtils._
+
+/** Replace DefAnnotatedMemory with memory blackbox + wrapper + conf file.
+ * This will not generate wmask ports if not needed.
+ * Creates the minimum # of black boxes needed by the design.
+ */
+class ReplaceMemMacros(writer: ConfWriter) extends Pass {
+ def name = "Replace Memory Macros"
+
+ /** Return true if mask granularity is per bit, false if per byte or unspecified
+ */
+ private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match {
+ case None => false
+ case Some(v) => v == 1
+ }
+
+ private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+
+ private def wPortToBundle(mem: DefAnnotatedMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) => Seq(Field("mask", Default, createMask(mem.dataType)))
+ })
+ )
+ private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
+ (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType)))
+ case Some(_) => Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))
+ })
+ )
+ // TODO(shunshou): Don't use createMask???
+
+ private def rwPortToBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("rdata", Flip, mem.dataType)
+ ) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) => Seq(Field("wmask", Default, createMask(mem.dataType)))
+ })
+ )
+ private def rwPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, flattenType(mem.dataType)),
+ Field("rdata", Flip, flattenType(mem.dataType))
+ ) ++ (mem.maskGran match {
+ case None => Nil
+ case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType)))
+ case Some(_) => Seq(Field("wmask", Default, flattenType(createMask(mem.dataType))))
+ })
+ )
+
+ def memToBundle(s: DefAnnotatedMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+ def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType(
+ s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
+
+ /** Creates a wrapper module and external module to replace a candidate memory
+ * The wrapper module has the same type as the memory it replaces
+ * The external module
+ */
+ def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = {
+ assert(m.dataType != UnknownType)
+ val wrapperIoType = memToBundle(m)
+ val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ // Creates a type with the write/readwrite masks omitted if necessary
+ val bbIoType = memToFlattenBundle(m)
+ val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbRef = createRef(m.name, bbIoType)
+ val hasMask = m.maskGran.isDefined
+ val fillMask = getFillWMask(m)
+ def portRef(p: String) = createRef(p, field_type(wrapperIoType, p))
+ val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
+ (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++
+ (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask)))
+ val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
+ val bb = ExtModule(NoInfo, m.name, bbIoPorts)
+ // TODO: Annotate? -- use actual annotation map
+
+ // add to conf file
+ writer.append(m)
+ Seq(bb, wrapper)
+ }
+
+ // TODO(shunshou): get rid of copy pasta
+ // Connects the clk, en, and addr fields from the wrapperPort to the bbPort
+ def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] =
+ Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
+
+ // Connects the clk, en, and addr fields from the wrapperPort to the bbPort
+ def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression =
+ if (fillMask) toBitMask(mask, dataType) else toBits(mask)
+
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
+ defaultConnects(wrapperPort, bbPort) :+
+ fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data"))
+
+ def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
+ val wrapperData = createSubField(wrapperPort, "data")
+ val defaultSeq = defaultConnects(wrapperPort, bbPort) :+
+ Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData))
+ hasMask match {
+ case false => defaultSeq
+ case true => defaultSeq :+ Connect(
+ NoInfo,
+ createSubField(bbPort, "mask"),
+ maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
+ )
+ }
+ }
+
+ def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
+ val wrapperWData = createSubField(wrapperPort, "wdata")
+ val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
+ fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")),
+ connectFields(bbPort, "wmode", wrapperPort, "wmode"),
+ Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ hasMask match {
+ case false => defaultSeq
+ case true => defaultSeq :+ Connect(
+ NoInfo,
+ createSubField(bbPort, "wmask"),
+ maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
+ )
+ }
+ }
+
+ def updateMemStmts(namespace: Namespace,
+ memPortMap: MemPortMap,
+ memMods: Modules)
+ (s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ if (m.maskGran.isEmpty) {
+ m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
+ m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
+ }
+ m.memRef match {
+ case None =>
+ // prototype mem
+ val newWrapperName = namespace newName m.name
+ val newMemBBName = namespace newName s"${m.name}_ext"
+ val newMem = m copy (name = newMemBBName)
+ memMods ++= createMemModule(newMem, newWrapperName)
+ WDefInstance(m.info, m.name, newWrapperName, UnknownType)
+ case Some(ref: String) =>
+ WDefInstance(m.info, m.name, ref, UnknownType)
+ }
+ case sx => sx map updateMemStmts(namespace, memPortMap, memMods)
+ }
+
+ def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = {
+ val memPortMap = new MemPortMap
+
+ (m map updateMemStmts(namespace, memPortMap, memMods)
+ map updateStmtRefs(memPortMap))
+ }
+
+ def run(c: Circuit) = {
+ val namespace = Namespace(c)
+ val memMods = new Modules
+ val modules = c.modules map updateMemMods(namespace, memMods)
+ // print conf
+ writer.serialize()
+ c copy (modules = modules ++ memMods)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
new file mode 100644
index 00000000..dfa828c9
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -0,0 +1,126 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import Annotations._
+import AnalysisUtils._
+import Utils.error
+import java.io.{File, CharArrayWriter, PrintWriter}
+
+sealed trait PassOption
+case object InputConfigFileName extends PassOption
+case object OutputConfigFileName extends PassOption
+case object PassCircuitName extends PassOption
+
+object PassConfigUtil {
+ type PassOptionMap = Map[PassOption, String]
+
+ def getPassOptions(t: String, usage: String = "") = {
+ // can't use space to delimit sub arguments (otherwise, Driver.scala will throw error)
+ val passArgList = t.split(":").toList
+
+ def nextPassOption(map: PassOptionMap, list: List[String]): PassOptionMap = {
+ list match {
+ case Nil => map
+ case "-i" :: value :: tail =>
+ nextPassOption(map + (InputConfigFileName -> value), tail)
+ case "-o" :: value :: tail =>
+ nextPassOption(map + (OutputConfigFileName -> value), tail)
+ case "-c" :: value :: tail =>
+ nextPassOption(map + (PassCircuitName -> value), tail)
+ case option :: tail =>
+ error("Unknown option " + option + usage)
+ }
+ }
+ nextPassOption(Map[PassOption, String](), passArgList)
+ }
+}
+
+class ConfWriter(filename: String) {
+ val outputBuffer = new CharArrayWriter
+ def append(m: DefAnnotatedMemory) = {
+ // legacy
+ val maskGran = m.maskGran
+ val readers = List.fill(m.readers.length)("read")
+ val writers = List.fill(m.writers.length)(if (maskGran.isEmpty) "write" else "mwrite")
+ val readwriters = List.fill(m.readwriters.length)(if (maskGran.isEmpty) "rw" else "mrw")
+ val ports = (writers ++ readers ++ readwriters) mkString ","
+ val maskGranConf = maskGran match { case None => "" case Some(p) => s"mask_gran $p" }
+ val width = bitWidth(m.dataType)
+ val conf = s"name ${m.name} depth ${m.depth} width $width ports $ports $maskGranConf \n"
+ outputBuffer.append(conf)
+ }
+ def serialize() = {
+ val outputFile = new PrintWriter(filename)
+ outputFile.write(outputBuffer.toString)
+ outputFile.close()
+ }
+}
+
+case class ReplSeqMemAnnotation(t: String, tID: TransID)
+ extends Annotation with Loose with Unstable {
+
+ val usage = """
+[Optional] ReplSeqMem
+ Pass to replace sequential memories with blackboxes + configuration file
+
+Usage:
+ --replSeqMem -c:<circuit>:-i:<filename>:-o:<filename>
+ *** Note: sub-arguments to --replSeqMem should be delimited by : and not white space!
+
+Required Arguments:
+ -o<filename> Specify the output configuration file
+ -c<compiler> Specify the target circuit
+
+Optional Arguments:
+ -i<filename> Specify the input configuration file (for additional optimizations)
+"""
+
+ val passOptions = PassConfigUtil.getPassOptions(t, usage)
+ val outputConfig = passOptions.getOrElse(
+ OutputConfigFileName,
+ error("No output config file provided for ReplSeqMem!" + usage)
+ )
+ val passCircuit = passOptions.getOrElse(
+ PassCircuitName,
+ error("No circuit name specified for ReplSeqMem!" + usage)
+ )
+ val target = CircuitName(passCircuit)
+ def duplicate(n: Named) = this copy (t = t.replace(s"-c:$passCircuit", s"-c:${n.name}"))
+}
+
+class ReplSeqMem(transID: TransID) extends Transform with SimpleRun {
+ def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) =
+ Seq(Legalize,
+ ToMemIR,
+ ResolveMaskGranularity,
+ RenameAnnotatedMemoryPorts,
+ ResolveMemoryReference,
+ //new AnnotateValidMemConfigs(inConfigFile),
+ new ReplaceMemMacros(outConfigFile),
+ RemoveEmpty,
+ CheckInitialization,
+ InferTypes,
+ Uniquify,
+ ResolveKinds, // Must be run for the transform to work!
+ ResolveGenders)
+
+ def execute(c: Circuit, map: AnnotationMap) = map get transID match {
+ case Some(p) => p get CircuitName(c.main) match {
+ case Some(ReplSeqMemAnnotation(t, _)) =>
+ val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "")
+ val inConfigFile = {
+ if (inputFileName.isEmpty) None
+ else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName))
+ else error("Input configuration file does not exist!")
+ }
+ val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName))
+ run(c, passSeq(inConfigFile, outConfigFile))
+ case _ => error("Unexpected transform annotation")
+ }
+ case _ => TransformResult(c)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
new file mode 100644
index 00000000..a8ff9fe3
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -0,0 +1,124 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import WrappedExpression.weq
+import AnalysisUtils._
+import MemTransformUtils._
+
+object AnalysisUtils {
+ type Connects = collection.mutable.HashMap[String, Expression]
+
+ /** Builds a map from named component to assigned value
+ * Named components are serialized LHS of connections, nodes, invalids
+ */
+ def getConnects(m: DefModule): Connects = {
+ def getConnects(connects: Connects)(s: Statement): Statement = {
+ s match {
+ case Connect(_, loc, expr) =>
+ connects(loc.serialize) = expr
+ case DefNode(_, name, value) =>
+ connects(name) = value
+ case IsInvalid(_, value) =>
+ connects(value.serialize) = WInvalid
+ case _ => // do nothing
+ }
+ s map getConnects(connects)
+ }
+ val connects = new Connects
+ m map getConnects(connects)
+ connects
+ }
+
+ /** Find a connection LHS's origin from a module's list of node-to-node connections
+ * regardless of whether constant propagation has been run.
+ * Will search past trivial primop/mux's which do not affect its origin.
+ * Limitations:
+ * - Only works in a module (stops @ module inputs)
+ * - Only does trivial primop/mux's (is not complete)
+ * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a)
+ */
+ def getOrigin(connects: Connects, s: String): Expression =
+ getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER))
+ def getOrigin(connects: Connects)(e: Expression): Expression = e match {
+ case Mux(cond, tv, fv, _) =>
+ val fvOrigin = getOrigin(connects)(fv)
+ val tvOrigin = getOrigin(connects)(tv)
+ val condOrigin = getOrigin(connects)(cond)
+ if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
+ else if (weq(condOrigin, one)) tvOrigin
+ else if (weq(condOrigin, zero)) fvOrigin
+ else if (weq(tvOrigin, fvOrigin)) tvOrigin
+ else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
+ else e
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
+ val extractionWidth = (msb - lsb) + 1
+ val nodeWidth = bitWidth(args.head.tpe)
+ // if you're extracting the full bitwidth, then keep searching for origin
+ if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
+ case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
+ getOrigin(connects)(args.head)
+ case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
+ // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
+ case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
+ connects get e.serialize match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
+ case _ => e
+ }
+
+ /** Checks whether the two memories are equivalent in all respects except name
+ */
+ def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = a == b.copy(name = a.name)
+}
+
+/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
+ * Populates the maskGran field of DefAnnotatedMemory
+ * Annotations:
+ * - maskGran = (dataType size) / (number of mask bits)
+ * - i.e. 1 if bitmask, 8 if bytemask, absent for no mask
+ * TODO(shunshou): Add floorplan info?
+ */
+object ResolveMaskGranularity extends Pass {
+ def name = "Resolve Mask Granularity"
+
+ /** Returns the number of mask bits, if used
+ */
+ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
+ val wenOrigin = getOrigin(connects)(wen)
+ val wmaskOrigin = connects.keys filter
+ (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
+ val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ if (redundantMask) None else Some(wmaskOrigin.size)
+ }
+
+ /** Only annotate memories that are candidates for memory macro replacements
+ * i.e. rw, w + r (read, write 1 cycle delay)
+ */
+ def updateStmts(connects: Connects)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ val dataBits = bitWidth(m.dataType)
+ val rwMasks = m.readwriters map (rw =>
+ getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers map (w =>
+ getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val maskGran = (rwMasks ++ wMasks).head match {
+ case None => None
+ case Some(maskBits) => Some(dataBits / maskBits)
+ }
+ m.copy(maskGran = maskGran)
+ case sx => sx map updateStmts(connects)
+ }
+
+ def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
+ def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
new file mode 100644
index 00000000..783c179f
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -0,0 +1,38 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+import firrtl.ir._
+import AnalysisUtils.eqMems
+import firrtl.Mappers._
+
+
+/** Resolves annotation ref to memories that exactly match (except name) another memory
+ */
+object ResolveMemoryReference extends Pass {
+
+ def name = "Resolve Memory Reference"
+
+ type AnnotatedMemories = collection.mutable.ArrayBuffer[DefAnnotatedMemory]
+
+ /** If a candidate memory is identical except for name to another, add an
+ * annotation that references the name of the other memory.
+ */
+ def updateMemStmts(uniqueMems: AnnotatedMemories)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ uniqueMems find (x => eqMems(x, m)) match {
+ case None =>
+ uniqueMems += m
+ m
+ case Some(proto) => m copy (memRef = Some(proto.name))
+ }
+ case s => s map updateMemStmts(uniqueMems)
+ }
+
+ def updateMemMods(m: DefModule) = {
+ val uniqueMems = new AnnotatedMemories
+ (m map updateMemStmts(uniqueMems))
+ }
+
+ def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
new file mode 100644
index 00000000..741ea5ef
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
@@ -0,0 +1,41 @@
+package firrtl.passes
+package memlib
+
+import firrtl.Mappers._
+import firrtl.ir._
+
+/** Annotates sequential memories that are candidates for macro replacement.
+ * Requirements for macro replacement:
+ * - read latency and write latency of one
+ * - only one readwrite port or write port
+ * - zero or one read port
+ */
+object ToMemIR extends Pass {
+ def name = "To Memory IR"
+
+ /** Only annotate memories that are candidates for memory macro replacements
+ * i.e. rw, w + r (read, write 1 cycle delay)
+ */
+ def updateStmts(s: Statement): Statement = s match {
+ case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
+ (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
+ DefAnnotatedMemory(
+ m.info,
+ m.name,
+ m.dataType,
+ m.depth,
+ m.writeLatency,
+ m.readLatency,
+ m.readers,
+ m.writers,
+ m.readwriters,
+ m.readUnderWrite,
+ None, // mask granularity annotation
+ None // No reference yet to another memory
+ )
+ case sx => sx map updateStmts
+ }
+
+ def annotateModMems(m: DefModule) = m map updateStmts
+ def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+}
diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
new file mode 100644
index 00000000..a1088300
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
@@ -0,0 +1,36 @@
+package firrtl.passes
+package memlib
+import net.jcazevedo.moultingyaml._
+import java.io.{File, CharArrayWriter, PrintWriter}
+
+object CustomYAMLProtocol extends DefaultYamlProtocol {
+ // bottom depends on top
+}
+
+class YamlFileReader(file: String) {
+ import CustomYAMLProtocol._
+ def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
+ if (new File(file).exists) {
+ val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
+ yamlString.parseYamls flatMap (x =>
+ try Some(reader read x)
+ catch { case e: Exception => None }
+ )
+ }
+ else error("Yaml file doesn't exist!")
+ }
+}
+
+class YamlFileWriter(file: String) {
+ import CustomYAMLProtocol._
+ val outputBuffer = new CharArrayWriter
+ val separator = "--- \n"
+ def append(in: YamlValue) {
+ outputBuffer append s"$separator${in.prettyPrint}"
+ }
+ def dump() {
+ val outputFile = new PrintWriter(file)
+ outputFile write outputBuffer.toString
+ outputFile.close()
+ }
+}