diff options
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateMemMacros.scala | 121 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 2 |
2 files changed, 59 insertions, 64 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index 7da290b7..7ced7a99 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -2,13 +2,13 @@ package firrtl.passes -import scala.collection.mutable -import AnalysisUtils._ -import firrtl.WrappedExpression._ -import firrtl.ir._ import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import WrappedExpression.weq +import MemPortUtils.memPortField +import AnalysisUtils._ case class AppendableInfo(fields: Map[String, Any]) extends Info { def append(a: Map[String, Any]) = this.copy(fields = fields ++ a) @@ -44,35 +44,40 @@ object AnalysisUtils { // use case: compare if two nodes have the same origin // limitation: only works in a module (stops @ module inputs) // TODO: more thorough (i.e. a + 0 = a) - def getConnectOrigin(connects: Connects, node: String): Expression = { - if (connects contains node) getOrigin(connects, connects(node)) - else EmptyExpression - } + def getConnectOrigin(connects: Connects)(node: String): Expression = + connects get node match { + case None => EmptyExpression + case Some(e) => getOrigin(connects, e) + } + def getConnectOrigin(connects: Connects, e: Expression): Expression = + getConnectOrigin(connects)(e.serialize) private def getOrigin(connects: Connects, e: Expression): Expression = e match { case Mux(cond, tv, fv, _) => val fvOrigin = getOrigin(connects, fv) val tvOrigin = getOrigin(connects, tv) val condOrigin = getOrigin(connects, cond) - if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin - else if (we(condOrigin) == we(one)) tvOrigin - else if (we(condOrigin) == we(zero)) fvOrigin - else if (we(tvOrigin) == we(fvOrigin)) tvOrigin - else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin + if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin + else if (weq(condOrigin, one)) tvOrigin + else if (weq(condOrigin, zero)) fvOrigin + else if (weq(tvOrigin, fvOrigin)) tvOrigin + else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin else e - case DoPrim(PrimOps.Or, args, consts, tpe) if args.contains(one) => one - case DoPrim(PrimOps.And, args, consts, tpe) if args.contains(zero) => zero + case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => - val extractionWidth = (msb-lsb)+1 + val extractionWidth = (msb - lsb) + 1 val nodeWidth = bitWidth(args.head.tpe) // if you're extracting the full bitwidth, then keep searching for origin - if (nodeWidth == extractionWidth) getOrigin(connects, args.head) - else e + if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => getOrigin(connects, args.head) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) - case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects.contains(e.serialize) && kind(e) != RegKind => - getConnectOrigin(connects, e.serialize) + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => + connects get e.serialize match { + case Some(ex) => getOrigin(connects, ex) + case None => e + } case _ => e } @@ -92,55 +97,45 @@ object AnalysisUtils { // memories equivalent as long as all fields (except name) are the same def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name) - } object AnnotateMemMacros extends Pass { + def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)" + + // returns # of mask bits if used + def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { + val wenOrigin = getConnectOrigin(connects, wen) + val wmaskOrigin = connects.keys filter + (_ startsWith wmask.serialize) map getConnectOrigin(connects) + // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) + val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + if (redundantMask) None else Some(wmaskOrigin.size) + } - def name = "Analyze sequential memories and tag with info for future passes (useMacro,maskGran)" - - def run(c: Circuit) = { - - def annotateModMems(m: Module) = { - val connects = getConnects(m) - - // returns # of mask bits if used - def getMaskBits(wen: String, wmask: String): Option[Int] = { - val wenOrigin = we(getConnectOrigin(connects, wen)) - val one1 = we(one) - val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects, x))) - // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) - val redundantMask = wmaskOrigin.map( x => (x == wenOrigin) || (x == one1) ).foldLeft(true)(_ && _) - if (redundantMask) None else Some(wmaskOrigin.length) - } - - def updateStmts(s: Statement): Statement = s match { - // only annotate memories that are candidates for memory macro replacements - // i.e. rw, w + r (read, write 1 cycle delay) - case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && - (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => - val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode", s"${m.name}.$w.wmask")) - val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en", s"${m.name}.$w.mask")) - val maskBits = (rwMasks ++ wMasks).head - val memAnnotations = Map("useMacro" -> true) - val tempInfo = appendInfo(m.info, memAnnotations) - if (maskBits == None) m.copy(info = tempInfo) - else m.copy(info = tempInfo.append("maskGran" -> dataBits/maskBits.get)) - case b: Block => b map updateStmts - case s => s + def updateStmts(connects: Connects)(s: Statement): Statement = s match { + // only annotate memories that are candidates for memory macro replacements + // i.e. rw, w + r (read, write 1 cycle delay) + case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && + (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => + val dataBits = bitWidth(m.dataType) + val rwMasks = m.readwriters map (rw => + getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers map (w => + getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val memAnnotations = Map("useMacro" -> true) + val tempInfo = appendInfo(m.info, memAnnotations) + (rwMasks ++ wMasks).head match { + case None => + m copy (info = tempInfo) + case Some(maskBits) => + m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits)) } - m.copy(body=updateStmts(m.body)) - } - - val updatedMods = c.modules map { - case m: Module => annotateModMems(m) - case m: ExtModule => m - } - c.copy(modules = updatedMods) + case s => s map updateStmts(connects) + } - } + def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) + def run(c: Circuit) = c copy (modules = (c.modules map annotateModMems)) } -// TODO: Add floorplan info?
\ No newline at end of file +// TODO: Add floorplan info? diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 7219b1ce..8aeafc9e 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -107,7 +107,7 @@ circuit Top : val circuit = InferTypes.run(ToWorkingIR.run(parse(input))) val m = circuit.modules.head.asInstanceOf[ir.Module] val connects = AnalysisUtils.getConnects(m) - val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects,"f").serialize + val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects)("f").serialize require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !") } |
