diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala new file mode 100644 index 00000000..a8ff9fe3 --- /dev/null +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -0,0 +1,124 @@ +// See LICENSE for license details. + +package firrtl.passes +package memlib + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils._ +import firrtl.Mappers._ +import WrappedExpression.weq +import AnalysisUtils._ +import MemTransformUtils._ + +object AnalysisUtils { + type Connects = collection.mutable.HashMap[String, Expression] + + /** Builds a map from named component to assigned value + * Named components are serialized LHS of connections, nodes, invalids + */ + def getConnects(m: DefModule): Connects = { + def getConnects(connects: Connects)(s: Statement): Statement = { + s match { + case Connect(_, loc, expr) => + connects(loc.serialize) = expr + case DefNode(_, name, value) => + connects(name) = value + case IsInvalid(_, value) => + connects(value.serialize) = WInvalid + case _ => // do nothing + } + s map getConnects(connects) + } + val connects = new Connects + m map getConnects(connects) + connects + } + + /** Find a connection LHS's origin from a module's list of node-to-node connections + * regardless of whether constant propagation has been run. + * Will search past trivial primop/mux's which do not affect its origin. + * Limitations: + * - Only works in a module (stops @ module inputs) + * - Only does trivial primop/mux's (is not complete) + * TODO(shunshou): implement more equivalence cases (i.e. a + 0 = a) + */ + def getOrigin(connects: Connects, s: String): Expression = + getOrigin(connects)(WRef(s, UnknownType, ExpKind, UNKNOWNGENDER)) + def getOrigin(connects: Connects)(e: Expression): Expression = e match { + case Mux(cond, tv, fv, _) => + val fvOrigin = getOrigin(connects)(fv) + val tvOrigin = getOrigin(connects)(tv) + val condOrigin = getOrigin(connects)(cond) + if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin + else if (weq(condOrigin, one)) tvOrigin + else if (weq(condOrigin, zero)) fvOrigin + else if (weq(tvOrigin, fvOrigin)) tvOrigin + else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin + else e + case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero + case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => + val extractionWidth = (msb - lsb) + 1 + val nodeWidth = bitWidth(args.head.tpe) + // if you're extracting the full bitwidth, then keep searching for origin + if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e + case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => + getOrigin(connects)(args.head) + case ValidIf(cond, value, ClockType) => getOrigin(connects)(value) + // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => + connects get e.serialize match { + case Some(ex) => getOrigin(connects)(ex) + case None => e + } + case _ => e + } + + /** Checks whether the two memories are equivalent in all respects except name + */ + def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory) = a == b.copy(name = a.name) +} + +/** Determines if a write mask is needed (wmode/en and wmask are equivalent). + * Populates the maskGran field of DefAnnotatedMemory + * Annotations: + * - maskGran = (dataType size) / (number of mask bits) + * - i.e. 1 if bitmask, 8 if bytemask, absent for no mask + * TODO(shunshou): Add floorplan info? + */ +object ResolveMaskGranularity extends Pass { + def name = "Resolve Mask Granularity" + + /** Returns the number of mask bits, if used + */ + def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { + val wenOrigin = getOrigin(connects)(wen) + val wmaskOrigin = connects.keys filter + (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)} + // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) + val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + if (redundantMask) None else Some(wmaskOrigin.size) + } + + /** Only annotate memories that are candidates for memory macro replacements + * i.e. rw, w + r (read, write 1 cycle delay) + */ + def updateStmts(connects: Connects)(s: Statement): Statement = s match { + case m: DefAnnotatedMemory => + val dataBits = bitWidth(m.dataType) + val rwMasks = m.readwriters map (rw => + getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers map (w => + getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val maskGran = (rwMasks ++ wMasks).head match { + case None => None + case Some(maskBits) => Some(dataBits / maskBits) + } + m.copy(maskGran = maskGran) + case sx => sx map updateStmts(connects) + } + + def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) + def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) +} |
