diff options
| author | Angie | 2016-08-30 00:27:30 -0700 |
|---|---|---|
| committer | jackkoenig | 2016-09-06 00:17:18 -0700 |
| commit | 6a05468ed0ece1ace3019666b16f2ae83ef76ef9 (patch) | |
| tree | 5d4e4244c61845334184a45f4df960c2d7ccb313 /src | |
| parent | a82f30d90940fd3c0386dee6f1ef21850c3c91c9 (diff) | |
Address style feedback and add tests for getConnectOrigin utility
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Namespace.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateMemMacros.scala | 115 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala | 138 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/MemUtils.scala | 56 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplSeqMem.scala | 34 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ReplaceMemMacros.scala | 113 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala | 54 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ReplSeqMemTests.scala | 47 |
8 files changed, 289 insertions, 270 deletions
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index 93e0ec76..e7a1cd10 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -85,7 +85,7 @@ object Namespace { namespace } - /** Initializes a [[Namespace]] for [[Module]] names in a [[Circuit]] */ + /** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */ def apply(c: Circuit): Namespace = { val namespace = new Namespace c.modules foreach { m => diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index a59ca4af..af58c7c5 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -1,15 +1,18 @@ +// See LICENSE for license details. + package firrtl.passes -import scala.collection.mutable.{ArrayBuffer,HashMap} +import scala.collection.mutable import AnalysisUtils._ import firrtl.WrappedExpression._ import firrtl.ir._ import firrtl._ import firrtl.Utils._ +import firrtl.Mappers._ -case class AppendableInfo(fields: Map[String,Any]) extends Info { - def append(a: Map[String,Any]) = this.copy(fields = fields ++ a) - def append(a: (String,Any)): AppendableInfo = append(Map(a)) +case class AppendableInfo(fields: Map[String, Any]) extends Info { + def append(a: Map[String, Any]) = this.copy(fields = fields ++ a) + def append(a: (String, Any)): AppendableInfo = append(Map(a)) def get(f: String) = fields.get(f) override def equals(b: Any) = b match { case i: AppendableInfo => fields - "info" == i.fields - "info" @@ -20,77 +23,64 @@ case class AppendableInfo(fields: Map[String,Any]) extends Info { object AnalysisUtils { def getConnects(m: Module) = { - val connects = HashMap[String, Expression]() - def getConnects(s: Statement): Unit = s match { - case s: Connect => - connects(s.loc.serialize) = s.expr - case s: PartialConnect => - connects(s.loc.serialize) = s.expr - case s: DefNode => - connects(s.name) = s.value - case s: Block => - s.stmts foreach getConnects - case _ => + val connects = mutable.HashMap[String, Expression]() + def getConnects(s: Statement): Statement = { + s map getConnects match { + case Connect(_, loc, expr) => + connects(loc.serialize) = expr + case DefNode(_, name, value) => + connects(name) = value + case _ => // do nothing + } + s // return because we only have map and not foreach } getConnects(m.body) connects.toMap } - // only works in a module + // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS. + // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation, + // the function will try to search backwards past the primop/mux. + // use case: compare if two nodes have the same origin + // limitation: only works in a module (stops @ module inputs) + // TODO: more thorough (i.e. a + 0 = a) def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = { - if (connects contains node) getConnectOrigin(connects,connects(node)) + if (connects contains node) getOrigin(connects, connects(node)) else EmptyExpression } - def checkLit(e: Expression) = e match { - case l : Literal => true - case _ => false - } - - def getOrigin(connects: Map[String, Expression], e: Expression) = e match { - case DoPrim(_,_,_,_) => getConnectOrigin(connects,e) - case l if (checkLit(l)) => e - case _ => getConnectOrigin(connects,e.serialize) - } - - // backward searches until PrimOp, Lit or non-trivial Mux appears - // technically, you should keep searching through PrimOp, because a node + 0 is still itself, - // a node shifted by 0 is still itself, etc. - // TODO: handle validif???, more thorough - private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { + private def getOrigin(connects: Map[String, Expression], e: Expression): Expression = e match { case Mux(cond, tv, fv, _) => - val fvOrigin = getOrigin(connects,fv) - val tvOrigin = getOrigin(connects,tv) - val condOrigin = getOrigin(connects,cond) + val fvOrigin = getOrigin(connects, fv) + val tvOrigin = getOrigin(connects, tv) + val condOrigin = getOrigin(connects, cond) if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin else if (we(condOrigin) == we(one)) tvOrigin else if (we(condOrigin) == we(zero)) fvOrigin else if (we(tvOrigin) == we(fvOrigin)) tvOrigin else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin else e - case DoPrim(op, args, consts, tpe) if op == PrimOps.Or && args.contains(one) => one - case DoPrim(op, args, consts, tpe) if op == PrimOps.And && args.contains(zero) => zero - case DoPrim(op, args, consts, tpe) if op == PrimOps.Bits => - val msb = consts(0) - val lsb = consts(1) + case DoPrim(PrimOps.Or, args, consts, tpe) if args.contains(one) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args.contains(zero) => zero + case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => val extractionWidth = (msb-lsb)+1 val nodeWidth = bitWidth(args.head.tpe) // if you're extracting the full bitwidth, then keep searching for origin - if (nodeWidth == extractionWidth) getOrigin(connects,args.head) + if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e - case DoPrim(op, args, _, _) if (op == PrimOps.AsUInt || op == PrimOps.AsSInt || op == PrimOps.AsClock) => - getOrigin(connects,args.head) + case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => + getOrigin(connects, args.head) case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects contains e.serialize => - getConnectOrigin(connects,e.serialize) + getConnectOrigin(connects, e.serialize) case _ => e } - def appendInfo[T <: Info](info: T, add: Map[String,Any]) = info match { + def appendInfo[T <: Info](info: T, add: Map[String, Any]) = info match { case i: AppendableInfo => i.append(add) case _ => AppendableInfo(fields = add + ("info" -> info)) } - def appendInfo[T <: Info](info: T, add: (String,Any)): AppendableInfo = appendInfo(info,Map(add)) - def getInfo[T <: Info](info: T, k: String) = info match{ + def appendInfo[T <: Info](info: T, add: (String, Any)): AppendableInfo = appendInfo(info, Map(add)) + def getInfo[T <: Info](info: T, k: String) = info match { case i: AppendableInfo => i.get(k) case _ => None } @@ -99,17 +89,8 @@ object AnalysisUtils { case _ => false } - def eqMems(a: DefMemory, b: DefMemory) = { - a.info == b.info && - a.dataType == b.dataType && - a.depth == b.depth && - a.writeLatency == b.writeLatency && - a.readLatency == b.readLatency && - a.readers == b.readers && - a.writers == b.writers && - a.readwriters == b.readwriters && - a.readUnderWrite == b.readUnderWrite - } + // memories equivalent as long as all fields (except name) are the same + def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name) } @@ -124,9 +105,9 @@ object AnnotateMemMacros extends Pass { // returns # of mask bits if used def getMaskBits(wen: String, wmask: String): Option[Int] = { - val wenOrigin = we(getConnectOrigin(connects,wen)) + val wenOrigin = we(getConnectOrigin(connects, wen)) val one1 = we(one) - val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects,x))) + val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects, x))) // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) val redundantMask = wmaskOrigin.map( x => (x == wenOrigin) || (x == one1) ).foldLeft(true)(_ && _) if (redundantMask) None else Some(wmaskOrigin.length) @@ -134,18 +115,18 @@ object AnnotateMemMacros extends Pass { def updateStmts(s: Statement): Statement = s match { // only annotate memories that are candidates for memory macro replacements - // i.e. rw, w + r (read,write 1 cycle delay) + // i.e. rw, w + r (read, write 1 cycle delay) case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 && (m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 => val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask")) - val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask")) + val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode", s"${m.name}.$w.wmask")) + val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en", s"${m.name}.$w.mask")) val maskBits = (rwMasks ++ wMasks).head val memAnnotations = Map("useMacro" -> true) - val tempInfo = appendInfo(m.info,memAnnotations) + val tempInfo = appendInfo(m.info, memAnnotations) if (maskBits == None) m.copy(info = tempInfo) else m.copy(info = tempInfo.append("maskGran" -> dataBits/maskBits.get)) - case b: Block => Block(b.stmts map updateStmts) + case b: Block => b map updateStmts case s => s } m.copy(body=updateStmts(m.body)) @@ -161,4 +142,4 @@ object AnnotateMemMacros extends Pass { } -// TODO: Add floorplan info?
\ No newline at end of file +// TODO: Add floorplan info? diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala index a7f7703b..816a179b 100644 --- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala +++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala @@ -1,3 +1,5 @@ +// See LICENSE for license details. + package firrtl.passes import firrtl.ir._ @@ -5,7 +7,8 @@ import firrtl._ import net.jcazevedo.moultingyaml._ import net.jcazevedo.moultingyaml.DefaultYamlProtocol._ import AnalysisUtils._ -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable +import firrtl.Mappers._ object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top @@ -17,13 +20,12 @@ object CustomYAMLProtocol extends DefaultYamlProtocol { } case class DimensionRules( - min: Int, - // step size - inc: Int, - max: Int, - // these values should not be used, regardless of min,inc,max - illegal: Option[List[Int]] -){ + min: Int, + // step size + inc: Int, + max: Int, + // these values should not be used, regardless of min,inc,max + illegal: Option[List[Int]]) { def getValid = { val range = (min to max by inc).toList range.filterNot(illegal.getOrElse(List[Int]()).toSet) @@ -31,9 +33,8 @@ case class DimensionRules( } case class MemDimension( - rules: Option[DimensionRules], - set: Option[List[Int]] -){ + rules: Option[DimensionRules], + set: Option[List[Int]]) { require ( if(rules == None) set != None else set == None, "Should specify either rules or a list of valid options, but not both" @@ -42,44 +43,42 @@ case class MemDimension( } case class SRAMConfig( - ymux: String = "", - ybank: String = "", - width: Int, - depth: Int, - xsplit: Int = 1, - ysplit: Int = 1 -){ + ymux: String = "", + ybank: String = "", + width: Int, + depth: Int, + xsplit: Int = 1, + ysplit: Int = 1) { // how many duplicate copies of this SRAM are needed def num = xsplit * ysplit def serialize(pattern: String): String = { - val fieldMap = getClass.getDeclaredFields.map{f => + val fieldMap = getClass.getDeclaredFields.map { f => f.setAccessible(true) f.getName -> f.get(this) - }.toMap + } toMap val fieldDelimiter = """\[.*?\]""".r val configOptions = fieldDelimiter.findAllIn(pattern).toList - configOptions.foldLeft(pattern)((b,a) => { + configOptions.foldLeft(pattern)((b, a) => { // Expects the contents of [] are valid configuration fields (otherwise key match error) val fieldVal = { - try fieldMap(a.substring(1,a.length-1)) - catch { case e: Exception => Error("**SRAM config field incorrect**") } + try fieldMap(a.substring(1, a.length-1)) + catch { case e: Exception => error("**SRAM config field incorrect**") } } - b.replace(a,fieldVal.toString) - }) + b.replace(a, fieldVal.toString) + } ) } } // Ex: https://www.ece.cmu.edu/~ece548/hw/hw5/meml80.pdf case class SRAMRules( - // column mux parameter (for adjusting aspect ratio) - ymux: (Int,String), - // vertical segmentation (banking -- tradeoff performance / area) - ybank: (Int,String), - width: MemDimension, - depth: MemDimension -){ + // column mux parameter (for adjusting aspect ratio) + ymux: (Int, String), + // vertical segmentation (banking -- tradeoff performance / area) + ybank: (Int, String), + width: MemDimension, + depth: MemDimension) { def getValidWidths = width.getValid def getValidDepths = depth.getValid def getValidConfig(width: Int, depth: Int): Option[SRAMConfig] = { @@ -88,13 +87,12 @@ case class SRAMRules( else None } - def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue,m.depth) + def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue, m.depth) } case class WMaskArg( - t: String, - f: String -) + t: String, + f: String) // vendor-specific compilers case class SRAMCompiler( @@ -116,8 +114,7 @@ case class SRAMCompiler( defaultArgs: Option[String], // default behavior (if not used) is to have wmask port width = datawidth/maskgran // if true: wmask port width pre-filled to datawidth - fillWMask: Boolean -){ + fillWMask: Boolean) { require(portType == "RW" || portType == "R,W", "Memory must be single port RW or dual port R,W") require( (configFile != None && configPattern != None && wMaskArg != None) || configFile == None, @@ -135,7 +132,7 @@ case class SRAMCompiler( private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter def append(m: DefMemory) : DefMemory = { - val validCombos = ArrayBuffer[SRAMConfig]() + val validCombos = mutable.ArrayBuffer[SRAMConfig]() defaultSearchOrdering foreach { r => val config = r.getValidConfig(m) if (config != None) validCombos += config.get @@ -146,7 +143,7 @@ case class SRAMCompiler( if (validCombos.nonEmpty) validCombos.head else getBestAlternative(m) } - val usesMaskGran = containsInfo(m.info,"maskGran") + val usesMaskGran = containsInfo(m.info, "maskGran") if (configPattern != None) { val newConfig = usedConfig.serialize(configPattern.get) + "\n" val currentBuff = { @@ -156,8 +153,8 @@ case class SRAMCompiler( if (!currentBuff.toString.contains(newConfig)) currentBuff.append(newConfig) } - val temp = appendInfo(m.info,"sramConfig" -> usedConfig) - val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp,"maskGran" -> 1) else temp + val temp = appendInfo(m.info, "sramConfig" -> usedConfig) + val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp m.copy(info = newInfo) } @@ -165,24 +162,24 @@ case class SRAMCompiler( // handled w/ a separate set of registers ? // split memory until width, depth achievable via given memory compiler private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = { - val validXRange = ArrayBuffer[SRAMRules]() - val validYRange = ArrayBuffer[SRAMRules]() + val validXRange = mutable.ArrayBuffer[SRAMRules]() + val validYRange = mutable.ArrayBuffer[SRAMRules]() defaultSearchOrdering foreach { r => if (m.width <= r.getValidWidths.max) validXRange += r if (m.depth <= r.getValidDepths.max) validYRange += r } if (validXRange.isEmpty && validYRange.isEmpty) - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2,depth = m.depth/2)) + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2)) else if (validXRange.isEmpty && validYRange.nonEmpty) - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2,depth = m.depth)) + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) else if (validXRange.nonEmpty && validYRange.isEmpty) - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width,depth = m.depth/2)) + getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) else if (validXRange.intersect(validYRange).nonEmpty) Seq(m) else - getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width,depth = m.depth/2)) ++ - getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2,depth = m.depth)) + getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++ + getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth)) } private def getBestAlternative(m: DefMemory): SRAMConfig = { @@ -190,39 +187,39 @@ case class SRAMCompiler( val minNum = validConfigs.map(x => x.num).min val validMinConfigs = validConfigs.filter(_.num == minNum) val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap - val squarestAspectRatio = validMinConfigsSquareness.map(x => x._1).min + val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min val validConfig = validMinConfigsSquareness(squarestAspectRatio) - val validRules = ArrayBuffer[SRAMRules]() + val validRules = mutable.ArrayBuffer[SRAMRules]() defaultSearchOrdering foreach { r => if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r } // TODO: don't just take first option + // TODO: More optimal split if particular value is in range but not supported + // TODO: Support up to 2 read ports, 2 write ports; should be power of 2? val bestRule = validRules.head val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit) } - def serialize() = { - // TODO - } + // TODO + def serialize = ??? } // TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true? // Or is this where module-level transforms (rather than circuit-level) make sense? -class YamlFileReader(file: String){ +class YamlFileReader(file: String) { import CustomYAMLProtocol._ def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { if (new java.io.File(file).exists) { val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n") - val optionOut = yamlString.parseYamls.map(x => + yamlString.parseYamls.flatMap(x => try Some(reader.read(x)) - catch {case e: Exception => None} + catch { case e: Exception => None } ) - optionOut.filter(_ != None).map(_.get) } - else Error("Yaml file doesn't exist!") + else error("Yaml file doesn't exist!") } } @@ -233,7 +230,7 @@ class YamlFileWriter(file: String) { def append(in: YamlValue) = { outputBuffer.append(separator + in.prettyPrint) } - def serialize = { + def dump = { val outputFile = new java.io.PrintWriter(file) outputFile.write(outputBuffer.toString) outputFile.close() @@ -249,12 +246,11 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { // TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist // TODO: Don't get first available? case class SRAMCompilerSet( - sp: Option[SRAMCompiler] = None, - dp: Option[SRAMCompiler] = None - ){ - def serialize() = { - if (sp != None) sp.get.serialize() - if (dp != None) dp.get.serialize() + sp: Option[SRAMCompiler] = None, + dp: Option[SRAMCompiler] = None) { + def serialize = { + if (sp != None) sp.get.serialize + if (dp != None) dp.get.serialize } } val sramCompilers = { @@ -270,18 +266,18 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { def run(c: Circuit) = { def annotateModMems(m: Module) = { def updateStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info,"useMacro") => { + case m: DefMemory if containsInfo(m.info, "useMacro") => { if (sramCompilers == None) m else { if (m.readwriters.length == 1) - if (sramCompilers.get.sp == None) Error("Design needs RW port memory compiler!") + if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!") else sramCompilers.get.sp.get.append(m) else - if (sramCompilers.get.dp == None) Error("Design needs R,W port memory compiler!") + if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!") else sramCompilers.get.dp.get.append(m) } } - case b: Block => Block(b.stmts map updateStmts) + case b: Block => b map updateStmts case s => s } m.copy(body=updateStmts(m.body)) @@ -292,4 +288,4 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass { } c.copy(modules = updatedMods) } -}
\ No newline at end of file +} diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala index 97915194..b235213a 100644 --- a/src/main/scala/firrtl/passes/MemUtils.scala +++ b/src/main/scala/firrtl/passes/MemUtils.scala @@ -41,22 +41,22 @@ object seqCat { case 2 => DoPrim(PrimOps.Cat, args, Seq.empty[BigInt], UIntType(UnknownWidth)) case _ => { val seqs = args.splitAt(args.length/2) - DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1),seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth)) + DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1), seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth)) } } } object toBits { def apply(e: Expression): Expression = e match { - case ex: WRef => hiercat(ex,ex.tpe) - case ex: WSubField => hiercat(ex,ex.tpe) - case ex: WSubIndex => hiercat(ex,ex.tpe) + case ex: WRef => hiercat(ex, ex.tpe) + case ex: WSubField => hiercat(ex, ex.tpe) + case ex: WSubIndex => hiercat(ex, ex.tpe) case t => error("Invalid operand expression for toBits!") } def hiercat(e: Expression, dt: Type): Expression = dt match { - case t:VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe))) - case t:BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) - case t:GroundType => e + case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER), t.tpe))) + case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe))) + case t: GroundType => e case t => error("Unknown type encountered in toBits!") } } @@ -64,16 +64,16 @@ object toBits { // TODO: make easier to understand object toBitMask { def apply(e: Expression, dataType: Type): Expression = e match { - case ex: WRef => hiermask(ex,ex.tpe,dataType) - case ex: WSubField => hiermask(ex,ex.tpe,dataType) - case ex: WSubIndex => hiermask(ex,ex.tpe,dataType) + case ex: WRef => hiermask(ex, ex.tpe, dataType) + case ex: WSubField => hiermask(ex, ex.tpe, dataType) + case ex: WSubIndex => hiermask(ex, ex.tpe, dataType) case t => error("Invalid operand expression for toBits!") } def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = (maskType, dataType) match { - case (mt:VectorType, dt:VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe))) - case (mt:BundleType, dt:BundleType) => seqCat((mt.fields zip dt.fields).map{ case (mf,df) => - hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) }) - case (mt:UIntType, dt:GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e)) + case (mt: VectorType, dt: VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe))) + case (mt: BundleType, dt: BundleType) => seqCat((mt.fields zip dt.fields).map { case (mf, df) => + hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) } ) + case (mt: UIntType, dt: GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e)) case (mt, dt) => error("Invalid type for mask component!") } } @@ -81,8 +81,8 @@ object toBitMask { object bitWidth { def apply(dt: Type): BigInt = widthOf(dt) def widthOf(dt: Type): BigInt = dt match { - case t:VectorType => t.size * bitWidth(t.tpe) - case t:BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) + case t: VectorType => t.size * bitWidth(t.tpe) + case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_) case UIntType(IntWidth(width)) => width case SIntType(IntWidth(width)) => width case t => error("Unknown type encountered in bitWidth!") @@ -101,12 +101,12 @@ object fromBits { } def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { val intWidth = bitWidth(lhst) - val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1,offset), UnknownType) - (offset + intWidth, Seq(Connect(NoInfo,lhs,sel))) + val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1, offset), UnknownType) + (offset + intWidth, Seq(Connect(NoInfo, lhs, sel))) } def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = { lhst match { - case t:VectorType => { + case t: VectorType => { var currentOffset = offset var stmts = Seq.empty[Statement] for (i <- (0 until t.size)) { @@ -116,7 +116,7 @@ object fromBits { } (currentOffset, stmts) } - case t:BundleType => { + case t: BundleType => { var currentOffset = offset var stmts = Seq.empty[Statement] for (f <- t.fields.reverse) { @@ -126,7 +126,7 @@ object fromBits { } (currentOffset, stmts) } - case t:GroundType => getPartGround(lhs, t, rhs, offset) + case t: GroundType => getPartGround(lhs, t, rhs, offset) case t => error("Unknown type encountered in fromBits!") } } @@ -145,7 +145,7 @@ object MemPortUtils { ) def getFillWMask(mem: DefMemory) = { - val maskGran = getInfo(mem.info,"maskGran") + val maskGran = getInfo(mem.info, "maskGran") if (maskGran == None) false else maskGran.get == 1 } @@ -156,7 +156,7 @@ object MemPortUtils { def wPortToBundle(mem: DefMemory) = { val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, mem.dataType) BundleType( - if (containsInfo(mem.info,"maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType)) + if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType)) else defaultSeq ) } @@ -164,7 +164,7 @@ object MemPortUtils { def wPortToFlattenBundle(mem: DefMemory) = { val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType)) BundleType( - if (containsInfo(mem.info,"maskGran")) { + if (containsInfo(mem.info, "maskGran")) { defaultSeq :+ { if (getFillWMask(mem)) Field("mask", Default, flattenType(mem.dataType)) else Field("mask", Default, flattenType(create_mask(mem.dataType))) @@ -175,26 +175,26 @@ object MemPortUtils { } // TODO: Don't use create_mask??? - def rwPortToBundle(mem: DefMemory) ={ + def rwPortToBundle(mem: DefMemory) = { val defaultSeq = defaultPortSeq(mem) ++ Seq( Field("wmode", Default, UIntType(IntWidth(1))), Field("wdata", Default, mem.dataType), Field("rdata", Flip, mem.dataType) ) BundleType( - if (containsInfo(mem.info,"maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType)) + if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType)) else defaultSeq ) } - def rwPortToFlattenBundle(mem: DefMemory) ={ + def rwPortToFlattenBundle(mem: DefMemory) = { val defaultSeq = defaultPortSeq(mem) ++ Seq( Field("wmode", Default, UIntType(IntWidth(1))), Field("wdata", Default, flattenType(mem.dataType)), Field("rdata", Flip, flattenType(mem.dataType)) ) BundleType( - if (containsInfo(mem.info,"maskGran")) { + if (containsInfo(mem.info, "maskGran")) { defaultSeq :+ { if (getFillWMask(mem)) Field("wmask", Default, flattenType(mem.dataType)) else Field("wmask", Default, flattenType(create_mask(mem.dataType))) diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala index a1b83efc..ae842a0b 100644 --- a/src/main/scala/firrtl/passes/ReplSeqMem.scala +++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala @@ -1,3 +1,5 @@ +// See LICENSE for license details. + package firrtl.passes import com.typesafe.scalalogging.LazyLogging @@ -12,16 +14,6 @@ case object InputConfigFileName extends PassOption case object OutputConfigFileName extends PassOption case object PassCircuitName extends PassOption -object Error { - def apply(msg: String) = throw new Exception(msg) -} -object Warn { - def apply[T <: Any](msg: String, r: T) = { - println(Console.RED + msg + Console.RESET) - r - } -} - object PassConfigUtil { def getPassOptions(t: String, usage: String = "") = { @@ -41,7 +33,7 @@ object PassConfigUtil { case "-c" :: value :: tail => nextPassOption(map + (PassCircuitName -> value), tail) case option :: tail => - Error("Unknown option " + option + usage) + error("Unknown option " + option + usage) } } nextPassOption(Map[PassOption, String](), passArgList) @@ -53,7 +45,7 @@ class ConfWriter(filename: String) { val outputBuffer = new java.io.CharArrayWriter def append(m: DefMemory) = { // legacy - val maskGran = getInfo(m.info,"maskGran") + val maskGran = getInfo(m.info, "maskGran") val writers = m.writers map (x => if (maskGran == None) "write" else "mwrite") val readers = List.fill(m.readers.length)("read") val readwriters = m.readwriters map (x => if (maskGran == None) "rw" else "mrw") @@ -89,17 +81,17 @@ Optional Arguments: -i<filename> Specify the input configuration file (for additional optimizations) """ - val passOptions = PassConfigUtil.getPassOptions(t,usage) + val passOptions = PassConfigUtil.getPassOptions(t, usage) val outputConfig = passOptions.getOrElse( OutputConfigFileName, - Error("No output config file provided for ReplSeqMem!" + usage) + error("No output config file provided for ReplSeqMem!" + usage) ) val passCircuit = passOptions.getOrElse( PassCircuitName, - Error("No circuit name specified for ReplSeqMem!" + usage) + error("No circuit name specified for ReplSeqMem!" + usage) ) val target = CircuitName(passCircuit) - def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name)) + def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit, "-c:"+n.name)) } @@ -109,11 +101,11 @@ class ReplSeqMem(transID: TransID) extends Transform with LazyLogging { case Some(p) => p get CircuitName(circuit.main) match { case Some(ReplSeqMemAnnotation(t, _)) => { - val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName,"") + val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "") val inConfigFile = { if (inputFileName.isEmpty) None else if (new java.io.File(inputFileName).exists) Some(new YamlFileReader(inputFileName)) - else Error("Input configuration file does not exist!") + else error("Input configuration file does not exist!") } val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t).get(OutputConfigFileName).get) @@ -131,17 +123,17 @@ class ReplSeqMem(transID: TransID) extends Transform with LazyLogging { InferTypes, ResolveGenders ) foldLeft circuit - ){ + ) { (c, pass) => val x = Utils.time(pass.name)(pass run c) logger debug x.serialize x - }, + } , None, Some(map) ) } - case _ => TransformResult(circuit, None, Some(map)) + case _ => error("Unexpected transform annotation") } case _ => TransformResult(circuit, None, Some(map)) } diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala index 94be10e7..54c522d7 100644 --- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala @@ -1,12 +1,15 @@ +// See LICENSE for license details. + package firrtl.passes -import scala.collection.mutable.{HashMap,ArrayBuffer} +import scala.collection.mutable import firrtl.ir._ import AnalysisUtils._ import MemTransformUtils._ import firrtl._ import firrtl.Utils._ import MemPortUtils._ +import firrtl.Mappers._ class ReplaceMemMacros(writer: ConfWriter) extends Pass { @@ -15,41 +18,41 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { def run(c: Circuit) = { lazy val moduleNamespace = Namespace(c) - val memMods = ArrayBuffer[DefModule]() - val uniqueMems = ArrayBuffer[DefMemory]() + val memMods = mutable.ArrayBuffer[DefModule]() + val uniqueMems = mutable.ArrayBuffer[DefMemory]() def updateMemMods(m: Module) = { - val memPortMap = HashMap[String,Expression]() + val memPortMap = mutable.HashMap[String, Expression]() def updateMemStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info,"useMacro") => - if(!containsInfo(m.info,"maskGran")) { - m.writers foreach {w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression} - m.readwriters foreach {w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression} + case m: DefMemory if containsInfo(m.info, "useMacro") => + if(!containsInfo(m.info, "maskGran")) { + m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression } + m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression } } - val infoT = getInfo(m.info,"info") - val info = if (infoT == None) NoInfo else infoT.get match {case i: Info => i} - val ref = getInfo(m.info,"ref") + val infoT = getInfo(m.info, "info") + val info = if (infoT == None) NoInfo else infoT.get match { case i: Info => i } + val ref = getInfo(m.info, "ref") // prototype mem if (ref == None) { val newWrapperName = moduleNamespace.newName(m.name) val newMemBBName = moduleNamespace.newName(m.name + "_ext") val newMem = m.copy(name = newMemBBName) - memMods ++= createMemModule(newMem,newWrapperName) + memMods ++= createMemModule(newMem, newWrapperName) uniqueMems += newMem WDefInstance(info, m.name, newWrapperName, UnknownType) } else { - val r = ref.get match {case s: String => s} + val r = ref.get match { case s: String => s } WDefInstance(info, m.name, r, UnknownType) } - case b: Block => Block(b.stmts map updateMemStmts) + case b: Block => b map updateMemStmts case s => s } val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(updatedMems,memPortMap.toMap) + val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap) m.copy(body = updatedConns) } @@ -66,64 +69,64 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { // from Albert def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) - val stmts = ArrayBuffer[Statement]() + val stmts = mutable.ArrayBuffer[Statement]() val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) val bbProto = m.copy(dataType = flattenType(m.dataType)) val bbioPorts = MemPortUtils.memToFlattenBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) - stmts += WDefInstance(NoInfo,m.name,m.name,UnknownType) + stmts += WDefInstance(NoInfo, m.name, m.name, UnknownType) val bbRef = createRef(m.name) - stmts ++= (m.readers zip bbProto.readers).map{ - case (x,y) => adaptReader(createRef(x),m,createSubField(bbRef,y),bbProto) - }.flatten - stmts ++= (m.writers zip bbProto.writers).map{ - case (x,y) => adaptWriter(createRef(x),m,createSubField(bbRef,y),bbProto) - }.flatten - stmts ++= (m.readwriters zip bbProto.readwriters).map{ - case (x,y) => adaptReadWriter(createRef(x),m,createSubField(bbRef,y),bbProto) - }.flatten - val wrapper = Module(NoInfo,wrapperName,wrapperioPorts,Block(stmts)) - val bb = ExtModule(NoInfo,m.name,bbioPorts) + stmts ++= (m.readers zip bbProto.readers).flatMap { + case (x, y) => adaptReader(createRef(x), m, createSubField(bbRef, y), bbProto) + } + stmts ++= (m.writers zip bbProto.writers).flatMap { + case (x, y) => adaptWriter(createRef(x), m, createSubField(bbRef, y), bbProto) + } + stmts ++= (m.readwriters zip bbProto.readwriters).flatMap { + case (x, y) => adaptReadWriter(createRef(x), m, createSubField(bbRef, y), bbProto) + } + val wrapper = Module(NoInfo, wrapperName, wrapperioPorts, Block(stmts)) + val bb = ExtModule(NoInfo, m.name, bbioPorts) // TODO: Annotate? -- use actual annotation map // add to conf file writer.append(m) - Seq(bb,wrapper) + Seq(bb, wrapper) } // TODO: get rid of copy pasta def adaptReader(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = Seq( - connectFields(bbPort,"addr",wrapperPort,"addr"), - connectFields(bbPort,"en",wrapperPort,"en"), - connectFields(bbPort,"clk",wrapperPort,"clk"), + connectFields(bbPort, "addr", wrapperPort, "addr"), + connectFields(bbPort, "en", wrapperPort, "en"), + connectFields(bbPort, "clk", wrapperPort, "clk"), fromBits( - WSubField(wrapperPort,"data",wrapperMem.dataType,UNKNOWNGENDER), - WSubField(bbPort,"data",bbMem.dataType,UNKNOWNGENDER) + WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER), + WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER) ) ) def adaptWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = { val defaultSeq = Seq( - connectFields(bbPort,"addr",wrapperPort,"addr"), - connectFields(bbPort,"en",wrapperPort,"en"), - connectFields(bbPort,"clk",wrapperPort,"clk"), + connectFields(bbPort, "addr", wrapperPort, "addr"), + connectFields(bbPort, "en", wrapperPort, "en"), + connectFields(bbPort, "clk", wrapperPort, "clk"), Connect( NoInfo, - WSubField(bbPort,"data",bbMem.dataType,UNKNOWNGENDER), - toBits(WSubField(wrapperPort,"data",wrapperMem.dataType,UNKNOWNGENDER)) + WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER), + toBits(WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER)) ) ) - if (containsInfo(wrapperMem.info,"maskGran")) { + if (containsInfo(wrapperMem.info, "maskGran")) { val wrapperMask = create_mask(wrapperMem.dataType) val fillWMask = getFillWMask(wrapperMem) val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) val rhs = { - if (fillWMask) toBitMask(WSubField(wrapperPort,"mask",wrapperMask,UNKNOWNGENDER),wrapperMem.dataType) - else toBits(WSubField(wrapperPort,"mask",wrapperMask,UNKNOWNGENDER)) + if (fillWMask) toBitMask(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType) + else toBits(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER)) } defaultSeq :+ Connect( NoInfo, - WSubField(bbPort,"mask",bbMask,UNKNOWNGENDER), + WSubField(bbPort, "mask", bbMask, UNKNOWNGENDER), rhs ) } @@ -132,31 +135,31 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass { def adaptReadWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = { val defaultSeq = Seq( - connectFields(bbPort,"addr",wrapperPort,"addr"), - connectFields(bbPort,"en",wrapperPort,"en"), - connectFields(bbPort,"clk",wrapperPort,"clk"), - connectFields(bbPort,"wmode",wrapperPort,"wmode"), + connectFields(bbPort, "addr", wrapperPort, "addr"), + connectFields(bbPort, "en", wrapperPort, "en"), + connectFields(bbPort, "clk", wrapperPort, "clk"), + connectFields(bbPort, "wmode", wrapperPort, "wmode"), Connect( NoInfo, - WSubField(bbPort,"wdata",bbMem.dataType,UNKNOWNGENDER), - toBits(WSubField(wrapperPort,"wdata",wrapperMem.dataType,UNKNOWNGENDER)) + WSubField(bbPort, "wdata", bbMem.dataType, UNKNOWNGENDER), + toBits(WSubField(wrapperPort, "wdata", wrapperMem.dataType, UNKNOWNGENDER)) ), fromBits( - WSubField(wrapperPort,"rdata",wrapperMem.dataType,UNKNOWNGENDER), - WSubField(bbPort,"rdata",bbMem.dataType,UNKNOWNGENDER) + WSubField(wrapperPort, "rdata", wrapperMem.dataType, UNKNOWNGENDER), + WSubField(bbPort, "rdata", bbMem.dataType, UNKNOWNGENDER) ) ) - if (containsInfo(wrapperMem.info,"maskGran")) { + if (containsInfo(wrapperMem.info, "maskGran")) { val wrapperMask = create_mask(wrapperMem.dataType) val fillWMask = getFillWMask(wrapperMem) val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask) val rhs = { - if (fillWMask) toBitMask(WSubField(wrapperPort,"wmask",wrapperMask,UNKNOWNGENDER),wrapperMem.dataType) - else toBits(WSubField(wrapperPort,"wmask",wrapperMask,UNKNOWNGENDER)) + if (fillWMask) toBitMask(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType) + else toBits(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER)) } defaultSeq :+ Connect( NoInfo, - WSubField(bbPort,"wmask",bbMask,UNKNOWNGENDER), + WSubField(bbPort, "wmask", bbMask, UNKNOWNGENDER), rhs ) } diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala index a4a910fd..d71b8ab8 100644 --- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala +++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala @@ -1,6 +1,8 @@ +// See LICENSE for license details. + package firrtl.passes -import scala.collection.mutable.{HashMap,ArrayBuffer} +import scala.collection.mutable import AnalysisUtils._ import MemTransformUtils._ import firrtl.ir._ @@ -10,27 +12,27 @@ import firrtl.Utils._ object MemTransformUtils { - def createRef(n: String) = WRef(n,UnknownType,ExpKind(),UNKNOWNGENDER) - def createSubField(exp: Expression, n: String) = WSubField(exp,n,UnknownType,UNKNOWNGENDER) + def createRef(n: String) = WRef(n, UnknownType, ExpKind(), UNKNOWNGENDER) + def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER) def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) = - Connect(NoInfo,createSubField(lref,lname),createSubField(rref,rname)) + Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname)) def getMemPortMap(m: DefMemory) = { - val memPortMap = HashMap[String,Expression]() - val defaultFields = Seq("addr","en","clk") + val memPortMap = mutable.HashMap[String, Expression]() + val defaultFields = Seq("addr", "en", "clk") val rFields = defaultFields :+ "data" val wFields = rFields :+ "mask" - val rwFields = defaultFields ++ Seq("wmode","wdata","rdata","wmask") + val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") def updateMemPortMap(ports: Seq[String], fields: Seq[String], portType: String) = - for (p <- ports.zipWithIndex; f <- fields) { - val newPort = createSubField(createRef(m.name),portType+p._2) - val field = createSubField(newPort,f) - memPortMap(s"${m.name}.${p._1}.${f}") = field + for ((p, i) <- ports.zipWithIndex; f <- fields) { + val newPort = createSubField(createRef(m.name), portType+i) + val field = createSubField(newPort, f) + memPortMap(s"${m.name}.${p}.${f}") = field } - updateMemPortMap(m.readers,rFields,"R") - updateMemPortMap(m.writers,wFields,"W") - updateMemPortMap(m.readwriters,rwFields,"RW") + updateMemPortMap(m.readers, rFields, "R") + updateMemPortMap(m.writers, wFields, "W") + updateMemPortMap(m.readwriters, rwFields, "RW") memPortMap.toMap } def createMemProto(m: DefMemory) = { @@ -40,14 +42,14 @@ object MemTransformUtils { m.copy(readers = rports, writers = wports, readwriters = rwports) } - def updateStmtRefs(s: Statement, repl: Map[String,Expression]): Statement = { + def updateStmtRefs(s: Statement, repl: Map[String, Expression]): Statement = { def updateRef(e: Expression): Expression = e map updateRef match { - case e: WSubField => repl getOrElse (e.serialize,e) + case e: WSubField => repl getOrElse (e.serialize, e) case e => e } def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match { - case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt - case Connect(info, WSubIndex(EmptyExpression,_,_,_), exp) => EmptyStmt + case Connect(info, EmptyExpression, exp) => EmptyStmt + case Connect(info, WSubIndex(EmptyExpression, _, _, _), exp) => EmptyStmt case s => s } updateStmtRefs(s) @@ -60,27 +62,27 @@ object UpdateDuplicateMemMacros extends Pass { def name = "Convert memory port names to be more meaningful and tag duplicate memories" def run(c: Circuit) = { - val uniqueMems = ArrayBuffer[DefMemory]() + val uniqueMems = mutable.ArrayBuffer[DefMemory]() def updateMemMods(m: Module) = { - val memPortMap = HashMap[String,Expression]() + val memPortMap = mutable.HashMap[String, Expression]() def updateMemStmts(s: Statement): Statement = s match { - case m: DefMemory if containsInfo(m.info,"useMacro") => + case m: DefMemory if containsInfo(m.info, "useMacro") => val updatedMem = createMemProto(m) memPortMap ++= getMemPortMap(m) - val proto = uniqueMems find (x => eqMems(x,updatedMem)) + val proto = uniqueMems find (x => eqMems(x, updatedMem)) if (proto == None) { uniqueMems += updatedMem updatedMem } - else updatedMem.copy(info = appendInfo(updatedMem.info,"ref" -> proto.get.name)) - case b: Block => Block(b.stmts map updateMemStmts) + else updatedMem.copy(info = appendInfo(updatedMem.info, "ref" -> proto.get.name)) + case b: Block => b map updateMemStmts case s => s } val updatedMems = updateMemStmts(m.body) - val updatedConns = updateStmtRefs(updatedMems,memPortMap.toMap) + val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap) m.copy(body = updatedConns) } @@ -92,4 +94,4 @@ object UpdateDuplicateMemMacros extends Pass { } } -// TODO: Module namespace?
\ No newline at end of file +// TODO: Module namespace? diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index 57a274c2..54ef6003 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -17,6 +17,51 @@ class ReplSeqMemSpec extends SimpleTransformSpec { new EmitFirrtl(writer) ) + "ReplSeqMem Utility -- getConnectOrigin" should + "determine connect origin across nodes/PrimOps even if ConstProp isn't performed" in { + def checkConnectOrigin(hurdle: String, origin: String) = { + val input = s""" +circuit Top : + module Top : + input a: UInt<1> + input b: UInt<1> + input e: UInt<1> + output c: UInt<1> + output f: UInt<1> + node d = $hurdle + c <= d + f <= c +""".stripMargin + + val circuit = InferTypes.run(ToWorkingIR.run(parse(input))) + val m = circuit.modules.head.asInstanceOf[ir.Module] + val connects = AnalysisUtils.getConnects(m) + val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects,"f").serialize + require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !") + } + + val tests = List( + """mux(a, UInt<1>("h1"), UInt<1>("h0"))""" -> "a", + """mux(UInt<1>("h1"), a, b)""" -> "a", + """mux(UInt<1>("h0"), a, b)""" -> "b", + "mux(b, a, a)" -> "a", + """mux(a, a, UInt<1>("h0"))""" -> "a", + "mux(a, b, e)" -> "mux(a, b, e)", + """or(a, UInt<1>("h1"))""" -> """UInt<1>("h1")""", + """and(a, UInt<1>("h0"))""" -> """UInt<1>("h0")""", + """UInt<1>("h1")""" -> """UInt<1>("h1")""", + "asUInt(a)" -> "a", + "asSInt(a)" -> "a", + "asClock(a)" -> "a", + "a" -> "a", + "or(a, b)" -> "or(a, b)", + "bits(a, 0, 0)" -> "a" + ) + + tests.foreach{ case(hurdle, origin) => checkConnectOrigin(hurdle, origin) } + + } + "ReplSeqMem" should "generate blackbox wrappers (no wmask, r, w ports)" in { val input = """ circuit sram6t : @@ -116,7 +161,7 @@ circuit sram6t : val writer = new java.io.StringWriter execute(writer, aMap, input, check) val confOut = read(confLoc) - require(confOut==checkConf,"Conf file incorrect!") + require(confOut==checkConf, "Conf file incorrect!") (new java.io.File(confLoc)).delete() } } |
