aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala')
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala36
1 files changed, 17 insertions, 19 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 41c47dce..434c7602 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -28,10 +28,10 @@ object AnalysisUtils {
connects(value.serialize) = WInvalid
case _ => // do nothing
}
- s map getConnects(connects)
+ s.map(getConnects(connects))
}
val connects = new Connects
- m map getConnects(connects)
+ m.map(getConnects(connects))
connects
}
@@ -56,8 +56,8 @@ object AnalysisUtils {
else if (weq(tvOrigin, fvOrigin)) tvOrigin
else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
@@ -69,10 +69,10 @@ object AnalysisUtils {
case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
- connects get e.serialize match {
- case Some(ex) => getOrigin(connects)(ex)
- case None => e
- }
+ connects.get(e.serialize) match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
case _ => e
}
}
@@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass {
*/
def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
val wenOrigin = getOrigin(connects)(wen)
- val wmaskOrigin = connects.keys filter
- (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) }
// all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one))
if (redundantMask) None else Some(wmaskOrigin.size)
}
@@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass {
def updateStmts(connects: Connects)(s: Statement): Statement = s match {
case m: DefAnnotatedMemory =>
val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (rw =>
- getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
- val wMasks = m.writers map (w =>
- getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val rwMasks =
+ m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
val maskGran = (rwMasks ++ wMasks).head match {
- case None => None
+ case None => None
case Some(maskBits) => Some(dataBits / maskBits)
}
m.copy(maskGran = maskGran)
- case sx => sx map updateStmts(connects)
+ case sx => sx.map(updateStmts(connects))
}
- def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m))
- def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m)))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems))
}