aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAngie2016-08-30 00:27:30 -0700
committerjackkoenig2016-09-06 00:17:18 -0700
commit6a05468ed0ece1ace3019666b16f2ae83ef76ef9 (patch)
tree5d4e4244c61845334184a45f4df960c2d7ccb313 /src
parenta82f30d90940fd3c0386dee6f1ef21850c3c91c9 (diff)
Address style feedback and add tests for getConnectOrigin utility
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Namespace.scala2
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala115
-rw-r--r--src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala138
-rw-r--r--src/main/scala/firrtl/passes/MemUtils.scala56
-rw-r--r--src/main/scala/firrtl/passes/ReplSeqMem.scala34
-rw-r--r--src/main/scala/firrtl/passes/ReplaceMemMacros.scala113
-rw-r--r--src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala54
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala47
8 files changed, 289 insertions, 270 deletions
diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala
index 93e0ec76..e7a1cd10 100644
--- a/src/main/scala/firrtl/Namespace.scala
+++ b/src/main/scala/firrtl/Namespace.scala
@@ -85,7 +85,7 @@ object Namespace {
namespace
}
- /** Initializes a [[Namespace]] for [[Module]] names in a [[Circuit]] */
+ /** Initializes a [[Namespace]] for [[ir.Module]] names in a [[ir.Circuit]] */
def apply(c: Circuit): Namespace = {
val namespace = new Namespace
c.modules foreach { m =>
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
index a59ca4af..af58c7c5 100644
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
@@ -1,15 +1,18 @@
+// See LICENSE for license details.
+
package firrtl.passes
-import scala.collection.mutable.{ArrayBuffer,HashMap}
+import scala.collection.mutable
import AnalysisUtils._
import firrtl.WrappedExpression._
import firrtl.ir._
import firrtl._
import firrtl.Utils._
+import firrtl.Mappers._
-case class AppendableInfo(fields: Map[String,Any]) extends Info {
- def append(a: Map[String,Any]) = this.copy(fields = fields ++ a)
- def append(a: (String,Any)): AppendableInfo = append(Map(a))
+case class AppendableInfo(fields: Map[String, Any]) extends Info {
+ def append(a: Map[String, Any]) = this.copy(fields = fields ++ a)
+ def append(a: (String, Any)): AppendableInfo = append(Map(a))
def get(f: String) = fields.get(f)
override def equals(b: Any) = b match {
case i: AppendableInfo => fields - "info" == i.fields - "info"
@@ -20,77 +23,64 @@ case class AppendableInfo(fields: Map[String,Any]) extends Info {
object AnalysisUtils {
def getConnects(m: Module) = {
- val connects = HashMap[String, Expression]()
- def getConnects(s: Statement): Unit = s match {
- case s: Connect =>
- connects(s.loc.serialize) = s.expr
- case s: PartialConnect =>
- connects(s.loc.serialize) = s.expr
- case s: DefNode =>
- connects(s.name) = s.value
- case s: Block =>
- s.stmts foreach getConnects
- case _ =>
+ val connects = mutable.HashMap[String, Expression]()
+ def getConnects(s: Statement): Statement = {
+ s map getConnects match {
+ case Connect(_, loc, expr) =>
+ connects(loc.serialize) = expr
+ case DefNode(_, name, value) =>
+ connects(name) = value
+ case _ => // do nothing
+ }
+ s // return because we only have map and not foreach
}
getConnects(m.body)
connects.toMap
}
- // only works in a module
+ // takes in a list of node-to-node connections in a given module and looks to find the origin of the LHS.
+ // if the source is a trivial primop/mux, etc. that has yet to be optimized via constant propagation,
+ // the function will try to search backwards past the primop/mux.
+ // use case: compare if two nodes have the same origin
+ // limitation: only works in a module (stops @ module inputs)
+ // TODO: more thorough (i.e. a + 0 = a)
def getConnectOrigin(connects: Map[String, Expression], node: String): Expression = {
- if (connects contains node) getConnectOrigin(connects,connects(node))
+ if (connects contains node) getOrigin(connects, connects(node))
else EmptyExpression
}
- def checkLit(e: Expression) = e match {
- case l : Literal => true
- case _ => false
- }
-
- def getOrigin(connects: Map[String, Expression], e: Expression) = e match {
- case DoPrim(_,_,_,_) => getConnectOrigin(connects,e)
- case l if (checkLit(l)) => e
- case _ => getConnectOrigin(connects,e.serialize)
- }
-
- // backward searches until PrimOp, Lit or non-trivial Mux appears
- // technically, you should keep searching through PrimOp, because a node + 0 is still itself,
- // a node shifted by 0 is still itself, etc.
- // TODO: handle validif???, more thorough
- private def getConnectOrigin(connects: Map[String, Expression], e: Expression): Expression = e match {
+ private def getOrigin(connects: Map[String, Expression], e: Expression): Expression = e match {
case Mux(cond, tv, fv, _) =>
- val fvOrigin = getOrigin(connects,fv)
- val tvOrigin = getOrigin(connects,tv)
- val condOrigin = getOrigin(connects,cond)
+ val fvOrigin = getOrigin(connects, fv)
+ val tvOrigin = getOrigin(connects, tv)
+ val condOrigin = getOrigin(connects, cond)
if (we(tvOrigin) == we(one) && we(fvOrigin) == we(zero)) condOrigin
else if (we(condOrigin) == we(one)) tvOrigin
else if (we(condOrigin) == we(zero)) fvOrigin
else if (we(tvOrigin) == we(fvOrigin)) tvOrigin
else if (we(fvOrigin) == we(zero) && we(condOrigin) == we(tvOrigin)) condOrigin
else e
- case DoPrim(op, args, consts, tpe) if op == PrimOps.Or && args.contains(one) => one
- case DoPrim(op, args, consts, tpe) if op == PrimOps.And && args.contains(zero) => zero
- case DoPrim(op, args, consts, tpe) if op == PrimOps.Bits =>
- val msb = consts(0)
- val lsb = consts(1)
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args.contains(one) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args.contains(zero) => zero
+ case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
val extractionWidth = (msb-lsb)+1
val nodeWidth = bitWidth(args.head.tpe)
// if you're extracting the full bitwidth, then keep searching for origin
- if (nodeWidth == extractionWidth) getOrigin(connects,args.head)
+ if (nodeWidth == extractionWidth) getOrigin(connects, args.head)
else e
- case DoPrim(op, args, _, _) if (op == PrimOps.AsUInt || op == PrimOps.AsSInt || op == PrimOps.AsClock) =>
- getOrigin(connects,args.head)
+ case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
+ getOrigin(connects, args.head)
case _: WRef | _: SubField | _: SubIndex | _: SubAccess if connects contains e.serialize =>
- getConnectOrigin(connects,e.serialize)
+ getConnectOrigin(connects, e.serialize)
case _ => e
}
- def appendInfo[T <: Info](info: T, add: Map[String,Any]) = info match {
+ def appendInfo[T <: Info](info: T, add: Map[String, Any]) = info match {
case i: AppendableInfo => i.append(add)
case _ => AppendableInfo(fields = add + ("info" -> info))
}
- def appendInfo[T <: Info](info: T, add: (String,Any)): AppendableInfo = appendInfo(info,Map(add))
- def getInfo[T <: Info](info: T, k: String) = info match{
+ def appendInfo[T <: Info](info: T, add: (String, Any)): AppendableInfo = appendInfo(info, Map(add))
+ def getInfo[T <: Info](info: T, k: String) = info match {
case i: AppendableInfo => i.get(k)
case _ => None
}
@@ -99,17 +89,8 @@ object AnalysisUtils {
case _ => false
}
- def eqMems(a: DefMemory, b: DefMemory) = {
- a.info == b.info &&
- a.dataType == b.dataType &&
- a.depth == b.depth &&
- a.writeLatency == b.writeLatency &&
- a.readLatency == b.readLatency &&
- a.readers == b.readers &&
- a.writers == b.writers &&
- a.readwriters == b.readwriters &&
- a.readUnderWrite == b.readUnderWrite
- }
+ // memories equivalent as long as all fields (except name) are the same
+ def eqMems(a: DefMemory, b: DefMemory) = a == b.copy(name = a.name)
}
@@ -124,9 +105,9 @@ object AnnotateMemMacros extends Pass {
// returns # of mask bits if used
def getMaskBits(wen: String, wmask: String): Option[Int] = {
- val wenOrigin = we(getConnectOrigin(connects,wen))
+ val wenOrigin = we(getConnectOrigin(connects, wen))
val one1 = we(one)
- val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects,x)))
+ val wmaskOrigin = connects.keys.toSeq.filter(_.startsWith(wmask)).map(x => we(getConnectOrigin(connects, x)))
// all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
val redundantMask = wmaskOrigin.map( x => (x == wenOrigin) || (x == one1) ).foldLeft(true)(_ && _)
if (redundantMask) None else Some(wmaskOrigin.length)
@@ -134,18 +115,18 @@ object AnnotateMemMacros extends Pass {
def updateStmts(s: Statement): Statement = s match {
// only annotate memories that are candidates for memory macro replacements
- // i.e. rw, w + r (read,write 1 cycle delay)
+ // i.e. rw, w + r (read, write 1 cycle delay)
case m: DefMemory if m.readLatency == 1 && m.writeLatency == 1 &&
(m.writers.length + m.readwriters.length) == 1 && m.readers.length <= 1 =>
val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode",s"${m.name}.$w.wmask"))
- val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en",s"${m.name}.$w.mask"))
+ val rwMasks = m.readwriters map (w => getMaskBits(s"${m.name}.$w.wmode", s"${m.name}.$w.wmask"))
+ val wMasks = m.writers map (w => getMaskBits(s"${m.name}.$w.en", s"${m.name}.$w.mask"))
val maskBits = (rwMasks ++ wMasks).head
val memAnnotations = Map("useMacro" -> true)
- val tempInfo = appendInfo(m.info,memAnnotations)
+ val tempInfo = appendInfo(m.info, memAnnotations)
if (maskBits == None) m.copy(info = tempInfo)
else m.copy(info = tempInfo.append("maskGran" -> dataBits/maskBits.get))
- case b: Block => Block(b.stmts map updateStmts)
+ case b: Block => b map updateStmts
case s => s
}
m.copy(body=updateStmts(m.body))
@@ -161,4 +142,4 @@ object AnnotateMemMacros extends Pass {
}
-// TODO: Add floorplan info? \ No newline at end of file
+// TODO: Add floorplan info?
diff --git a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
index a7f7703b..816a179b 100644
--- a/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
+++ b/src/main/scala/firrtl/passes/AnnotateValidMemConfigs.scala
@@ -1,3 +1,5 @@
+// See LICENSE for license details.
+
package firrtl.passes
import firrtl.ir._
@@ -5,7 +7,8 @@ import firrtl._
import net.jcazevedo.moultingyaml._
import net.jcazevedo.moultingyaml.DefaultYamlProtocol._
import AnalysisUtils._
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+import firrtl.Mappers._
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
@@ -17,13 +20,12 @@ object CustomYAMLProtocol extends DefaultYamlProtocol {
}
case class DimensionRules(
- min: Int,
- // step size
- inc: Int,
- max: Int,
- // these values should not be used, regardless of min,inc,max
- illegal: Option[List[Int]]
-){
+ min: Int,
+ // step size
+ inc: Int,
+ max: Int,
+ // these values should not be used, regardless of min,inc,max
+ illegal: Option[List[Int]]) {
def getValid = {
val range = (min to max by inc).toList
range.filterNot(illegal.getOrElse(List[Int]()).toSet)
@@ -31,9 +33,8 @@ case class DimensionRules(
}
case class MemDimension(
- rules: Option[DimensionRules],
- set: Option[List[Int]]
-){
+ rules: Option[DimensionRules],
+ set: Option[List[Int]]) {
require (
if(rules == None) set != None else set == None,
"Should specify either rules or a list of valid options, but not both"
@@ -42,44 +43,42 @@ case class MemDimension(
}
case class SRAMConfig(
- ymux: String = "",
- ybank: String = "",
- width: Int,
- depth: Int,
- xsplit: Int = 1,
- ysplit: Int = 1
-){
+ ymux: String = "",
+ ybank: String = "",
+ width: Int,
+ depth: Int,
+ xsplit: Int = 1,
+ ysplit: Int = 1) {
// how many duplicate copies of this SRAM are needed
def num = xsplit * ysplit
def serialize(pattern: String): String = {
- val fieldMap = getClass.getDeclaredFields.map{f =>
+ val fieldMap = getClass.getDeclaredFields.map { f =>
f.setAccessible(true)
f.getName -> f.get(this)
- }.toMap
+ } toMap
val fieldDelimiter = """\[.*?\]""".r
val configOptions = fieldDelimiter.findAllIn(pattern).toList
- configOptions.foldLeft(pattern)((b,a) => {
+ configOptions.foldLeft(pattern)((b, a) => {
// Expects the contents of [] are valid configuration fields (otherwise key match error)
val fieldVal = {
- try fieldMap(a.substring(1,a.length-1))
- catch { case e: Exception => Error("**SRAM config field incorrect**") }
+ try fieldMap(a.substring(1, a.length-1))
+ catch { case e: Exception => error("**SRAM config field incorrect**") }
}
- b.replace(a,fieldVal.toString)
- })
+ b.replace(a, fieldVal.toString)
+ } )
}
}
// Ex: https://www.ece.cmu.edu/~ece548/hw/hw5/meml80.pdf
case class SRAMRules(
- // column mux parameter (for adjusting aspect ratio)
- ymux: (Int,String),
- // vertical segmentation (banking -- tradeoff performance / area)
- ybank: (Int,String),
- width: MemDimension,
- depth: MemDimension
-){
+ // column mux parameter (for adjusting aspect ratio)
+ ymux: (Int, String),
+ // vertical segmentation (banking -- tradeoff performance / area)
+ ybank: (Int, String),
+ width: MemDimension,
+ depth: MemDimension) {
def getValidWidths = width.getValid
def getValidDepths = depth.getValid
def getValidConfig(width: Int, depth: Int): Option[SRAMConfig] = {
@@ -88,13 +87,12 @@ case class SRAMRules(
else
None
}
- def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue,m.depth)
+ def getValidConfig(m: DefMemory): Option[SRAMConfig] = getValidConfig(bitWidth(m.dataType).intValue, m.depth)
}
case class WMaskArg(
- t: String,
- f: String
-)
+ t: String,
+ f: String)
// vendor-specific compilers
case class SRAMCompiler(
@@ -116,8 +114,7 @@ case class SRAMCompiler(
defaultArgs: Option[String],
// default behavior (if not used) is to have wmask port width = datawidth/maskgran
// if true: wmask port width pre-filled to datawidth
- fillWMask: Boolean
-){
+ fillWMask: Boolean) {
require(portType == "RW" || portType == "R,W", "Memory must be single port RW or dual port R,W")
require(
(configFile != None && configPattern != None && wMaskArg != None) || configFile == None,
@@ -135,7 +132,7 @@ case class SRAMCompiler(
private val noMaskConfigOutputBuffer = new java.io.CharArrayWriter
def append(m: DefMemory) : DefMemory = {
- val validCombos = ArrayBuffer[SRAMConfig]()
+ val validCombos = mutable.ArrayBuffer[SRAMConfig]()
defaultSearchOrdering foreach { r =>
val config = r.getValidConfig(m)
if (config != None) validCombos += config.get
@@ -146,7 +143,7 @@ case class SRAMCompiler(
if (validCombos.nonEmpty) validCombos.head
else getBestAlternative(m)
}
- val usesMaskGran = containsInfo(m.info,"maskGran")
+ val usesMaskGran = containsInfo(m.info, "maskGran")
if (configPattern != None) {
val newConfig = usedConfig.serialize(configPattern.get) + "\n"
val currentBuff = {
@@ -156,8 +153,8 @@ case class SRAMCompiler(
if (!currentBuff.toString.contains(newConfig))
currentBuff.append(newConfig)
}
- val temp = appendInfo(m.info,"sramConfig" -> usedConfig)
- val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp,"maskGran" -> 1) else temp
+ val temp = appendInfo(m.info, "sramConfig" -> usedConfig)
+ val newInfo = if(usesMaskGran && fillWMask) appendInfo(temp, "maskGran" -> 1) else temp
m.copy(info = newInfo)
}
@@ -165,24 +162,24 @@ case class SRAMCompiler(
// handled w/ a separate set of registers ?
// split memory until width, depth achievable via given memory compiler
private def getInRange(m: SRAMConfig): Seq[SRAMConfig] = {
- val validXRange = ArrayBuffer[SRAMRules]()
- val validYRange = ArrayBuffer[SRAMRules]()
+ val validXRange = mutable.ArrayBuffer[SRAMRules]()
+ val validYRange = mutable.ArrayBuffer[SRAMRules]()
defaultSearchOrdering foreach { r =>
if (m.width <= r.getValidWidths.max) validXRange += r
if (m.depth <= r.getValidDepths.max) validYRange += r
}
if (validXRange.isEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2,depth = m.depth/2))
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = 2*m.ysplit, width = m.width/2, depth = m.depth/2))
else if (validXRange.isEmpty && validYRange.nonEmpty)
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2,depth = m.depth))
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
else if (validXRange.nonEmpty && validYRange.isEmpty)
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width,depth = m.depth/2))
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2))
else if (validXRange.intersect(validYRange).nonEmpty)
Seq(m)
else
- getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width,depth = m.depth/2)) ++
- getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2,depth = m.depth))
+ getInRange(SRAMConfig(xsplit = m.xsplit, ysplit = 2*m.ysplit, width = m.width, depth = m.depth/2)) ++
+ getInRange(SRAMConfig(xsplit = 2*m.xsplit, ysplit = m.ysplit, width = m.width/2, depth = m.depth))
}
private def getBestAlternative(m: DefMemory): SRAMConfig = {
@@ -190,39 +187,39 @@ case class SRAMCompiler(
val minNum = validConfigs.map(x => x.num).min
val validMinConfigs = validConfigs.filter(_.num == minNum)
val validMinConfigsSquareness = validMinConfigs.map(x => math.abs(x.width.toDouble/x.depth - 1) -> x).toMap
- val squarestAspectRatio = validMinConfigsSquareness.map(x => x._1).min
+ val squarestAspectRatio = validMinConfigsSquareness.map { case (aspectRatioDiff, _) => aspectRatioDiff } min
val validConfig = validMinConfigsSquareness(squarestAspectRatio)
- val validRules = ArrayBuffer[SRAMRules]()
+ val validRules = mutable.ArrayBuffer[SRAMRules]()
defaultSearchOrdering foreach { r =>
if (validConfig.width <= r.getValidWidths.max && validConfig.depth <= r.getValidDepths.max) validRules += r
}
// TODO: don't just take first option
+ // TODO: More optimal split if particular value is in range but not supported
+ // TODO: Support up to 2 read ports, 2 write ports; should be power of 2?
val bestRule = validRules.head
val memWidth = bestRule.getValidWidths.find(validConfig.width <= _).get
val memDepth = bestRule.getValidDepths.find(validConfig.depth <= _).get
bestRule.getValidConfig(width = memWidth, depth = memDepth).get.copy(xsplit = validConfig.xsplit, ysplit = validConfig.ysplit)
}
- def serialize() = {
- // TODO
- }
+ // TODO
+ def serialize = ???
}
// TODO: assumption that you would stick to just SRAMs or just RFs in a design -- is that true?
// Or is this where module-level transforms (rather than circuit-level) make sense?
-class YamlFileReader(file: String){
+class YamlFileReader(file: String) {
import CustomYAMLProtocol._
def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
if (new java.io.File(file).exists) {
val yamlString = scala.io.Source.fromFile(file).getLines.mkString("\n")
- val optionOut = yamlString.parseYamls.map(x =>
+ yamlString.parseYamls.flatMap(x =>
try Some(reader.read(x))
- catch {case e: Exception => None}
+ catch { case e: Exception => None }
)
- optionOut.filter(_ != None).map(_.get)
}
- else Error("Yaml file doesn't exist!")
+ else error("Yaml file doesn't exist!")
}
}
@@ -233,7 +230,7 @@ class YamlFileWriter(file: String) {
def append(in: YamlValue) = {
outputBuffer.append(separator + in.prettyPrint)
}
- def serialize = {
+ def dump = {
val outputFile = new java.io.PrintWriter(file)
outputFile.write(outputBuffer.toString)
outputFile.close()
@@ -249,12 +246,11 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
// TODO: Consider splitting InferRW to analysis + actual optimization pass, in case sp doesn't exist
// TODO: Don't get first available?
case class SRAMCompilerSet(
- sp: Option[SRAMCompiler] = None,
- dp: Option[SRAMCompiler] = None
- ){
- def serialize() = {
- if (sp != None) sp.get.serialize()
- if (dp != None) dp.get.serialize()
+ sp: Option[SRAMCompiler] = None,
+ dp: Option[SRAMCompiler] = None) {
+ def serialize = {
+ if (sp != None) sp.get.serialize
+ if (dp != None) dp.get.serialize
}
}
val sramCompilers = {
@@ -270,18 +266,18 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
def run(c: Circuit) = {
def annotateModMems(m: Module) = {
def updateStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info,"useMacro") => {
+ case m: DefMemory if containsInfo(m.info, "useMacro") => {
if (sramCompilers == None) m
else {
if (m.readwriters.length == 1)
- if (sramCompilers.get.sp == None) Error("Design needs RW port memory compiler!")
+ if (sramCompilers.get.sp == None) error("Design needs RW port memory compiler!")
else sramCompilers.get.sp.get.append(m)
else
- if (sramCompilers.get.dp == None) Error("Design needs R,W port memory compiler!")
+ if (sramCompilers.get.dp == None) error("Design needs R,W port memory compiler!")
else sramCompilers.get.dp.get.append(m)
}
}
- case b: Block => Block(b.stmts map updateStmts)
+ case b: Block => b map updateStmts
case s => s
}
m.copy(body=updateStmts(m.body))
@@ -292,4 +288,4 @@ class AnnotateValidMemConfigs(reader: Option[YamlFileReader]) extends Pass {
}
c.copy(modules = updatedMods)
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/passes/MemUtils.scala b/src/main/scala/firrtl/passes/MemUtils.scala
index 97915194..b235213a 100644
--- a/src/main/scala/firrtl/passes/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/MemUtils.scala
@@ -41,22 +41,22 @@ object seqCat {
case 2 => DoPrim(PrimOps.Cat, args, Seq.empty[BigInt], UIntType(UnknownWidth))
case _ => {
val seqs = args.splitAt(args.length/2)
- DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1),seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth))
+ DoPrim(PrimOps.Cat, Seq(seqCat(seqs._1), seqCat(seqs._2)), Seq.empty[BigInt], UIntType(UnknownWidth))
}
}
}
object toBits {
def apply(e: Expression): Expression = e match {
- case ex: WRef => hiercat(ex,ex.tpe)
- case ex: WSubField => hiercat(ex,ex.tpe)
- case ex: WSubIndex => hiercat(ex,ex.tpe)
+ case ex: WRef => hiercat(ex, ex.tpe)
+ case ex: WSubField => hiercat(ex, ex.tpe)
+ case ex: WSubIndex => hiercat(ex, ex.tpe)
case t => error("Invalid operand expression for toBits!")
}
def hiercat(e: Expression, dt: Type): Expression = dt match {
- case t:VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER),t.tpe)))
- case t:BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe)))
- case t:GroundType => e
+ case t: VectorType => seqCat((0 until t.size).reverse.map(i => hiercat(WSubIndex(e, i, t.tpe, UNKNOWNGENDER), t.tpe)))
+ case t: BundleType => seqCat(t.fields.map(f => hiercat(WSubField(e, f.name, f.tpe, UNKNOWNGENDER), f.tpe)))
+ case t: GroundType => e
case t => error("Unknown type encountered in toBits!")
}
}
@@ -64,16 +64,16 @@ object toBits {
// TODO: make easier to understand
object toBitMask {
def apply(e: Expression, dataType: Type): Expression = e match {
- case ex: WRef => hiermask(ex,ex.tpe,dataType)
- case ex: WSubField => hiermask(ex,ex.tpe,dataType)
- case ex: WSubIndex => hiermask(ex,ex.tpe,dataType)
+ case ex: WRef => hiermask(ex, ex.tpe, dataType)
+ case ex: WSubField => hiermask(ex, ex.tpe, dataType)
+ case ex: WSubIndex => hiermask(ex, ex.tpe, dataType)
case t => error("Invalid operand expression for toBits!")
}
def hiermask(e: Expression, maskType: Type, dataType: Type): Expression = (maskType, dataType) match {
- case (mt:VectorType, dt:VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe)))
- case (mt:BundleType, dt:BundleType) => seqCat((mt.fields zip dt.fields).map{ case (mf,df) =>
- hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) })
- case (mt:UIntType, dt:GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e))
+ case (mt: VectorType, dt: VectorType) => seqCat((0 until mt.size).reverse.map(i => hiermask(WSubIndex(e, i, mt.tpe, UNKNOWNGENDER), mt.tpe, dt.tpe)))
+ case (mt: BundleType, dt: BundleType) => seqCat((mt.fields zip dt.fields).map { case (mf, df) =>
+ hiermask(WSubField(e, mf.name, mf.tpe, UNKNOWNGENDER), mf.tpe, df.tpe) } )
+ case (mt: UIntType, dt: GroundType) => seqCat(List.fill(bitWidth(dt).intValue)(e))
case (mt, dt) => error("Invalid type for mask component!")
}
}
@@ -81,8 +81,8 @@ object toBitMask {
object bitWidth {
def apply(dt: Type): BigInt = widthOf(dt)
def widthOf(dt: Type): BigInt = dt match {
- case t:VectorType => t.size * bitWidth(t.tpe)
- case t:BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_)
+ case t: VectorType => t.size * bitWidth(t.tpe)
+ case t: BundleType => t.fields.map(f => bitWidth(f.tpe)).foldLeft(BigInt(0))(_+_)
case UIntType(IntWidth(width)) => width
case SIntType(IntWidth(width)) => width
case t => error("Unknown type encountered in bitWidth!")
@@ -101,12 +101,12 @@ object fromBits {
}
def getPartGround(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = {
val intWidth = bitWidth(lhst)
- val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1,offset), UnknownType)
- (offset + intWidth, Seq(Connect(NoInfo,lhs,sel)))
+ val sel = DoPrim(PrimOps.Bits, Seq(rhs), Seq(offset+intWidth-1, offset), UnknownType)
+ (offset + intWidth, Seq(Connect(NoInfo, lhs, sel)))
}
def getPart(lhs: Expression, lhst: Type, rhs: Expression, offset: BigInt): (BigInt, Seq[Statement]) = {
lhst match {
- case t:VectorType => {
+ case t: VectorType => {
var currentOffset = offset
var stmts = Seq.empty[Statement]
for (i <- (0 until t.size)) {
@@ -116,7 +116,7 @@ object fromBits {
}
(currentOffset, stmts)
}
- case t:BundleType => {
+ case t: BundleType => {
var currentOffset = offset
var stmts = Seq.empty[Statement]
for (f <- t.fields.reverse) {
@@ -126,7 +126,7 @@ object fromBits {
}
(currentOffset, stmts)
}
- case t:GroundType => getPartGround(lhs, t, rhs, offset)
+ case t: GroundType => getPartGround(lhs, t, rhs, offset)
case t => error("Unknown type encountered in fromBits!")
}
}
@@ -145,7 +145,7 @@ object MemPortUtils {
)
def getFillWMask(mem: DefMemory) = {
- val maskGran = getInfo(mem.info,"maskGran")
+ val maskGran = getInfo(mem.info, "maskGran")
if (maskGran == None) false
else maskGran.get == 1
}
@@ -156,7 +156,7 @@ object MemPortUtils {
def wPortToBundle(mem: DefMemory) = {
val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, mem.dataType)
BundleType(
- if (containsInfo(mem.info,"maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType))
+ if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("mask", Default, create_mask(mem.dataType))
else defaultSeq
)
}
@@ -164,7 +164,7 @@ object MemPortUtils {
def wPortToFlattenBundle(mem: DefMemory) = {
val defaultSeq = defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))
BundleType(
- if (containsInfo(mem.info,"maskGran")) {
+ if (containsInfo(mem.info, "maskGran")) {
defaultSeq :+ {
if (getFillWMask(mem)) Field("mask", Default, flattenType(mem.dataType))
else Field("mask", Default, flattenType(create_mask(mem.dataType)))
@@ -175,26 +175,26 @@ object MemPortUtils {
}
// TODO: Don't use create_mask???
- def rwPortToBundle(mem: DefMemory) ={
+ def rwPortToBundle(mem: DefMemory) = {
val defaultSeq = defaultPortSeq(mem) ++ Seq(
Field("wmode", Default, UIntType(IntWidth(1))),
Field("wdata", Default, mem.dataType),
Field("rdata", Flip, mem.dataType)
)
BundleType(
- if (containsInfo(mem.info,"maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType))
+ if (containsInfo(mem.info, "maskGran")) defaultSeq :+ Field("wmask", Default, create_mask(mem.dataType))
else defaultSeq
)
}
- def rwPortToFlattenBundle(mem: DefMemory) ={
+ def rwPortToFlattenBundle(mem: DefMemory) = {
val defaultSeq = defaultPortSeq(mem) ++ Seq(
Field("wmode", Default, UIntType(IntWidth(1))),
Field("wdata", Default, flattenType(mem.dataType)),
Field("rdata", Flip, flattenType(mem.dataType))
)
BundleType(
- if (containsInfo(mem.info,"maskGran")) {
+ if (containsInfo(mem.info, "maskGran")) {
defaultSeq :+ {
if (getFillWMask(mem)) Field("wmask", Default, flattenType(mem.dataType))
else Field("wmask", Default, flattenType(create_mask(mem.dataType)))
diff --git a/src/main/scala/firrtl/passes/ReplSeqMem.scala b/src/main/scala/firrtl/passes/ReplSeqMem.scala
index a1b83efc..ae842a0b 100644
--- a/src/main/scala/firrtl/passes/ReplSeqMem.scala
+++ b/src/main/scala/firrtl/passes/ReplSeqMem.scala
@@ -1,3 +1,5 @@
+// See LICENSE for license details.
+
package firrtl.passes
import com.typesafe.scalalogging.LazyLogging
@@ -12,16 +14,6 @@ case object InputConfigFileName extends PassOption
case object OutputConfigFileName extends PassOption
case object PassCircuitName extends PassOption
-object Error {
- def apply(msg: String) = throw new Exception(msg)
-}
-object Warn {
- def apply[T <: Any](msg: String, r: T) = {
- println(Console.RED + msg + Console.RESET)
- r
- }
-}
-
object PassConfigUtil {
def getPassOptions(t: String, usage: String = "") = {
@@ -41,7 +33,7 @@ object PassConfigUtil {
case "-c" :: value :: tail =>
nextPassOption(map + (PassCircuitName -> value), tail)
case option :: tail =>
- Error("Unknown option " + option + usage)
+ error("Unknown option " + option + usage)
}
}
nextPassOption(Map[PassOption, String](), passArgList)
@@ -53,7 +45,7 @@ class ConfWriter(filename: String) {
val outputBuffer = new java.io.CharArrayWriter
def append(m: DefMemory) = {
// legacy
- val maskGran = getInfo(m.info,"maskGran")
+ val maskGran = getInfo(m.info, "maskGran")
val writers = m.writers map (x => if (maskGran == None) "write" else "mwrite")
val readers = List.fill(m.readers.length)("read")
val readwriters = m.readwriters map (x => if (maskGran == None) "rw" else "mrw")
@@ -89,17 +81,17 @@ Optional Arguments:
-i<filename> Specify the input configuration file (for additional optimizations)
"""
- val passOptions = PassConfigUtil.getPassOptions(t,usage)
+ val passOptions = PassConfigUtil.getPassOptions(t, usage)
val outputConfig = passOptions.getOrElse(
OutputConfigFileName,
- Error("No output config file provided for ReplSeqMem!" + usage)
+ error("No output config file provided for ReplSeqMem!" + usage)
)
val passCircuit = passOptions.getOrElse(
PassCircuitName,
- Error("No circuit name specified for ReplSeqMem!" + usage)
+ error("No circuit name specified for ReplSeqMem!" + usage)
)
val target = CircuitName(passCircuit)
- def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit,"-c:"+n.name))
+ def duplicate(n: Named) = this.copy(t=t.replace("-c:"+passCircuit, "-c:"+n.name))
}
@@ -109,11 +101,11 @@ class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
case Some(p) => p get CircuitName(circuit.main) match {
case Some(ReplSeqMemAnnotation(t, _)) => {
- val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName,"")
+ val inputFileName = PassConfigUtil.getPassOptions(t).getOrElse(InputConfigFileName, "")
val inConfigFile = {
if (inputFileName.isEmpty) None
else if (new java.io.File(inputFileName).exists) Some(new YamlFileReader(inputFileName))
- else Error("Input configuration file does not exist!")
+ else error("Input configuration file does not exist!")
}
val outConfigFile = new ConfWriter(PassConfigUtil.getPassOptions(t).get(OutputConfigFileName).get)
@@ -131,17 +123,17 @@ class ReplSeqMem(transID: TransID) extends Transform with LazyLogging {
InferTypes,
ResolveGenders
) foldLeft circuit
- ){
+ ) {
(c, pass) =>
val x = Utils.time(pass.name)(pass run c)
logger debug x.serialize
x
- },
+ } ,
None,
Some(map)
)
}
- case _ => TransformResult(circuit, None, Some(map))
+ case _ => error("Unexpected transform annotation")
}
case _ => TransformResult(circuit, None, Some(map))
}
diff --git a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
index 94be10e7..54c522d7 100644
--- a/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/ReplaceMemMacros.scala
@@ -1,12 +1,15 @@
+// See LICENSE for license details.
+
package firrtl.passes
-import scala.collection.mutable.{HashMap,ArrayBuffer}
+import scala.collection.mutable
import firrtl.ir._
import AnalysisUtils._
import MemTransformUtils._
import firrtl._
import firrtl.Utils._
import MemPortUtils._
+import firrtl.Mappers._
class ReplaceMemMacros(writer: ConfWriter) extends Pass {
@@ -15,41 +18,41 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
def run(c: Circuit) = {
lazy val moduleNamespace = Namespace(c)
- val memMods = ArrayBuffer[DefModule]()
- val uniqueMems = ArrayBuffer[DefMemory]()
+ val memMods = mutable.ArrayBuffer[DefModule]()
+ val uniqueMems = mutable.ArrayBuffer[DefMemory]()
def updateMemMods(m: Module) = {
- val memPortMap = HashMap[String,Expression]()
+ val memPortMap = mutable.HashMap[String, Expression]()
def updateMemStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info,"useMacro") =>
- if(!containsInfo(m.info,"maskGran")) {
- m.writers foreach {w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression}
- m.readwriters foreach {w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression}
+ case m: DefMemory if containsInfo(m.info, "useMacro") =>
+ if(!containsInfo(m.info, "maskGran")) {
+ m.writers foreach { w => memPortMap(s"${m.name}.${w}.mask") = EmptyExpression }
+ m.readwriters foreach { w => memPortMap(s"${m.name}.${w}.wmask") = EmptyExpression }
}
- val infoT = getInfo(m.info,"info")
- val info = if (infoT == None) NoInfo else infoT.get match {case i: Info => i}
- val ref = getInfo(m.info,"ref")
+ val infoT = getInfo(m.info, "info")
+ val info = if (infoT == None) NoInfo else infoT.get match { case i: Info => i }
+ val ref = getInfo(m.info, "ref")
// prototype mem
if (ref == None) {
val newWrapperName = moduleNamespace.newName(m.name)
val newMemBBName = moduleNamespace.newName(m.name + "_ext")
val newMem = m.copy(name = newMemBBName)
- memMods ++= createMemModule(newMem,newWrapperName)
+ memMods ++= createMemModule(newMem, newWrapperName)
uniqueMems += newMem
WDefInstance(info, m.name, newWrapperName, UnknownType)
}
else {
- val r = ref.get match {case s: String => s}
+ val r = ref.get match { case s: String => s }
WDefInstance(info, m.name, r, UnknownType)
}
- case b: Block => Block(b.stmts map updateMemStmts)
+ case b: Block => b map updateMemStmts
case s => s
}
val updatedMems = updateMemStmts(m.body)
- val updatedConns = updateStmtRefs(updatedMems,memPortMap.toMap)
+ val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap)
m.copy(body = updatedConns)
}
@@ -66,64 +69,64 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
// from Albert
def createMemModule(m: DefMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
- val stmts = ArrayBuffer[Statement]()
+ val stmts = mutable.ArrayBuffer[Statement]()
val wrapperioPorts = MemPortUtils.memToBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
val bbProto = m.copy(dataType = flattenType(m.dataType))
val bbioPorts = MemPortUtils.memToFlattenBundle(m).fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
- stmts += WDefInstance(NoInfo,m.name,m.name,UnknownType)
+ stmts += WDefInstance(NoInfo, m.name, m.name, UnknownType)
val bbRef = createRef(m.name)
- stmts ++= (m.readers zip bbProto.readers).map{
- case (x,y) => adaptReader(createRef(x),m,createSubField(bbRef,y),bbProto)
- }.flatten
- stmts ++= (m.writers zip bbProto.writers).map{
- case (x,y) => adaptWriter(createRef(x),m,createSubField(bbRef,y),bbProto)
- }.flatten
- stmts ++= (m.readwriters zip bbProto.readwriters).map{
- case (x,y) => adaptReadWriter(createRef(x),m,createSubField(bbRef,y),bbProto)
- }.flatten
- val wrapper = Module(NoInfo,wrapperName,wrapperioPorts,Block(stmts))
- val bb = ExtModule(NoInfo,m.name,bbioPorts)
+ stmts ++= (m.readers zip bbProto.readers).flatMap {
+ case (x, y) => adaptReader(createRef(x), m, createSubField(bbRef, y), bbProto)
+ }
+ stmts ++= (m.writers zip bbProto.writers).flatMap {
+ case (x, y) => adaptWriter(createRef(x), m, createSubField(bbRef, y), bbProto)
+ }
+ stmts ++= (m.readwriters zip bbProto.readwriters).flatMap {
+ case (x, y) => adaptReadWriter(createRef(x), m, createSubField(bbRef, y), bbProto)
+ }
+ val wrapper = Module(NoInfo, wrapperName, wrapperioPorts, Block(stmts))
+ val bb = ExtModule(NoInfo, m.name, bbioPorts)
// TODO: Annotate? -- use actual annotation map
// add to conf file
writer.append(m)
- Seq(bb,wrapper)
+ Seq(bb, wrapper)
}
// TODO: get rid of copy pasta
def adaptReader(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = Seq(
- connectFields(bbPort,"addr",wrapperPort,"addr"),
- connectFields(bbPort,"en",wrapperPort,"en"),
- connectFields(bbPort,"clk",wrapperPort,"clk"),
+ connectFields(bbPort, "addr", wrapperPort, "addr"),
+ connectFields(bbPort, "en", wrapperPort, "en"),
+ connectFields(bbPort, "clk", wrapperPort, "clk"),
fromBits(
- WSubField(wrapperPort,"data",wrapperMem.dataType,UNKNOWNGENDER),
- WSubField(bbPort,"data",bbMem.dataType,UNKNOWNGENDER)
+ WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER),
+ WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER)
)
)
def adaptWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = {
val defaultSeq = Seq(
- connectFields(bbPort,"addr",wrapperPort,"addr"),
- connectFields(bbPort,"en",wrapperPort,"en"),
- connectFields(bbPort,"clk",wrapperPort,"clk"),
+ connectFields(bbPort, "addr", wrapperPort, "addr"),
+ connectFields(bbPort, "en", wrapperPort, "en"),
+ connectFields(bbPort, "clk", wrapperPort, "clk"),
Connect(
NoInfo,
- WSubField(bbPort,"data",bbMem.dataType,UNKNOWNGENDER),
- toBits(WSubField(wrapperPort,"data",wrapperMem.dataType,UNKNOWNGENDER))
+ WSubField(bbPort, "data", bbMem.dataType, UNKNOWNGENDER),
+ toBits(WSubField(wrapperPort, "data", wrapperMem.dataType, UNKNOWNGENDER))
)
)
- if (containsInfo(wrapperMem.info,"maskGran")) {
+ if (containsInfo(wrapperMem.info, "maskGran")) {
val wrapperMask = create_mask(wrapperMem.dataType)
val fillWMask = getFillWMask(wrapperMem)
val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
val rhs = {
- if (fillWMask) toBitMask(WSubField(wrapperPort,"mask",wrapperMask,UNKNOWNGENDER),wrapperMem.dataType)
- else toBits(WSubField(wrapperPort,"mask",wrapperMask,UNKNOWNGENDER))
+ if (fillWMask) toBitMask(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType)
+ else toBits(WSubField(wrapperPort, "mask", wrapperMask, UNKNOWNGENDER))
}
defaultSeq :+ Connect(
NoInfo,
- WSubField(bbPort,"mask",bbMask,UNKNOWNGENDER),
+ WSubField(bbPort, "mask", bbMask, UNKNOWNGENDER),
rhs
)
}
@@ -132,31 +135,31 @@ class ReplaceMemMacros(writer: ConfWriter) extends Pass {
def adaptReadWriter(wrapperPort: Expression, wrapperMem: DefMemory, bbPort: Expression, bbMem: DefMemory) = {
val defaultSeq = Seq(
- connectFields(bbPort,"addr",wrapperPort,"addr"),
- connectFields(bbPort,"en",wrapperPort,"en"),
- connectFields(bbPort,"clk",wrapperPort,"clk"),
- connectFields(bbPort,"wmode",wrapperPort,"wmode"),
+ connectFields(bbPort, "addr", wrapperPort, "addr"),
+ connectFields(bbPort, "en", wrapperPort, "en"),
+ connectFields(bbPort, "clk", wrapperPort, "clk"),
+ connectFields(bbPort, "wmode", wrapperPort, "wmode"),
Connect(
NoInfo,
- WSubField(bbPort,"wdata",bbMem.dataType,UNKNOWNGENDER),
- toBits(WSubField(wrapperPort,"wdata",wrapperMem.dataType,UNKNOWNGENDER))
+ WSubField(bbPort, "wdata", bbMem.dataType, UNKNOWNGENDER),
+ toBits(WSubField(wrapperPort, "wdata", wrapperMem.dataType, UNKNOWNGENDER))
),
fromBits(
- WSubField(wrapperPort,"rdata",wrapperMem.dataType,UNKNOWNGENDER),
- WSubField(bbPort,"rdata",bbMem.dataType,UNKNOWNGENDER)
+ WSubField(wrapperPort, "rdata", wrapperMem.dataType, UNKNOWNGENDER),
+ WSubField(bbPort, "rdata", bbMem.dataType, UNKNOWNGENDER)
)
)
- if (containsInfo(wrapperMem.info,"maskGran")) {
+ if (containsInfo(wrapperMem.info, "maskGran")) {
val wrapperMask = create_mask(wrapperMem.dataType)
val fillWMask = getFillWMask(wrapperMem)
val bbMask = if (fillWMask) flattenType(wrapperMem.dataType) else flattenType(wrapperMask)
val rhs = {
- if (fillWMask) toBitMask(WSubField(wrapperPort,"wmask",wrapperMask,UNKNOWNGENDER),wrapperMem.dataType)
- else toBits(WSubField(wrapperPort,"wmask",wrapperMask,UNKNOWNGENDER))
+ if (fillWMask) toBitMask(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER), wrapperMem.dataType)
+ else toBits(WSubField(wrapperPort, "wmask", wrapperMask, UNKNOWNGENDER))
}
defaultSeq :+ Connect(
NoInfo,
- WSubField(bbPort,"wmask",bbMask,UNKNOWNGENDER),
+ WSubField(bbPort, "wmask", bbMask, UNKNOWNGENDER),
rhs
)
}
diff --git a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
index a4a910fd..d71b8ab8 100644
--- a/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/UpdateDuplicateMemMacros.scala
@@ -1,6 +1,8 @@
+// See LICENSE for license details.
+
package firrtl.passes
-import scala.collection.mutable.{HashMap,ArrayBuffer}
+import scala.collection.mutable
import AnalysisUtils._
import MemTransformUtils._
import firrtl.ir._
@@ -10,27 +12,27 @@ import firrtl.Utils._
object MemTransformUtils {
- def createRef(n: String) = WRef(n,UnknownType,ExpKind(),UNKNOWNGENDER)
- def createSubField(exp: Expression, n: String) = WSubField(exp,n,UnknownType,UNKNOWNGENDER)
+ def createRef(n: String) = WRef(n, UnknownType, ExpKind(), UNKNOWNGENDER)
+ def createSubField(exp: Expression, n: String) = WSubField(exp, n, UnknownType, UNKNOWNGENDER)
def connectFields(lref: Expression, lname: String, rref: Expression, rname: String) =
- Connect(NoInfo,createSubField(lref,lname),createSubField(rref,rname))
+ Connect(NoInfo, createSubField(lref, lname), createSubField(rref, rname))
def getMemPortMap(m: DefMemory) = {
- val memPortMap = HashMap[String,Expression]()
- val defaultFields = Seq("addr","en","clk")
+ val memPortMap = mutable.HashMap[String, Expression]()
+ val defaultFields = Seq("addr", "en", "clk")
val rFields = defaultFields :+ "data"
val wFields = rFields :+ "mask"
- val rwFields = defaultFields ++ Seq("wmode","wdata","rdata","wmask")
+ val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
def updateMemPortMap(ports: Seq[String], fields: Seq[String], portType: String) =
- for (p <- ports.zipWithIndex; f <- fields) {
- val newPort = createSubField(createRef(m.name),portType+p._2)
- val field = createSubField(newPort,f)
- memPortMap(s"${m.name}.${p._1}.${f}") = field
+ for ((p, i) <- ports.zipWithIndex; f <- fields) {
+ val newPort = createSubField(createRef(m.name), portType+i)
+ val field = createSubField(newPort, f)
+ memPortMap(s"${m.name}.${p}.${f}") = field
}
- updateMemPortMap(m.readers,rFields,"R")
- updateMemPortMap(m.writers,wFields,"W")
- updateMemPortMap(m.readwriters,rwFields,"RW")
+ updateMemPortMap(m.readers, rFields, "R")
+ updateMemPortMap(m.writers, wFields, "W")
+ updateMemPortMap(m.readwriters, rwFields, "RW")
memPortMap.toMap
}
def createMemProto(m: DefMemory) = {
@@ -40,14 +42,14 @@ object MemTransformUtils {
m.copy(readers = rports, writers = wports, readwriters = rwports)
}
- def updateStmtRefs(s: Statement, repl: Map[String,Expression]): Statement = {
+ def updateStmtRefs(s: Statement, repl: Map[String, Expression]): Statement = {
def updateRef(e: Expression): Expression = e map updateRef match {
- case e: WSubField => repl getOrElse (e.serialize,e)
+ case e: WSubField => repl getOrElse (e.serialize, e)
case e => e
}
def updateStmtRefs(s: Statement): Statement = s map updateStmtRefs map updateRef match {
- case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt
- case Connect(info, WSubIndex(EmptyExpression,_,_,_), exp) => EmptyStmt
+ case Connect(info, EmptyExpression, exp) => EmptyStmt
+ case Connect(info, WSubIndex(EmptyExpression, _, _, _), exp) => EmptyStmt
case s => s
}
updateStmtRefs(s)
@@ -60,27 +62,27 @@ object UpdateDuplicateMemMacros extends Pass {
def name = "Convert memory port names to be more meaningful and tag duplicate memories"
def run(c: Circuit) = {
- val uniqueMems = ArrayBuffer[DefMemory]()
+ val uniqueMems = mutable.ArrayBuffer[DefMemory]()
def updateMemMods(m: Module) = {
- val memPortMap = HashMap[String,Expression]()
+ val memPortMap = mutable.HashMap[String, Expression]()
def updateMemStmts(s: Statement): Statement = s match {
- case m: DefMemory if containsInfo(m.info,"useMacro") =>
+ case m: DefMemory if containsInfo(m.info, "useMacro") =>
val updatedMem = createMemProto(m)
memPortMap ++= getMemPortMap(m)
- val proto = uniqueMems find (x => eqMems(x,updatedMem))
+ val proto = uniqueMems find (x => eqMems(x, updatedMem))
if (proto == None) {
uniqueMems += updatedMem
updatedMem
}
- else updatedMem.copy(info = appendInfo(updatedMem.info,"ref" -> proto.get.name))
- case b: Block => Block(b.stmts map updateMemStmts)
+ else updatedMem.copy(info = appendInfo(updatedMem.info, "ref" -> proto.get.name))
+ case b: Block => b map updateMemStmts
case s => s
}
val updatedMems = updateMemStmts(m.body)
- val updatedConns = updateStmtRefs(updatedMems,memPortMap.toMap)
+ val updatedConns = updateStmtRefs(updatedMems, memPortMap.toMap)
m.copy(body = updatedConns)
}
@@ -92,4 +94,4 @@ object UpdateDuplicateMemMacros extends Pass {
}
}
-// TODO: Module namespace? \ No newline at end of file
+// TODO: Module namespace?
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index 57a274c2..54ef6003 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -17,6 +17,51 @@ class ReplSeqMemSpec extends SimpleTransformSpec {
new EmitFirrtl(writer)
)
+ "ReplSeqMem Utility -- getConnectOrigin" should
+ "determine connect origin across nodes/PrimOps even if ConstProp isn't performed" in {
+ def checkConnectOrigin(hurdle: String, origin: String) = {
+ val input = s"""
+circuit Top :
+ module Top :
+ input a: UInt<1>
+ input b: UInt<1>
+ input e: UInt<1>
+ output c: UInt<1>
+ output f: UInt<1>
+ node d = $hurdle
+ c <= d
+ f <= c
+""".stripMargin
+
+ val circuit = InferTypes.run(ToWorkingIR.run(parse(input)))
+ val m = circuit.modules.head.asInstanceOf[ir.Module]
+ val connects = AnalysisUtils.getConnects(m)
+ val calculatedOrigin = AnalysisUtils.getConnectOrigin(connects,"f").serialize
+ require(calculatedOrigin == origin, s"getConnectOrigin returns incorrect origin $calculatedOrigin !")
+ }
+
+ val tests = List(
+ """mux(a, UInt<1>("h1"), UInt<1>("h0"))""" -> "a",
+ """mux(UInt<1>("h1"), a, b)""" -> "a",
+ """mux(UInt<1>("h0"), a, b)""" -> "b",
+ "mux(b, a, a)" -> "a",
+ """mux(a, a, UInt<1>("h0"))""" -> "a",
+ "mux(a, b, e)" -> "mux(a, b, e)",
+ """or(a, UInt<1>("h1"))""" -> """UInt<1>("h1")""",
+ """and(a, UInt<1>("h0"))""" -> """UInt<1>("h0")""",
+ """UInt<1>("h1")""" -> """UInt<1>("h1")""",
+ "asUInt(a)" -> "a",
+ "asSInt(a)" -> "a",
+ "asClock(a)" -> "a",
+ "a" -> "a",
+ "or(a, b)" -> "or(a, b)",
+ "bits(a, 0, 0)" -> "a"
+ )
+
+ tests.foreach{ case(hurdle, origin) => checkConnectOrigin(hurdle, origin) }
+
+ }
+
"ReplSeqMem" should "generate blackbox wrappers (no wmask, r, w ports)" in {
val input = """
circuit sram6t :
@@ -116,7 +161,7 @@ circuit sram6t :
val writer = new java.io.StringWriter
execute(writer, aMap, input, check)
val confOut = read(confLoc)
- require(confOut==checkConf,"Conf file incorrect!")
+ require(confOut==checkConf, "Conf file incorrect!")
(new java.io.File(confLoc)).delete()
}
}