aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-10-17 18:53:19 -0700
committerAngie Wang2016-10-17 18:53:19 -0700
commit85baeda249e59c7d9d9f159aaf29ff46d685cf02 (patch)
treecfb5f4a6a0a80f9033275de6e5e36b9d5b96faad /src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
parent7d08b9a1486fef0459481f6e542464a29fbe1db5 (diff)
Reorganized memory blackboxing (#336)
* Reorganized memory blackboxing Moved to new package memlib Added comments Moved utility functions around Removed unused AnnotateValidMemConfigs.scala * Fixed tests to pass * Use DefAnnotatedMemory instead of AppendableInfo * Broke passes up into simpler passes AnnotateMemMacros -> (ToMemIR, ResolveMaskGranularity) UpdateDuplicateMemMacros -> (RenameAnnotatedMemoryPorts, ResolveMemoryReference) * Fixed to make tests run * Minor changes from code review * Removed vim comments and renamed ReplSeqMem
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala')
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala124
1 files changed, 124 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
new file mode 100644
index 00000000..a8ff9fe3
--- /dev/null
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -0,0 +1,124 @@
+// See LICENSE for license details.
+
+package firrtl.passes
+package memlib
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import WrappedExpression.weq
+import AnalysisUtils._
+import MemTransformUtils._
+
+object AnalysisUtils {
+ type Connects = collection.mutable.HashMap[String, Expression]
+
+ /** Builds a map from named component to assigned value
+ * Named components are serialized LHS of connections, nodes, invalids
+ */
+ def getConnects(m: DefModule): Connects = {
+ def getConnects(connects: Connects)(s: Statement): Statement = {
+ s match {
+ case Connect(_, loc, expr) =>
+ connects(loc.serialize) = expr
+ case DefNode(_, name, value) =>
+ connects(name) = value
+ case IsInvalid(_, value) =>
+ connects(value.serialize) = WInvalid
+ case _ => // do nothing
+ }
+ s map getConnects(connects)
+ }
+ val connects = new Connects
+ m map getConnects(connects)
+ connects
+ }
+
+ /** Find a connection LHS's origin from a module's list of node-to-node connections
+ * regardless of whether constant propagation has been run.
+ * Will search past trivial primop/mux's which do not affect its origin.
+ * Limitations:
+ * - Only works in a module (stops @ module inputs)
+ * - Only does trivial primop/mux's (is not complete)
+ * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a)
+ */
+ def getOrigin(connects: Connects, s: String): Expression =
+ getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER))
+ def getOrigin(connects: Connects)(e: Expression): Expression = e match {
+ case Mux(cond, tv, fv, _) =>
+ val fvOrigin = getOrigin(connects)(fv)
+ val tvOrigin = getOrigin(connects)(tv)
+ val condOrigin = getOrigin(connects)(cond)
+ if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
+ else if (weq(condOrigin, one)) tvOrigin
+ else if (weq(condOrigin, zero)) fvOrigin
+ else if (weq(tvOrigin, fvOrigin)) tvOrigin
+ else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
+ else e
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
+ val extractionWidth = (msb - lsb) + 1
+ val nodeWidth = bitWidth(args.head.tpe)
+ // if you're extracting the full bitwidth, then keep searching for origin
+ if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
+ case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
+ getOrigin(connects)(args.head)
+ case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
+ // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
+ case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
+ connects get e.serialize match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
+ case _ => e
+ }
+
+ /** Checks whether the two memories are equivalent in all respects except name
+ */
+ def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = a == b.copy(name = a.name)
+}
+
+/** Determines if a write mask is needed (wmode/en and wmask are equivalent).
+ * Populates the maskGran field of DefAnnotatedMemory
+ * Annotations:
+ * - maskGran = (dataType size) / (number of mask bits)
+ * - i.e. 1 if bitmask, 8 if bytemask, absent for no mask
+ * TODO(shunshou): Add floorplan info?
+ */
+object ResolveMaskGranularity extends Pass {
+ def name = "Resolve Mask Granularity"
+
+ /** Returns the number of mask bits, if used
+ */
+ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
+ val wenOrigin = getOrigin(connects)(wen)
+ val wmaskOrigin = connects.keys filter
+ (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
+ val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ if (redundantMask) None else Some(wmaskOrigin.size)
+ }
+
+ /** Only annotate memories that are candidates for memory macro replacements
+ * i.e. rw, w + r (read, write 1 cycle delay)
+ */
+ def updateStmts(connects: Connects)(s: Statement): Statement = s match {
+ case m: DefAnnotatedMemory =>
+ val dataBits = bitWidth(m.dataType)
+ val rwMasks = m.readwriters map (rw =>
+ getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers map (w =>
+ getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val maskGran = (rwMasks ++ wMasks).head match {
+ case None => None
+ case Some(maskBits) => Some(dataBits / maskBits)
+ }
+ m.copy(maskGran = maskGran)
+ case sx => sx map updateStmts(connects)
+ }
+
+ def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
+ def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+}