diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Namespace.scala | 15 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateMemMacros.scala | 141 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala | 174 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferReadWrite.scala | 249 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/MemUtils.scala | 80 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplSeqMem.scala | 96 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceMemMacros.scala | 229 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala | 148 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/InferReadWriteSpec.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 2 |
10 files changed, 530 insertions, 606 deletions
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index 952670cf..1e922673 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -57,8 +57,6 @@ class Namespace private { } object Namespace { - def apply(): Namespace = new Namespace - // Initializes a namespace from a Module def apply(m: DefModule): Namespace = { val namespace = new Namespace @@ -69,7 +67,7 @@ object Namespace { case s: Block => s.stmts flatMap buildNamespaceStmt case _ => Nil } - namespace.namespace ++= (m.ports collect { case dec: IsDeclaration => dec.name }) + namespace.namespace ++= m.ports map (_.name) m match { case in: Module => namespace.namespace ++= buildNamespaceStmt(in.body) @@ -82,9 +80,14 @@ object Namespace { /** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */ def apply(c: Circuit): Namespace = { val namespace = new Namespace - c.modules foreach { m => - namespace.namespace += m.name - } + namespace.namespace ++= c.modules map (_.name) + namespace + } + + /** Initializes a [[Namespace]] from arbitrary strings **/ + def apply(names: Seq[String] = Nil): Namespace = { + val namespace = new Namespace + namespace.namespace ++= names namespace } } diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index 58e10a66..7ced7a99 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -2,13 +2,13 @@ package firrtl.passes -import scala.collection.mutable -import AnalysisUtils._ -import firrtl.WrappedExpression._ -import firrtl.ir._ import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import WrappedExpression.weq +import MemPortUtils.memPortField +import AnalysisUtils._ case class AppendableInfo(fields: Map[String, Any]) extends Info { def append(a: Map[String, Any]) = this.copy(fields = fields ++ a) @@ -21,22 +21,22 @@ case class AppendableInfo(fields: Map[String, Any]) extends Info { } object AnalysisUtils { - - def getConnects(m: Module) = { - val connects = mutable.HashMap[String, Expression]() - def getConnects(s: Statement): Statement = { - s map getConnects match { + type Connects = collection.mutable.HashMap[String, Expression] + def getConnects(m: DefModule): Connects = { + def getConnects(connects: Connects)(s: Statement): Statement = { + s match { case Connect(_, loc, expr) => connects(loc.serialize) = expr case DefNode(_, name, value) => connects(name) = value case _ => // do nothing } - s // return because we only have map and not foreach + s map getConnects(connects) } - getConnects(m.body) - connects.toMap - } + val connects = new Connects + m map getConnects(connects) + connects + } // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS. // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation, @@ -44,35 +44,40 @@ object AnalysisUtils { // use case: compare if two nodes have the same origin // limitation: only works in a module (stops @ module inputs) // TODO: more thorough (i.e. a + 0 = a) - def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = { - if (connects contains node) getOrigin(connects, connects(node)) - else EmptyExpression - } + def getConnectOrigin(connects: Connects)(node: String): Expression = + connects get node match { + case None => EmptyExpression + case Some(e) => getOrigin(connects, e) + } + def getConnectOrigin(connects: Connects, e: Expression): Expression = + getConnectOrigin(connects)(e.serialize) - private def getOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { + private def getOrigin(connects: Connects, e: Expression): Expression = e match { case Mux(cond, tv, fv, _) => val fvOrigin = getOrigin(connects, fv) val tvOrigin = getOrigin(connects, tv) val condOrigin = getOrigin(connects, cond) - if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin - else if (we(condOrigin) == we(one)) tvOrigin - else if (we(condOrigin) == we(zero)) fvOrigin - else if (we(tvOrigin) == we(fvOrigin)) tvOrigin - else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin + if (weq(tvOrigin, one) && weq(fvOrigin, zero)) condOrigin + else if (weq(condOrigin, one)) tvOrigin + else if (weq(condOrigin, zero)) fvOrigin + else if (weq(tvOrigin, fvOrigin)) tvOrigin + else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin else e - case DoPrim(PrimOps.Or, args, consts, tpe) if args.contains(one) => one - case DoPrim(PrimOps.And, args, consts, tpe) if args.contains(zero) => zero + case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => - val extractionWidth = (msb-lsb)+1 + val extractionWidth = (msb - lsb) + 1 val nodeWidth = bitWidth(args.head.tpe) // if you're extracting the full bitwidth, then keep searching for origin - if (nodeWidth == extractionWidth) getOrigin(connects, args.head) - else e + if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => getOrigin(connects, args.head) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) - case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects.contains(e.serialize) && kind(e) != RegKind => - getConnectOrigin(connects, e.serialize) + case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => + connects get e.serialize match { + case Some(ex) => getOrigin(connects, ex) + case None => e + } case _ => e } @@ -92,55 +97,45 @@ object AnalysisUtils { // memories equivalent as long as all fields (except name) are the same def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name) - } object AnnotateMemMacros extends Pass { + def name = "Analyze sequential memories and tag with info for future passes(useMacro, maskGran)" + + // returns # of mask bits if used + def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { + val wenOrigin = getConnectOrigin(connects, wen) + val wmaskOrigin = connects.keys filter + (_ startsWith wmask.serialize) map getConnectOrigin(connects) + // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) + val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + if (redundantMask) None else Some(wmaskOrigin.size) + } - def name = "Analyze sequential memories and tag with info for future passes (useMacro,maskGran)" - - def run(c: Circuit) = { - - def annotateModMems(m: Module) = { - val connects = getConnects(m) - - // returns # of mask bits if used - def getMaskBits(wen: String, wmask: String): Option[Int] = { - val wenOrigin = we(getConnectOrigin(connects, wen)) - val one1 = we(one) - val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects, x))) - // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) - val redundantMask = wmaskOrigin.map( x => (x == wenOrigin) || (x == one1) ).foldLeft(true)(_ && _) - if (redundantMask) None else Some(wmaskOrigin.length) - } - - def updateStmts(s: Statement): Statement = s match { - // only annotate memories that are candidates for memory macro replacements - // i.e. rw, w + r (read, write 1 cycle delay) - case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && - (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => - val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode", s"${m.name}.$w.wmask")) - val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en", s"${m.name}.$w.mask")) - val maskBits = (rwMasks ++ wMasks).head - val memAnnotations = Map("useMacro" -> true) - val tempInfo = appendInfo(m.info, memAnnotations) - if (maskBits == None) m.copy(info = tempInfo) - else m.copy(info = tempInfo.append("maskGran" -> dataBits/maskBits.get)) - case b: Block => b map updateStmts - case s => s + def updateStmts(connects: Connects)(s: Statement): Statement = s match { + // only annotate memories that are candidates for memory macro replacements + // i.e. rw, w + r (read, write 1 cycle delay) + case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && + (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => + val dataBits = bitWidth(m.dataType) + val rwMasks = m.readwriters map (rw => + getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers map (w => + getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val memAnnotations = Map("useMacro" -> true) + val tempInfo = appendInfo(m.info, memAnnotations) + (rwMasks ++ wMasks).head match { + case None => + m copy (info = tempInfo) + case Some(maskBits) => + m.copy(info = tempInfo.append("maskGran" -> dataBits / maskBits)) } - m.copy(body=updateStmts(m.body)) - } - - val updatedMods = c.modules map { - case m: Module => annotateModMems(m) - case m: ExtModule => m - } - c.copy(modules = updatedMods) + case s => s map updateStmts(connects) + } - } + def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) + def run(c: Circuit) = c copy (modules = (c.modules map annotateModMems)) } -// TODO: Add floorplan info?
\ No newline at end of file +// TODO: Add floorplan info? diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala index 816a179b..f80d4a0c 100644 --- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala +++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala @@ -2,13 +2,16 @@ package firrtl.passes -import firrtl.ir._ import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import Utils.error +import AnalysisUtils._ + import net.jcazevedo.moultingyaml._ import net.jcazevedo.moultingyaml.DefaultYamlProtocol._ -import AnalysisUtils._ import scala.collection.mutable -import firrtl.Mappers._ +import java.io.{File, CharArrayWriter, PrintWriter} object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top @@ -36,7 +39,7 @@ case class MemDimension( rules: Option[DimensionRules], set: Option[List[Int]]) { require ( - if(rules == None) set != None else set == None, + if (rules == None) set != None else set == None, "Should specify either rules or a list of valid options, but not both" ) def getValid = set.getOrElse(rules.get.getValid).sorted @@ -128,15 +131,12 @@ case class SRAMCompiler( rules.find(r => r.ymux._1 == x && r.ybank._1 == y).get } - private val maskConfigOutputBuffer = new java.io.CharArrayWriter - private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter + private val maskConfigOutputBuffer = new CharArrayWriter + private val noMaskConfigOutputBuffer = new CharArrayWriter def append(m: DefMemory) : DefMemory = { - val validCombos = mutable.ArrayBuffer[SRAMConfig]() - defaultSearchOrdering foreach { r => - val config = r.getValidConfig(m) - if (config != None) validCombos += config.get - } + val validCombos = (defaultSearchOrdering map (_ getValidConfig m) + collect { case Some(config) => config }) // non empty if successfully found compiler option that supports depth/width // TODO: don't just take first option val usedConfig = { @@ -144,18 +144,19 @@ case class SRAMCompiler( else getBestAlternative(m) } val usesMaskGran = containsInfo(m.info, "maskGran") - if (configPattern != None) { - val newConfig = usedConfig.serialize(configPattern.get) + "\n" - val currentBuff = { - if (usesMaskGran) maskConfigOutputBuffer - else noMaskConfigOutputBuffer - } - if (!currentBuff.toString.contains(newConfig)) - currentBuff.append(newConfig) + configPattern match { + case None => + case Some(p) => + val newConfig = usedConfig.serialize(p) + "\n" + val currentBuff = { + if (usesMaskGran) maskConfigOutputBuffer + else noMaskConfigOutputBuffer + } + if (!currentBuff.toString.contains(newConfig)) currentBuff append newConfig } val temp = appendInfo(m.info, "sramConfig" -> usedConfig) - val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp - m.copy(info = newInfo) + val newInfo = if (usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp + m copy (info = newInfo) } // TODO: Should you really be splitting in 2 if, say, depth is 1 more than allowed? should be thresholded and @@ -164,47 +165,47 @@ case class SRAMCompiler( private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = { val validXRange = mutable.ArrayBuffer[SRAMRules]() val validYRange = mutable.ArrayBuffer[SRAMRules]() - defaultSearchOrdering foreach { r => + defaultSearchOrdering foreach { r => if (m.width <= r.getValidWidths.max) validXRange += r if (m.depth <= r.getValidDepths.max) validYRange += r } - - if (validXRange.isEmpty && validYRange.isEmpty) - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2)) - else if (validXRange.isEmpty && validYRange.nonEmpty) - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) - else if (validXRange.nonEmpty && validYRange.isEmpty) - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) - else if (validXRange.intersect(validYRange).nonEmpty) - Seq(m) - else - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++ + (validXRange.isEmpty, validYRange.isEmpty) match { + case (true, true) => + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2)) + case (true, false) => + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) + case (false, true) => + getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) + case (false, false) if validXRange.intersect(validYRange).nonEmpty => + Seq(m) + case (false, false) => + getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) + } } private def getBestAlternative(m: DefMemory): SRAMConfig = { val validConfigs = getInRange(SRAMConfig(width = bitWidth(m.dataType).intValue, depth = m.depth)) - val minNum = validConfigs.map(x => x.num).min + val minNum = validConfigs.map(_.num).min val validMinConfigs = validConfigs.filter(_.num == minNum) - val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap - val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min + val validMinConfigsSquareness = validMinConfigs.map( + x => math.abs(x.width.toDouble / x.depth - 1) -> x).toMap + val squarestAspectRatio = validMinConfigsSquareness.unzip._1.min val validConfig = validMinConfigsSquareness(squarestAspectRatio) - val validRules = mutable.ArrayBuffer[SRAMRules]() - defaultSearchOrdering foreach { r => - if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r - } + val validRules = defaultSearchOrdering filter (r => + (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max)) // TODO: don't just take first option // TODO: More optimal split if particular value is in range but not supported // TODO: Support up to 2 read ports, 2 write ports; should be power of 2? val bestRule = validRules.head val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get - bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit) + (bestRule.getValidConfig(width = memWidth, depth = memDepth).get + copy (xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)) } // TODO def serialize = ??? - } // TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true? @@ -212,11 +213,11 @@ case class SRAMCompiler( class YamlFileReader(file: String) { import CustomYAMLProtocol._ def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { - if (new java.io.File(file).exists) { + if (new File(file).exists) { val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n") - yamlString.parseYamls.flatMap(x => - try Some(reader.read(x)) - catch { case e: Exception => None } + yamlString.parseYamls flatMap (x => + try Some(reader read x) + catch { case e: Exception => None } ) } else error("Yaml file doesn't exist!") @@ -225,22 +226,20 @@ class YamlFileReader(file: String) { class YamlFileWriter(file: String) { import CustomYAMLProtocol._ - val outputBuffer = new java.io.CharArrayWriter + val outputBuffer = new CharArrayWriter val separator = "--- \n" - def append(in: YamlValue) = { - outputBuffer.append(separator + in.prettyPrint) + def append(in: YamlValue) { + outputBuffer append s"$separator${in.prettyPrint}" } - def dump = { - val outputFile = new java.io.PrintWriter(file) - outputFile.write(outputBuffer.toString) - outputFile.close() + def dump { + val outputFile = new PrintWriter(file) + outputFile write outputBuffer.toString + outputFile.close } } class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { - import CustomYAMLProtocol._ - def name = "Annotate memories with valid split depths, widths, #\'s" // TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist @@ -249,43 +248,42 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { sp: Option[SRAMCompiler] = None, dp: Option[SRAMCompiler] = None) { def serialize = { - if (sp != None) sp.get.serialize - if (dp != None) dp.get.serialize + sp match { + case None => + case Some(p) => p.serialize + } + dp match { + case None => + case Some(p) => p.serialize + } } } - val sramCompilers = { - if (reader == None) None - else { - val compilers = reader.get.parse[SRAMCompiler] - val sp = compilers.find(_.portType == "RW") - val dp = compilers.find(_.portType == "R,W") + + val sramCompilers = reader match { + case None => None + case Some(r) => + val compilers = r.parse[SRAMCompiler] + val sp = compilers find (_.portType == "RW") + val dp = compilers find (_.portType == "R,W") Some(SRAMCompilerSet(sp = sp, dp = dp)) - } } - def run(c: Circuit) = { - def annotateModMems(m: Module) = { - def updateStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => { - if (sramCompilers == None) m - else { - if (m.readwriters.length == 1) - if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!") - else sramCompilers.get.sp.get.append(m) - else - if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!") - else sramCompilers.get.dp.get.append(m) - } - } - case b: Block => b map updateStmts - case s => s - } - m.copy(body=updateStmts(m.body)) - } - val updatedMods = c.modules map { - case m: Module => annotateModMems(m) - case m: ExtModule => m + def updateStmts(s: Statement): Statement = s match { + case m: DefMemory if containsInfo(m.info, "useMacro") => sramCompilers match { + case None => m + case Some(compiler) if (m.readwriters.length == 1) => + compiler.sp match { + case None => error("Design needs RW port memory compiler!") + case Some(p) => p append m + } + case Some(compiler) => + compiler.dp match { + case None => error("Design needs R,W port memory compiler!") + case Some(p) => p append m + } } - c.copy(modules = updatedMods) - } + case s => s map updateStmts + } + + def run(c: Circuit) = c copy (modules = (c.modules map (_ map updateStmts))) } diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala index 9fbd6ab3..ec996fdb 100644 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -27,13 +27,14 @@ MODIFICATIONS. package firrtl.passes -import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap} -import com.typesafe.scalalogging.LazyLogging - import firrtl._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ +import firrtl.Utils.{one, zero, BoolType} +import MemPortUtils.memPortField +import AnalysisUtils.{Connects, getConnects} +import WrappedExpression.weq import Annotations._ case class InferReadWriteAnnotation(t: String, tID: TransID) @@ -50,153 +51,133 @@ case class InferReadWriteAnnotation(t: String, tID: TransID) object InferReadWritePass extends Pass { def name = "Infer ReadWrite Ports" - def inferReadWrite(m: Module) = { - import WrappedExpression.we - val connects = HashMap[String, Expression]() - val repl = HashMap[String, Expression]() - val stmts = ArrayBuffer[Statement]() - val zero = we(UIntLiteral(0, IntWidth(1))) - val one = we(UIntLiteral(1, IntWidth(1))) + type Netlist = collection.mutable.HashMap[String, Expression] + type Statements = collection.mutable.ArrayBuffer[Statement] + type PortSet = collection.mutable.HashSet[String] + + private implicit def toString(e: Expression) = e.serialize - // find all wire connections - def analyze(s: Statement): Unit = s match { - case s: Connect => - connects(s.loc.serialize) = s.expr - case s: PartialConnect => - connects(s.loc.serialize) = s.expr - case s: DefNode => - connects(s.name) = s.value - case s: Block => - s.stmts foreach analyze - case _ => + def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match { + // No ConstProp yet... + case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) => + getProductTerms(connects)(cond) + // Visit each term of AND operation + case DoPrim(op, args, consts, tpe) if op == And => + e +: (args flatMap getProductTerms(connects)) + // Visit connected nodes to references + case _: WRef | _: WSubField | _: WSubIndex => connects get e match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) } + // Otherwise just return itself + case _ => Seq(e) + } - def getProductTermsFromExp(e: Expression): Seq[Expression] = - e match { - // No ConstProp yet... - case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero => - cond +: getProductTerms(cond.serialize) - // Visit each term of AND operation - case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTermsFromExp) - // Visit connected nodes to references - case _: WRef | _: SubField | _: SubIndex | _: SubAccess => - e +: getProductTerms(e.serialize) - // Otherwise just return itselt - case _ => - List(e) - } + def checkComplement(a: Expression, b: Expression) = (a, b) match { + // b ?= Not(a) + case (_, DoPrim(Not, args, _, _)) => weq(args.head, a) + // a ?= Not(b) + case (DoPrim(Not, args, _, _), _) => weq(args.head, b) + // b ?= Eq(a, 0) or b ?= Eq(0, a) + case (_, DoPrim(Eq, args, _, _)) => + weq(args(0), a) && weq(args(1), zero) || + weq(args(1), a) && weq(args(0), zero) + // a ?= Eq(b, 0) or b ?= Eq(0, a) + case (DoPrim(Eq, args, _, _), _) => + weq(args(0), b) && weq(args(1), zero) || + weq(args(1), b) && weq(args(0), zero) + case _ => false + } - def getProductTerms(node: String): Seq[Expression] = - if (connects contains node) getProductTermsFromExp(connects(node)) else Nil - def checkComplement(a: Expression, b: Expression) = (a, b) match { - // b ?= Not(a) - case (_, DoPrim(op, args, _, _)) if op == Not => - args.head.serialize == a.serialize - // a ?= Not(b) - case (DoPrim(op, args, _, _), _) if op == Not => - args.head.serialize == b.serialize - // b ?= Eq(a, 0) or b ?= Eq(0, a) - case (_, DoPrim(op, args, _, _)) if op == Eq => - args(0).serialize == a.serialize && we(args(1)) == zero || - args(1).serialize == a.serialize && we(args(0)) == zero - // a ?= Eq(b, 0) or b ?= Eq(0, a) - case (DoPrim(op, args, _, _), _) if op == Eq => - args(0).serialize == b.serialize && we(args(1)) == zero || - args(1).serialize == b.serialize && we(args(0)) == zero - case _ => false + def replaceExp(repl: Netlist)(e: Expression): Expression = + e map replaceExp(repl) match { + case e: WSubField => repl getOrElse (e.serialize, e) + case e => e } - def inferReadWrite(s: Statement): Statement = s map inferReadWrite match { - // infer readwrite ports only for non combinational memories - case mem: DefMemory if mem.readLatency > 0 => - val bt = UIntType(IntWidth(1)) - val ut = UnknownType - val ug = UNKNOWNGENDER - val readers = HashSet[String]() - val writers = HashSet[String]() - val readwriters = ArrayBuffer[String]() - for (w <- mem.writers ; r <- mem.readers) { - val wp = getProductTerms(s"${mem.name}.$w.en") - val rp = getProductTerms(s"${mem.name}.$r.en") - if (wp exists (a => rp exists (b => checkComplement(a, b)))) { - val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet - // Uniquify names by examining all ports of the memory - var rw = (for { - idx <- Stream from 0 - newName = s"rw_$idx" - if !allPorts(newName) - } yield newName).head - val rw_exp = WSubField(WRef(mem.name, ut, MemKind, ug), rw, ut, ug) - readwriters += rw - readers += r - writers += w - repl(s"${mem.name}.$r.en") = EmptyExpression - repl(s"${mem.name}.$r.clk") = EmptyExpression - repl(s"${mem.name}.$r.addr") = EmptyExpression - repl(s"${mem.name}.$r.data") = WSubField(rw_exp, "rdata", mem.dataType, MALE) - repl(s"${mem.name}.$w.en") = WSubField(rw_exp, "wmode", bt, FEMALE) - repl(s"${mem.name}.$w.clk") = EmptyExpression - repl(s"${mem.name}.$w.addr") = EmptyExpression - repl(s"${mem.name}.$w.data") = WSubField(rw_exp, "wdata", mem.dataType, FEMALE) - repl(s"${mem.name}.$w.mask") = WSubField(rw_exp, "wmask", ut, FEMALE) - stmts += Connect(NoInfo, WSubField(rw_exp, "clk", ClockType, FEMALE), - WRef("clk", ClockType, NodeKind, MALE)) - stmts += Connect(NoInfo, WSubField(rw_exp, "en", bt, FEMALE), - DoPrim(Or, List(connects(s"${mem.name}.$r.en"), connects(s"${mem.name}.$w.en")), Nil, bt)) - stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE), - Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"), - connects(s"${mem.name}.$r.addr"), ut)) - } - } - if (readwriters.isEmpty) mem else DefMemory(mem.info, - mem.name, mem.dataType, mem.depth, mem.writeLatency, mem.readLatency, - mem.readers filterNot readers, mem.writers filterNot writers, - mem.readwriters ++ readwriters) + def replaceStmt(repl: Netlist)(s: Statement): Statement = + s map replaceStmt(repl) map replaceExp(repl) match { + case Connect(_, EmptyExpression, _) => EmptyStmt case s => s } - - def replaceExp(e: Expression): Expression = - e map replaceExp match { - case e: WSubField => repl getOrElse (e.serialize, e) - case e => e + + def inferReadWriteStmt(connects: Connects, + repl: Netlist, + stmts: Statements) + (s: Statement): Statement = s match { + // infer readwrite ports only for non combinational memories + case mem: DefMemory if mem.readLatency > 0 => + val ut = UnknownType + val ug = UNKNOWNGENDER + val readers = new PortSet + val writers = new PortSet + val readwriters = collection.mutable.ArrayBuffer[String]() + val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) + for (w <- mem.writers ; r <- mem.readers) { + val wp = getProductTerms(connects)(memPortField(mem, w, "en")) + val rp = getProductTerms(connects)(memPortField(mem, r, "en")) + if (wp exists (a => rp exists (b => checkComplement(a, b)))) { + val rw = namespace newName "rw" + val rwExp = createSubField(createRef(mem.name), rw) + readwriters += rw + readers += r + writers += w + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "addr")) = EmptyExpression + repl(memPortField(mem, r, "data")) = createSubField(rwExp, "rdata") + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = createSubField(rwExp, "wmode") + repl(memPortField(mem, w, "addr")) = EmptyExpression + repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata") + repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask") + stmts += Connect(NoInfo, createSubField(rwExp, "clk"), createRef("clk")) // TODO: fix it + stmts += Connect(NoInfo, createSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), + connects(memPortField(mem, w, "en"))), Nil, BoolType)) + stmts += Connect(NoInfo, createSubField(rwExp, "addr"), + Mux(connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), UnknownType)) + } } + if (readwriters.isEmpty) mem else mem copy ( + readers = mem.readers filterNot readers, + writers = mem.writers filterNot writers, + readwriters = mem.readwriters ++ readwriters) + case s => s map inferReadWriteStmt(connects, repl, stmts) + } - def replaceStmt(s: Statement): Statement = - s map replaceStmt map replaceExp match { - case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt - case s => s - } - - analyze(m.body) - Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq)) + def inferReadWrite(m: DefModule) = { + val connects = getConnects(m) + val repl = new Netlist + val stmts = new Statements + (m map inferReadWriteStmt(connects, repl, stmts) + map replaceStmt(repl)) match { + case m: ExtModule => m + case m: Module => m copy (body = Block(m.body +: stmts)) + } } - def run (c:Circuit) = Circuit(c.info, c.modules map { - case m: Module => inferReadWrite(m) - case m: ExtModule => m - }, c.main) + def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite(transID: TransID) extends Transform with LazyLogging { - def execute(circuit:Circuit, map: AnnotationMap) = - map get transID match { - case Some(p) => p get CircuitName(circuit.main) match { - case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq( - InferReadWritePass, - CheckInitialization, - ResolveKinds, - InferTypes, - ResolveGenders) foldLeft circuit){ (c, pass) => - val x = Utils.time(pass.name)(pass run c) - logger debug x.serialize - x - }, None, Some(map)) - case _ => TransformResult(circuit, None, Some(map)) - } - case _ => TransformResult(circuit, None, Some(map)) +class InferReadWrite(transID: TransID) extends Transform with SimpleRun { + def passSeq = Seq( + InferReadWritePass, + CheckInitialization, + InferTypes, + ResolveKinds, + ResolveGenders + ) + def execute(c: Circuit, map: AnnotationMap) = map get transID match { + case Some(p) => p get CircuitName(c.main) match { + case Some(InferReadWriteAnnotation(_, _)) => run(c, passSeq) + case _ => error("Unexpected annotation for InferReadWrite") } + case _ => TransformResult(c) + } } diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala index d2557f8d..1091db5f 100644 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ b/src/main/scala/firrtl/passes/MemUtils.scala @@ -145,11 +145,27 @@ object createMask { } } -object MemPortUtils { +object createRef { + def apply(n: String, t: Type = UnknownType, k: Kind = ExpKind) = WRef(n, t, k, UNKNOWNGENDER) +} + +object createSubField { + def apply(exp: Expression, n: String) = WSubField(exp, n, field_type(exp.tpe, n), UNKNOWNGENDER) +} - import AnalysisUtils._ +object connectFields { + def apply(lref: Expression, lname: String, rref: Expression, rname: String) = + Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) +} - def flattenType(t: Type) = UIntType(IntWidth(bitWidth(t))) +object flattenType { + def apply(t: Type) = UIntType(IntWidth(bitWidth(t))) +} + +object MemPortUtils { + type MemPortMap = collection.mutable.HashMap[String, Expression] + type Memories = collection.mutable.ArrayBuffer[DefMemory] + type Modules = collection.mutable.ArrayBuffer[DefModule] def defaultPortSeq(mem: DefMemory) = Seq( Field("addr", Default, UIntType(IntWidth(ceilLog2(mem.depth) max 1))), @@ -157,64 +173,10 @@ object MemPortUtils { Field("clk", Default, ClockType) ) - def getFillWMask(mem: DefMemory) = - getInfo(mem.info, "maskGran") match { - case None => false - case Some(maskGran) => maskGran == 1 - } - - def rPortToBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) - def rPortToFlattenBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) - - def wPortToBundle(mem: DefMemory) = BundleType( - (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ - (if (!containsInfo(mem.info, "maskGran")) Nil - else Seq(Field("mask", Default, createMask(mem.dataType)))) - ) - def wPortToFlattenBundle(mem: DefMemory) = BundleType( - (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ - (if (!containsInfo(mem.info, "maskGran")) Nil - else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType))) - else Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))) - ) - // TODO: Don't use createMask??? - - def rwPortToBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) ++ Seq( - Field("wmode", Default, BoolType), - Field("wdata", Default, mem.dataType), - Field("rdata", Flip, mem.dataType) - ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil - else Seq(Field("wmask", Default, createMask(mem.dataType))) - ) - ) - - def rwPortToFlattenBundle(mem: DefMemory) = BundleType( - defaultPortSeq(mem) ++ Seq( - Field("wmode", Default, BoolType), - Field("wdata", Default, flattenType(mem.dataType)), - Field("rdata", Flip, flattenType(mem.dataType)) - ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil - else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType))) - else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) - ) - ) - - def memToBundle(s: DefMemory) = BundleType( - s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) - - def memToFlattenBundle(s: DefMemory) = BundleType( - s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) - // Todo: merge it with memToBundle def memType(mem: DefMemory) = { - val rType = rPortToBundle(mem) + val rType = BundleType(defaultPortSeq(mem) :+ + Field("data", Flip, mem.dataType)) val wType = BundleType(defaultPortSeq(mem) ++ Seq( Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType)))) diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala index ae842a0b..c2c1b303 100644 --- a/src/main/scala/firrtl/passes/ReplSeqMem.scala +++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala @@ -2,12 +2,12 @@ package firrtl.passes -import com.typesafe.scalalogging.LazyLogging import firrtl._ import firrtl.ir._ import Annotations._ -import java.io.Writer import AnalysisUtils._ +import Utils.error +import java.io.{File, CharArrayWriter, PrintWriter} sealed trait PassOption case object InputConfigFileName extends PassOption @@ -15,11 +15,9 @@ case object OutputConfigFileName extends PassOption case object PassCircuitName extends PassOption object PassConfigUtil { - + type PassOptionMap = Map[PassOption, String] + def getPassOptions(t: String, usage: String = "") = { - - type PassOptionMap = Map[PassOption, String] - // can't use space to delimit sub arguments (otherwise, Driver.scala will throw error) val passArgList = t.split(":").toList @@ -38,25 +36,24 @@ object PassConfigUtil { } nextPassOption(Map[PassOption, String](), passArgList) } - } class ConfWriter(filename: String) { - val outputBuffer = new java.io.CharArrayWriter + val outputBuffer = new CharArrayWriter def append(m: DefMemory) = { // legacy val maskGran = getInfo(m.info, "maskGran") - val writers = m.writers map (x => if (maskGran == None) "write" else "mwrite") val readers = List.fill(m.readers.length)("read") - val readwriters = m.readwriters map (x => if (maskGran == None) "rw" else "mrw") - val ports = (writers ++ readers ++ readwriters).mkString(",") - val maskGranConf = if (maskGran == None) "" else s"mask_gran ${maskGran.get}" + val writers = List.fill(m.writers.length)(if (maskGran == None) "write" else "mwrite") + val readwriters = List.fill(m.readwriters.length)(if (maskGran == None) "rw" else "mrw") + val ports = (writers ++ readers ++ readwriters) mkString "," + val maskGranConf = maskGran match { case None => "" case Some(p) => s"mask_gran $p" } val width = bitWidth(m.dataType) val conf = s"name ${m.name} depth ${m.depth} width ${width} ports ${ports} ${maskGranConf} \n" outputBuffer.append(conf) } def serialize = { - val outputFile = new java.io.PrintWriter(filename) + val outputFile = new PrintWriter(filename) outputFile.write(outputBuffer.toString) outputFile.close() } @@ -91,50 +88,35 @@ Optional Arguments: error("No circuit name specified for ReplSeqMem!" + usage) ) val target = CircuitName(passCircuit) - def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit, "-c:"+n.name)) - + def duplicate(n: Named) = this copy (t = (t replace (s"-c:$passCircuit", s"-c:${n.name}"))) } -class ReplSeqMem(transID: TransID) extends Transform with LazyLogging { - def execute(circuit:Circuit, map: AnnotationMap) = - map get transID match { - case Some(p) => p get CircuitName(circuit.main) match { - case Some(ReplSeqMemAnnotation(t, _)) => { - - val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") - val inConfigFile = { - if (inputFileName.isEmpty) None - else if (new java.io.File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else error("Input configuration file does not exist!") - } - - val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t).get(OutputConfigFileName).get) - TransformResult( - ( - Seq( - Legalize, - AnnotateMemMacros, - UpdateDuplicateMemMacros, - new AnnotateValidMemConfigs(inConfigFile), - new ReplaceMemMacros(outConfigFile), - RemoveEmpty, - CheckInitialization, - ResolveKinds, // Must be run for the transform to work! - InferTypes, - ResolveGenders - ) foldLeft circuit - ) { - (c, pass) => - val x = Utils.time(pass.name)(pass run c) - logger debug x.serialize - x - } , - None, - Some(map) - ) - } - case _ => error("Unexpected transform annotation") - } - case _ => TransformResult(circuit, None, Some(map)) +class ReplSeqMem(transID: TransID) extends Transform with SimpleRun { + def passSeq(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter) = + Seq(Legalize, + AnnotateMemMacros, + UpdateDuplicateMemMacros, + new AnnotateValidMemConfigs(inConfigFile), + new ReplaceMemMacros(outConfigFile), + RemoveEmpty, + CheckInitialization, + InferTypes, + ResolveKinds, // Must be run for the transform to work! + ResolveGenders) + + def execute(c: Circuit, map: AnnotationMap) = map get transID match { + case Some(p) => p get CircuitName(c.main) match { + case Some(ReplSeqMemAnnotation(t, _)) => + val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") + val inConfigFile = { + if (inputFileName.isEmpty) None + else if (new File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) + else error("Input configuration file does not exist!") + } + val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t)(OutputConfigFileName)) + run(c, passSeq(inConfigFile, outConfigFile)) + case _ => error("Unexpected transform annotation") } -}
\ No newline at end of file + case _ => TransformResult(c) + } +} diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala index 7bb9c6c4..33a371a0 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala @@ -2,91 +2,35 @@ package firrtl.passes -import scala.collection.mutable -import firrtl.ir._ -import AnalysisUtils._ -import MemTransformUtils._ import firrtl._ +import firrtl.ir._ import firrtl.Utils._ -import MemPortUtils._ import firrtl.Mappers._ +import MemPortUtils._ +import MemTransformUtils._ +import AnalysisUtils._ class ReplaceMemMacros(writer: ConfWriter) extends Pass { - - def name = "Replace memories with black box wrappers (optimizes when write mask isn't needed) + configuration file" - - def run(c: Circuit) = { - - lazy val moduleNamespace = Namespace(c) - val memMods = mutable.ArrayBuffer[DefModule]() - val uniqueMems = mutable.ArrayBuffer[DefMemory]() - - def updateMemMods(m: Module) = { - val memPortMap = mutable.HashMap[String, Expression]() - - def updateMemStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => - if(!containsInfo(m.info, "maskGran")) { - m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression } - m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression } - } - val infoT = getInfo(m.info, "info") - val info = if (infoT == None) NoInfo else infoT.get match { case i: Info => i } - val ref = getInfo(m.info, "ref") - - // prototype mem - if (ref == None) { - val newWrapperName = moduleNamespace.newName(m.name) - val newMemBBName = moduleNamespace.newName(m.name + "_ext") - val newMem = m.copy(name = newMemBBName) - memMods ++= createMemModule(newMem, newWrapperName) - uniqueMems += newMem - WDefInstance(info, m.name, newWrapperName, UnknownType) - } - else { - val r = ref.get match { case s: String => s } - WDefInstance(info, m.name, r, UnknownType) - } - case b: Block => b map updateMemStmts - case s => s - } - - val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap) - m.copy(body = updatedConns) - } - - val updatedMods = c.modules map { - case m: Module => updateMemMods(m) - case m: ExtModule => m - } - - // print conf - writer.serialize - c.copy(modules = updatedMods ++ memMods.toSeq) - } + def name = "Replace memories with black box wrappers" + + " (optimizes when write mask isn't needed) + configuration file" // from Albert def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) - val stmts = mutable.ArrayBuffer[Statement]() - val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) - val bbProto = m.copy(dataType = flattenType(m.dataType)) - val bbioPorts = MemPortUtils.memToFlattenBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) - - stmts += WDefInstance(NoInfo, m.name, m.name, UnknownType) - val bbRef = createRef(m.name) - stmts ++= (m.readers zip bbProto.readers).flatMap { - case (x, y) => adaptReader(createRef(x), m, createSubField(bbRef, y), bbProto) - } - stmts ++= (m.writers zip bbProto.writers).flatMap { - case (x, y) => adaptWriter(createRef(x), m, createSubField(bbRef, y), bbProto) - } - stmts ++= (m.readwriters zip bbProto.readwriters).flatMap { - case (x, y) => adaptReadWriter(createRef(x), m, createSubField(bbRef, y), bbProto) - } - val wrapper = Module(NoInfo, wrapperName, wrapperioPorts, Block(stmts)) - val bb = ExtModule(NoInfo, m.name, bbioPorts) + val wrapperIoType = memToBundle(m) + val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbIoType = memToFlattenBundle(m) + val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbRef = createRef(m.name, bbIoType) + val hasMask = containsInfo(m.info, "maskGran") + val fillMask = getFillWMask(m) + def portRef(p: String) = createRef(p, field_type(wrapperIoType, p)) + val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ + (m.readers flatMap (r => adaptReader(portRef(r), createSubField(bbRef, r)))) ++ + (m.writers flatMap (w => adaptWriter(portRef(w), createSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), createSubField(bbRef, rw), hasMask, fillMask))) + val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) + val bb = ExtModule(NoInfo, m.name, bbIoPorts) // TODO: Annotate? -- use actual annotation map // add to conf file @@ -95,75 +39,86 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { } // TODO: get rid of copy pasta - def adaptReader(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = Seq( - connectFields(bbPort, "addr", wrapperPort, "addr"), - connectFields(bbPort, "en", wrapperPort, "en"), - connectFields(bbPort, "clk", wrapperPort, "clk"), - fromBits( - WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER), - WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER) - ) - ) - - def adaptWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = { - val defaultSeq = Seq( - connectFields(bbPort, "addr", wrapperPort, "addr"), - connectFields(bbPort, "en", wrapperPort, "en"), - connectFields(bbPort, "clk", wrapperPort, "clk"), - Connect( - NoInfo, - WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER), - toBits(WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER)) - ) - ) - if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = createMask(wrapperMem.dataType) - val fillWMask = getFillWMask(wrapperMem) - val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) - val rhs = { - if (fillWMask) toBitMask(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType) - else toBits(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER)) - } - defaultSeq :+ Connect( + def defaultConnects(wrapperPort: WRef, bbPort: WSubField) = + Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) + + def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean) = + if (fillMask) toBitMask(mask, dataType) else toBits(mask) + + def adaptReader(wrapperPort: WRef, bbPort: WSubField) = + defaultConnects(wrapperPort, bbPort) :+ + fromBits(createSubField(wrapperPort, "data"), createSubField(bbPort, "data")) + + def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = { + val wrapperData = createSubField(wrapperPort, "data") + val defaultSeq = defaultConnects(wrapperPort, bbPort) :+ + Connect(NoInfo, createSubField(bbPort, "data"), toBits(wrapperData)) + hasMask match { + case false => defaultSeq + case true => defaultSeq :+ Connect( NoInfo, - WSubField(bbPort, "mask", bbMask, UNKNOWNGENDER), - rhs + createSubField(bbPort, "mask"), + maskBits(createSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) ) } - else defaultSeq } - def adaptReadWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = { - val defaultSeq = Seq( - connectFields(bbPort, "addr", wrapperPort, "addr"), - connectFields(bbPort, "en", wrapperPort, "en"), - connectFields(bbPort, "clk", wrapperPort, "clk"), - connectFields(bbPort, "wmode", wrapperPort, "wmode"), - Connect( - NoInfo, - WSubField(bbPort, "wdata", bbMem.dataType, UNKNOWNGENDER), - toBits(WSubField(wrapperPort, "wdata", wrapperMem.dataType, UNKNOWNGENDER)) - ), - fromBits( - WSubField(wrapperPort, "rdata", wrapperMem.dataType, UNKNOWNGENDER), - WSubField(bbPort, "rdata", bbMem.dataType, UNKNOWNGENDER) - ) - ) - if (containsInfo(wrapperMem.info, "maskGran")) { - val wrapperMask = createMask(wrapperMem.dataType) - val fillWMask = getFillWMask(wrapperMem) - val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) - val rhs = { - if (fillWMask) toBitMask(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType) - else toBits(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER)) - } - defaultSeq :+ Connect( + def adaptReadWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean) = { + val wrapperWData = createSubField(wrapperPort, "wdata") + val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( + fromBits(createSubField(wrapperPort, "rdata"), createSubField(bbPort, "rdata")), + connectFields(bbPort, "wmode", wrapperPort, "wmode"), + Connect(NoInfo, createSubField(bbPort, "wdata"), toBits(wrapperWData))) + hasMask match { + case false => defaultSeq + case true => defaultSeq :+ Connect( NoInfo, - WSubField(bbPort, "wmask", bbMask, UNKNOWNGENDER), - rhs + createSubField(bbPort, "wmask"), + maskBits(createSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) ) } - else defaultSeq } + def updateMemStmts(namespace: Namespace, + memPortMap: MemPortMap, + memMods: Modules) + (s: Statement): Statement = s match { + case m: DefMemory if containsInfo(m.info, "useMacro") => + if (!containsInfo(m.info, "maskGran")) { + m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression } + m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression } + } + val info = getInfo(m.info, "info") match { + case None => NoInfo + case Some(p: Info) => p + } + getInfo(m.info, "ref") match { + case None => + // prototype mem + val newWrapperName = namespace newName m.name + val newMemBBName = namespace newName s"${m.name}_ext" + val newMem = m copy (name = newMemBBName) + memMods ++= createMemModule(newMem, newWrapperName) + WDefInstance(info, m.name, newWrapperName, UnknownType) + case Some(ref: String) => + WDefInstance(info, m.name, ref, UnknownType) + } + case s => s map updateMemStmts(namespace, memPortMap, memMods) + } + + def updateMemMods(namespace: Namespace, memMods: Modules)(m: DefModule) = { + val memPortMap = new MemPortMap + + (m map updateMemStmts(namespace, memPortMap, memMods) + map updateStmtRefs(memPortMap)) + } + + def run(c: Circuit) = { + val namespace = Namespace(c) + val memMods = new Modules + val modules = c.modules map updateMemMods(namespace, memMods) + // print conf + writer.serialize + c copy (modules = modules ++ memMods) + } } diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala index 0098fa5f..fbff9bd6 100644 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala @@ -2,23 +2,73 @@ package firrtl.passes -import scala.collection.mutable -import AnalysisUtils._ -import MemTransformUtils._ -import firrtl.ir._ import firrtl._ -import firrtl.Mappers._ +import firrtl.ir._ import firrtl.Utils._ +import firrtl.Mappers._ +import AnalysisUtils._ +import MemPortUtils._ +import MemTransformUtils._ object MemTransformUtils { + def getFillWMask(mem: DefMemory) = + getInfo(mem.info, "maskGran") match { + case None => false + case Some(maskGran) => maskGran == 1 + } + + def rPortToBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + def rPortToFlattenBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + + def wPortToBundle(mem: DefMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)) ++ + (if (!containsInfo(mem.info, "maskGran")) Nil + else Seq(Field("mask", Default, createMask(mem.dataType)))) + ) + def wPortToFlattenBundle(mem: DefMemory) = BundleType( + (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ + (if (!containsInfo(mem.info, "maskGran")) Nil + else if (getFillWMask(mem)) Seq(Field("mask", Default, flattenType(mem.dataType))) + else Seq(Field("mask", Default, flattenType(createMask(mem.dataType))))) + ) + // TODO: Don't use createMask??? + + def rwPortToBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("rdata", Flip, mem.dataType) + ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil + else Seq(Field("wmask", Default, createMask(mem.dataType))) + ) + ) + + def rwPortToFlattenBundle(mem: DefMemory) = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("wmode", Default, BoolType), + Field("wdata", Default, flattenType(mem.dataType)), + Field("rdata", Flip, flattenType(mem.dataType)) + ) ++ (if (!containsInfo(mem.info, "maskGran")) Nil + else if (getFillWMask(mem)) Seq(Field("wmask", Default, flattenType(mem.dataType))) + else Seq(Field("wmask", Default, flattenType(createMask(mem.dataType)))) + ) + ) + + def memToBundle(s: DefMemory) = BundleType( + s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) + + def memToFlattenBundle(s: DefMemory) = BundleType( + s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) - def createRef(n: String) = WRef(n, UnknownType, ExpKind, UNKNOWNGENDER) - def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER) - def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) = - Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) def getMemPortMap(m: DefMemory) = { - val memPortMap = mutable.HashMap[String, Expression]() + val memPortMap = new MemPortMap val defaultFields = Seq("addr", "en", "clk") val rFields = defaultFields :+ "data" val wFields = rFields :+ "mask" @@ -33,35 +83,41 @@ object MemTransformUtils { updateMemPortMap(m.readers, rFields, "R") updateMemPortMap(m.writers, wFields, "W") updateMemPortMap(m.readwriters, rwFields, "RW") - memPortMap.toMap + memPortMap } + def createMemProto(m: DefMemory) = { val rports = (0 until m.readers.length) map (i => s"R$i") val wports = (0 until m.writers.length) map (i => s"W$i") val rwports = (0 until m.readwriters.length) map (i => s"RW$i") - m.copy(readers = rports, writers = wports, readwriters = rwports) + m copy (readers = rports, writers = wports, readwriters = rwports) } - def updateStmtRefs(s: Statement, repl: Map[String, Expression]): Statement = { - def updateRef(e: Expression): Expression = e map updateRef match { - case e => repl getOrElse (e.serialize, e) + def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { + def updateRef(e: Expression): Expression = { + val ex = e map updateRef + repl getOrElse (ex.serialize, ex) } + def hasEmptyExpr(stmt: Statement): Boolean = { var foundEmpty = false def testEmptyExpr(e: Expression): Expression = { - e map testEmptyExpr match { + e match { case EmptyExpression => foundEmpty = true case _ => } - e // map must return; no foreach + e map testEmptyExpr // map must return; no foreach } stmt map testEmptyExpr foundEmpty } - def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match { - case c: Connect if hasEmptyExpr(c) => EmptyStmt - case s => s - } + + def updateStmtRefs(s: Statement): Statement = + s map updateStmtRefs map updateRef match { + case c: Connect if hasEmptyExpr(c) => EmptyStmt + case s => s + } + updateStmtRefs(s) } @@ -71,37 +127,29 @@ object UpdateDuplicateMemMacros extends Pass { def name = "Convert memory port names to be more meaningful and tag duplicate memories" - def run(c: Circuit) = { - val uniqueMems = mutable.ArrayBuffer[DefMemory]() - - def updateMemMods(m: Module) = { - val memPortMap = mutable.HashMap[String, Expression]() - - def updateMemStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info, "useMacro") => - val updatedMem = createMemProto(m) - memPortMap ++= getMemPortMap(m) - val proto = uniqueMems find (x => eqMems(x, updatedMem)) - if (proto == None) { - uniqueMems += updatedMem - updatedMem - } - else updatedMem.copy(info = appendInfo(updatedMem.info, "ref" -> proto.get.name)) - case b: Block => b map updateMemStmts - case s => s + def updateMemStmts(uniqueMems: Memories, + memPortMap: MemPortMap) + (s: Statement): Statement = s match { + case m: DefMemory if containsInfo(m.info, "useMacro") => + val updatedMem = createMemProto(m) + memPortMap ++= getMemPortMap(m) + uniqueMems find (x => eqMems(x, updatedMem)) match { + case None => + uniqueMems += updatedMem + updatedMem + case Some(proto) => + updatedMem copy (info = appendInfo(updatedMem.info, "ref" -> proto.name)) } + case s => s map updateMemStmts(uniqueMems, memPortMap) + } - val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap) - m.copy(body = updatedConns) - } - - val updatedMods = c.modules map { - case m: Module => updateMemMods(m) - case m: ExtModule => m - } - c.copy(modules = updatedMods) - } + def updateMemMods(m: DefModule) = { + val uniqueMems = new Memories + val memPortMap = new MemPortMap + (m map updateMemStmts(uniqueMems, memPortMap) + map updateStmtRefs(memPortMap)) + } + def run(c: Circuit) = c copy (modules = (c.modules map updateMemMods)) } // TODO: Module namespace? diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index 7e3383b2..3af018bd 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -38,7 +38,7 @@ class InferReadWriteSpec extends SimpleTransformSpec { val name = "Check Infer ReadWrite Ports" def findReadWrite(s: Statement): Boolean = s match { case s: DefMemory if s.readLatency > 0 && s.readwriters.size == 1 => - s.name == "mem" && s.readwriters.head == "rw_0" + s.name == "mem" && s.readwriters.head == "rw" case s: Block => s.stmts exists findReadWrite case _ => false diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 7219b1ce..8aeafc9e 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -107,7 +107,7 @@ circuit Top : val circuit = InferTypes.run(ToWorkingIR.run(parse(input))) val m = circuit.modules.head.asInstanceOf[ir.Module] val connects = AnalysisUtils.getConnects(m) - val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects,"f").serialize + val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects)("f").serialize require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !") } |
