aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala121
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala2
2 files changed, 59 insertions, 64 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
index 7da290b7..7ced7a99 100644
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
@@ -2,13 +2,13 @@
package firrtl.passes
-import scala.collection.mutable
-import AnalysisUtils._
-import firrtl.WrappedExpression._
-import firrtl.ir._
import firrtl._
+import firrtl.ir._
import firrtl.Utils._
import firrtl.Mappers._
+import WrappedExpression.weq
+import MemPortUtils.memPortField
+import AnalysisUtils._
case class AppendableInfo(fields: Map[String, Any]) extends Info {
def append(a: Map[String, Any]) = this.copy(fields = fields ++ a)
@@ -44,35 +44,40 @@ object AnalysisUtils {
// use case: compare if two nodes have the same origin
// limitation: only works in a module (stops @ module inputs)
// TODO: more thorough (i.e. a + 0 = a)
- def getConnectOrigin(connects: Connects, node: String): Expression = {
- if (connects contains node) getOrigin(connects, connects(node))
- else EmptyExpression
- }
+ def getConnectOrigin(connects: Connects)(node: String): Expression =
+ connects get node match {
+ case None => EmptyExpression
+ case Some(e) => getOrigin(connects, e)
+ }
+ def getConnectOrigin(connects: Connects, e: Expression): Expression =
+ getConnectOrigin(connects)(e.serialize)
private def getOrigin(connects: Connects, e: Expression): Expression = e match {
case Mux(cond, tv, fv, _) =>
val fvOrigin = getOrigin(connects, fv)
val tvOrigin = getOrigin(connects, tv)
val condOrigin = getOrigin(connects, cond)
- if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin
- else if (we(condOrigin) == we(one)) tvOrigin
- else if (we(condOrigin) == we(zero)) fvOrigin
- else if (we(tvOrigin) == we(fvOrigin)) tvOrigin
- else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin
+ if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin
+ else if (weq(condOrigin, one)) tvOrigin
+ else if (weq(condOrigin, zero)) fvOrigin
+ else if (weq(tvOrigin, fvOrigin)) tvOrigin
+ else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args.contains(one) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args.contains(zero) => zero
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
- val extractionWidth = (msb-lsb)+1
+ val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
// if you're extracting the full bitwidth, then keep searching for origin
- if (nodeWidth == extractionWidth) getOrigin(connects, args.head)
- else e
+ if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects, args.head)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
- case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects.contains(e.serialize) && kind(e) != RegKind =>
- getConnectOrigin(connects, e.serialize)
+ case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
+ connects get e.serialize match {
+ case Some(ex) => getOrigin(connects, ex)
+ case None => e
+ }
case _ => e
}
@@ -92,55 +97,45 @@ object AnalysisUtils {
// memories equivalent as long as all fields (except name) are the same
def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name)
-
}
object AnnotateMemMacros extends Pass {
+ def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)"
+
+ // returns # of mask bits if used
+ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
+ val wenOrigin = getConnectOrigin(connects, wen)
+ val wmaskOrigin = connects.keys filter
+ (_ startsWith wmask.serialize) map getConnectOrigin(connects)
+ // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
+ val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ if (redundantMask) None else Some(wmaskOrigin.size)
+ }
- def name = "Analyze sequential memories and tag with info for future passes (useMacro,maskGran)"
-
- def run(c: Circuit) = {
-
- def annotateModMems(m: Module) = {
- val connects = getConnects(m)
-
- // returns # of mask bits if used
- def getMaskBits(wen: String, wmask: String): Option[Int] = {
- val wenOrigin = we(getConnectOrigin(connects, wen))
- val one1 = we(one)
- val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects, x)))
- // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin.map( x => (x == wenOrigin) || (x == one1) ).foldLeft(true)(_ && _)
- if (redundantMask) None else Some(wmaskOrigin.length)
- }
-
- def updateStmts(s: Statement): Statement = s match {
- // only annotate memories that are candidates for memory macro replacements
- // i.e. rw, w + r (read, write 1 cycle delay)
- case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
- (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
- val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode", s"${m.name}.$w.wmask"))
- val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en", s"${m.name}.$w.mask"))
- val maskBits = (rwMasks ++ wMasks).head
- val memAnnotations = Map("useMacro" -> true)
- val tempInfo = appendInfo(m.info, memAnnotations)
- if (maskBits == None) m.copy(info = tempInfo)
- else m.copy(info = tempInfo.append("maskGran" -> dataBits/maskBits.get))
- case b: Block => b map updateStmts
- case s => s
+ def updateStmts(connects: Connects)(s: Statement): Statement = s match {
+ // only annotate memories that are candidates for memory macro replacements
+ // i.e. rw, w + r (read, write 1 cycle delay)
+ case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
+ (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
+ val dataBits = bitWidth(m.dataType)
+ val rwMasks = m.readwriters map (rw =>
+ getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers map (w =>
+ getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val memAnnotations = Map("useMacro" -> true)
+ val tempInfo = appendInfo(m.info, memAnnotations)
+ (rwMasks ++ wMasks).head match {
+ case None =>
+ m copy (info = tempInfo)
+ case Some(maskBits) =>
+ m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits))
}
- m.copy(body=updateStmts(m.body))
- }
-
- val updatedMods = c.modules map {
- case m: Module => annotateModMems(m)
- case m: ExtModule => m
- }
- c.copy(modules = updatedMods)
+ case s => s map updateStmts(connects)
+ }
- }
+ def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m))
+ def run(c: Circuit) = c copy (modules = (c.modules map annotateModMems))
}
-// TODO: Add floorplan info? \ No newline at end of file
+// TODO: Add floorplan info?
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 7219b1ce..8aeafc9e 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -107,7 +107,7 @@ circuit Top :
val circuit = InferTypes.run(ToWorkingIR.run(parse(input)))
val m = circuit.modules.head.asInstanceOf[ir.Module]
val connects = AnalysisUtils.getConnects(m)
- val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects,"f").serialize
+ val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects)("f").serialize
require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !")
}