aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/AnnotateMemMacros.scala')
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala143
1 files changed, 0 insertions, 143 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
deleted file mode 100644
index 21287922..00000000
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ /dev/null
@@ -1,143 +0,0 @@
-// See LICENSE for license details.
-
-package firrtl.passes
-
-import firrtl._
-import firrtl.ir._
-import firrtl.Utils._
-import firrtl.Mappers._
-import WrappedExpression.weq
-import MemPortUtils.memPortField
-import AnalysisUtils._
-
-case class AppendableInfo(fields: Map[String, Any]) extends Info {
- def append(a: Map[String, Any]) = this.copy(fields = fields ++ a)
- def append(a: (String, Any)): AppendableInfo = append(Map(a))
- def get(f: String) = fields.get(f)
- override def equals(b: Any) = b match {
- case i: AppendableInfo => fields - "info" == i.fields - "info"
- case _ => false
- }
-}
-
-object AnalysisUtils {
- type Connects = collection.mutable.HashMap[String, Expression]
- def getConnects(m: DefModule): Connects = {
- def getConnects(connects: Connects)(s: Statement): Statement = {
- s match {
- case Connect(_, loc, expr) =>
- connects(loc.serialize) = expr
- case DefNode(_, name, value) =>
- connects(name) = value
- case _ => // do nothing
- }
- s map getConnects(connects)
- }
- val connects = new Connects
- m map getConnects(connects)
- connects
- }
-
- // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS.
- // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation,
- // the function will try to search backwards past the primop/mux.
- // use case: compare if two nodes have the same origin
- // limitation: only works in a module (stops @ module inputs)
- // TODO: more thorough (i.e. a + 0 = a)
- def getConnectOrigin(connects: Connects)(node: String): Expression =
- connects get node match {
- case None => EmptyExpression
- case Some(e) => getOrigin(connects, e)
- }
- def getConnectOrigin(connects: Connects, e: Expression): Expression =
- getConnectOrigin(connects)(e.serialize)
-
- private def getOrigin(connects: Connects, e: Expression): Expression = e match {
- case Mux(cond, tv, fv, _) =>
- val fvOrigin = getOrigin(connects, fv)
- val tvOrigin = getOrigin(connects, tv)
- val condOrigin = getOrigin(connects, cond)
- if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
- else if (weq(condOrigin, one)) tvOrigin
- else if (weq(condOrigin, zero)) fvOrigin
- else if (weq(tvOrigin, fvOrigin)) tvOrigin
- else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
- else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
- case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
- val extractionWidth = (msb - lsb) + 1
- val nodeWidth = bitWidth(args.head.tpe)
- // if you're extracting the full bitwidth, then keep searching for origin
- if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e
- case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
- getOrigin(connects, args.head)
- // Todo: It's not clear it's ok to call remove validifs before mem passes...
- case ValidIf(cond, value, ClockType) => getOrigin(connects, value)
- // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
- case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
- connects get e.serialize match {
- case Some(ex) => getOrigin(connects, ex)
- case None => e
- }
- case _ => e
- }
-
- def appendInfo[T <: Info](info: T, add: Map[String, Any]) = info match {
- case i: AppendableInfo => i.append(add)
- case _ => AppendableInfo(fields = add + ("info" -> info))
- }
- def appendInfo[T <: Info](info: T, add: (String, Any)): AppendableInfo = appendInfo(info, Map(add))
- def getInfo[T <: Info](info: T, k: String) = info match {
- case i: AppendableInfo => i.get(k)
- case _ => None
- }
- def containsInfo[T <: Info](info: T, k: String) = info match {
- case i: AppendableInfo => i.fields.contains(k)
- case _ => false
- }
-
- // memories equivalent as long as all fields (except name) are the same
- def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name)
-}
-
-object AnnotateMemMacros extends Pass {
- def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)"
-
- // returns # of mask bits if used
- def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
- val wenOrigin = getConnectOrigin(connects, wen)
- val wmaskOrigin = connects.keys filter
- (_ startsWith wmask.serialize) map getConnectOrigin(connects)
- // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
- if (redundantMask) None else Some(wmaskOrigin.size)
- }
-
- def updateStmts(connects: Connects)(s: Statement): Statement = s match {
- // only annotate memories that are candidates for memory macro replacements
- // i.e. rw, w + r (read, write 1 cycle delay)
- case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
- (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
- val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (rw =>
- getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
- val wMasks = m.writers map (w =>
- getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
- val memAnnotations = Map("useMacro" -> true)
- val tempInfo = appendInfo(m.info, memAnnotations)
- (rwMasks ++ wMasks).head match {
- case None =>
- m copy (info = tempInfo)
- case Some(maskBits) =>
- m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits))
- }
- case sx => sx map updateStmts(connects)
- }
-
- def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
-
- def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
-}
-
-// TODO: Add floorplan info?