aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib')
-rw-r--r--src/main/scala/firrtl/passes/memlib/DecorateMems.scala5
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala101
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemConf.scala65
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemIR.scala56
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemLibOptions.scala3
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala16
-rw-r--r--src/main/scala/firrtl/passes/memlib/MemUtils.scala69
-rw-r--r--src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala43
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala118
-rw-r--r--src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala50
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala36
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala12
-rw-r--r--src/main/scala/firrtl/passes/memlib/ToMemIR.scala9
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala136
-rw-r--r--src/main/scala/firrtl/passes/memlib/YamlUtils.scala15
15 files changed, 399 insertions, 335 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
index 14bd9e44..d237c36a 100644
--- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
+++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala
@@ -19,8 +19,9 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform
import CustomYAMLProtocol._
val configs = r.parse[Config]
val oldAnnos = state.annotations
- val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) =>
- (annos, pins :+ config.pin.name)
+ val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) {
+ case ((annos, pins), config) =>
+ (annos, pins :+ config.pin.name)
}
state.copy(annotations = PinAnnotation(pins.toSeq) +: as)
}
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 4847a698..e290633e 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -10,12 +10,11 @@ import firrtl.PrimOps._
import firrtl.Utils.{one, zero, BoolType}
import firrtl.options.{HasShellOptions, ShellOption}
import MemPortUtils.memPortField
-import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin}
+import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects}
import WrappedExpression.weq
import annotations._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
-
case object InferReadWriteAnnotation extends NoTargetAnnotation
// This pass examine the enable signals of the read & write ports of memories
@@ -40,12 +39,13 @@ object InferReadWritePass extends Pass {
getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval)
// Visit each term of AND operation
case DoPrim(op, args, consts, tpe) if op == And =>
- e +: (args flatMap getProductTerms(connects))
+ e +: (args.flatMap(getProductTerms(connects)))
// Visit connected nodes to references
- case _: WRef | _: WSubField | _: WSubIndex => connects get e match {
- case None => Seq(e)
- case Some(ex) => e +: getProductTerms(connects)(ex)
- }
+ case _: WRef | _: WSubField | _: WSubIndex =>
+ connects.get(e) match {
+ case None => Seq(e)
+ case Some(ex) => e +: getProductTerms(connects)(ex)
+ }
// Otherwise just return itself
case _ => Seq(e)
}
@@ -58,96 +58,103 @@ object InferReadWritePass extends Pass {
// b ?= Eq(a, 0) or b ?= Eq(0, a)
case (_, DoPrim(Eq, args, _, _)) =>
weq(args.head, a) && weq(args(1), zero) ||
- weq(args(1), a) && weq(args.head, zero)
+ weq(args(1), a) && weq(args.head, zero)
// a ?= Eq(b, 0) or b ?= Eq(0, a)
case (DoPrim(Eq, args, _, _), _) =>
weq(args.head, b) && weq(args(1), zero) ||
- weq(args(1), b) && weq(args.head, zero)
+ weq(args(1), b) && weq(args.head, zero)
case _ => false
}
-
def replaceExp(repl: Netlist)(e: Expression): Expression =
- e map replaceExp(repl) match {
- case ex: WSubField => repl getOrElse (ex.serialize, ex)
+ e.map(replaceExp(repl)) match {
+ case ex: WSubField => repl.getOrElse(ex.serialize, ex)
case ex => ex
}
def replaceStmt(repl: Netlist)(s: Statement): Statement =
- s map replaceStmt(repl) map replaceExp(repl) match {
+ s.map(replaceStmt(repl)).map(replaceExp(repl)) match {
case Connect(_, EmptyExpression, _) => EmptyStmt
- case sx => sx
+ case sx => sx
}
- def inferReadWriteStmt(connects: Connects,
- repl: Netlist,
- stmts: Statements)
- (s: Statement): Statement = s match {
+ def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match {
// infer readwrite ports only for non combinational memories
case mem: DefMemory if mem.readLatency > 0 =>
val readers = new PortSet
val writers = new PortSet
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
- for (w <- mem.writers ; r <- mem.readers) {
+ for {
+ w <- mem.writers
+ r <- mem.readers
+ } {
val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
- val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
- val rw = namespace newName "rw"
+ val rw = namespace.newName("rw")
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
readers += r
writers += w
- repl(memPortField(mem, r, "clk")) = EmptyExpression
- repl(memPortField(mem, r, "en")) = EmptyExpression
+ repl(memPortField(mem, r, "clk")) = EmptyExpression
+ repl(memPortField(mem, r, "en")) = EmptyExpression
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
- repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = EmptyExpression
+ repl(memPortField(mem, w, "clk")) = EmptyExpression
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
- stmts += Connect(NoInfo, WSubField(rwExp, "en"),
- DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
- connects(memPortField(mem, w, "en"))), Nil, BoolType))
- stmts += Connect(NoInfo, WSubField(rwExp, "addr"),
- Mux(connects(memPortField(mem, w, "en")),
- connects(memPortField(mem, w, "addr")),
- connects(memPortField(mem, r, "addr")), UnknownType))
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "en"),
+ DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType)
+ )
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "addr"),
+ Mux(
+ connects(memPortField(mem, w, "en")),
+ connects(memPortField(mem, w, "addr")),
+ connects(memPortField(mem, r, "addr")),
+ UnknownType
+ )
+ )
}
}
- if (readwriters.isEmpty) mem else mem copy (
- readers = mem.readers filterNot readers,
- writers = mem.writers filterNot writers,
- readwriters = mem.readwriters ++ readwriters)
- case sx => sx map inferReadWriteStmt(connects, repl, stmts)
+ if (readwriters.isEmpty) mem
+ else
+ mem.copy(
+ readers = mem.readers.filterNot(readers),
+ writers = mem.writers.filterNot(writers),
+ readwriters = mem.readwriters ++ readwriters
+ )
+ case sx => sx.map(inferReadWriteStmt(connects, repl, stmts))
}
def inferReadWrite(m: DefModule) = {
val connects = getConnects(m)
val repl = new Netlist
val stmts = new Statements
- (m map inferReadWriteStmt(connects, repl, stmts)
- map replaceStmt(repl)) match {
+ (m.map(inferReadWriteStmt(connects, repl, stmts))
+ .map(replaceStmt(repl))) match {
case m: ExtModule => m
- case m: Module => m copy (body = Block(m.body +: stmts.toSeq))
+ case m: Module => m.copy(body = Block(m.body +: stmts.toSeq))
}
}
- def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite))
}
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite extends Transform
- with DependencyAPIMigration
- with SeqTransformBased
- with HasShellOptions {
+class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions {
override def prerequisites = Forms.MidForm
override def optionalPrerequisites = Seq.empty
@@ -159,7 +166,9 @@ class InferReadWrite extends Transform
longOption = "infer-rw",
toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)),
helpText = "Enable read/write port inference for memories",
- shortOption = Some("firw") ) )
+ shortOption = Some("firw")
+ )
+ )
def transforms = Seq(
InferReadWritePass,
diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala
index 3809c47c..871a1093 100644
--- a/src/main/scala/firrtl/passes/memlib/MemConf.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala
@@ -3,7 +3,6 @@
package firrtl.passes
package memlib
-
sealed abstract class MemPort(val name: String) { override def toString = name }
case object ReadPort extends MemPort("read")
@@ -19,22 +18,27 @@ object MemPort {
def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s)
def fromString(s: String): Map[MemPort, Int] = {
- s.split(",").toSeq.map(MemPort.apply).map(_ match {
- case Some(x) => x
- case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
- }).groupBy(identity).mapValues(_.size).toMap
+ s.split(",")
+ .toSeq
+ .map(MemPort.apply)
+ .map(_ match {
+ case Some(x) => x
+ case _ => throw new Exception(s"Error parsing MemPort string : ${s}")
+ })
+ .groupBy(identity)
+ .mapValues(_.size)
+ .toMap
}
}
case class MemConf(
- name: String,
- depth: BigInt,
- width: Int,
- ports: Map[MemPort, Int],
- maskGranularity: Option[Int]
-) {
+ name: String,
+ depth: BigInt,
+ width: Int,
+ ports: Map[MemPort, Int],
+ maskGranularity: Option[Int]) {
- private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",")
+ private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",")
private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("")
// Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes
@@ -49,21 +53,34 @@ object MemConf {
val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r
def fromString(s: String): Seq[MemConf] = {
- s.split("\n").toSeq.map(_ match {
- case MemConf.regex(name, depth, width, ports, maskGran) => Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt)))
- case "" => None
- case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
- }).flatten
+ s.split("\n")
+ .toSeq
+ .map(_ match {
+ case MemConf.regex(name, depth, width, ports, maskGran) =>
+ Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt)))
+ case "" => None
+ case _ => throw new Exception(s"Error parsing MemConf string : ${s}")
+ })
+ .flatten
}
- def apply(name: String, depth: BigInt, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = {
+ def apply(
+ name: String,
+ depth: BigInt,
+ width: Int,
+ readPorts: Int,
+ writePorts: Int,
+ readWritePorts: Int,
+ maskGranularity: Option[Int]
+ ): MemConf = {
val ports: Seq[(MemPort, Int)] = (if (maskGranularity.isEmpty) {
- (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++
- (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts))
- } else {
- (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++
- (if (readWritePorts == 0) Seq() else Seq(MaskedReadWritePort -> readWritePorts))
- }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts))
+ (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts))
+ } else {
+ (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++
+ (if (readWritePorts == 0) Seq()
+ else Seq(MaskedReadWritePort -> readWritePorts))
+ }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts))
new MemConf(name, depth, width, ports.toMap, maskGranularity)
}
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala
index 3731ea86..c8cd3e8d 100644
--- a/src/main/scala/firrtl/passes/memlib/MemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala
@@ -19,38 +19,38 @@ object DefAnnotatedMemory {
m.readwriters,
m.readUnderWrite,
None, // mask granularity annotation
- None // No reference yet to another memory
+ None // No reference yet to another memory
)
}
}
case class DefAnnotatedMemory(
- info: Info,
- name: String,
- dataType: Type,
- depth: BigInt,
- writeLatency: Int,
- readLatency: Int,
- readers: Seq[String],
- writers: Seq[String],
- readwriters: Seq[String],
- readUnderWrite: ReadUnderWrite.Value,
- maskGran: Option[BigInt],
- memRef: Option[(String, String)] /* (Module, Mem) */
- //pins: Seq[Pin],
- ) extends Statement with IsDeclaration {
+ info: Info,
+ name: String,
+ dataType: Type,
+ depth: BigInt,
+ writeLatency: Int,
+ readLatency: Int,
+ readers: Seq[String],
+ writers: Seq[String],
+ readwriters: Seq[String],
+ readUnderWrite: ReadUnderWrite.Value,
+ maskGran: Option[BigInt],
+ memRef: Option[(String, String)] /* (Module, Mem) */
+ //pins: Seq[Pin],
+) extends Statement
+ with IsDeclaration {
override def serialize: String = this.toMem.serialize
- def mapStmt(f: Statement => Statement): Statement = this
- def mapExpr(f: Expression => Expression): Statement = this
- def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
- def mapString(f: String => String): Statement = this.copy(name = f(name))
- def toMem = DefMemory(info, name, dataType, depth,
- writeLatency, readLatency, readers, writers,
- readwriters, readUnderWrite)
- def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
- def foreachStmt(f: Statement => Unit): Unit = ()
- def foreachExpr(f: Expression => Unit): Unit = ()
- def foreachType(f: Type => Unit): Unit = f(dataType)
- def foreachString(f: String => Unit): Unit = f(name)
- def foreachInfo(f: Info => Unit): Unit = f(info)
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement = this
+ def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def toMem =
+ DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite)
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = ()
+ def foreachType(f: Type => Unit): Unit = f(dataType)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
index f0c9ebf4..1db132f7 100644
--- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala
@@ -7,8 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption}
class MemLibOptions extends RegisteredLibrary {
val name: String = "MemLib Options"
- val options: Seq[ShellOption[_]] = Seq( new InferReadWrite,
- new ReplSeqMem )
+ val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem)
.flatMap(_.options)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
index b6a9a23d..f153fa2b 100644
--- a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala
@@ -11,12 +11,12 @@ import MemPortUtils.{MemPortMap}
object MemTransformUtils {
/** Replaces references to old memory port names with new memory port names
- */
+ */
def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = {
//TODO(izraelevitz): check speed
def updateRef(e: Expression): Expression = {
- val ex = e map updateRef
- repl getOrElse (ex.serialize, ex)
+ val ex = e.map(updateRef)
+ repl.getOrElse(ex.serialize, ex)
}
def hasEmptyExpr(stmt: Statement): Boolean = {
@@ -24,16 +24,16 @@ object MemTransformUtils {
def testEmptyExpr(e: Expression): Expression = {
e match {
case EmptyExpression => foundEmpty = true
- case _ =>
+ case _ =>
}
- e map testEmptyExpr // map must return; no foreach
+ e.map(testEmptyExpr) // map must return; no foreach
}
- stmt map testEmptyExpr
+ stmt.map(testEmptyExpr)
foundEmpty
}
def updateStmtRefs(s: Statement): Statement =
- s map updateStmtRefs map updateRef match {
+ s.map(updateStmtRefs).map(updateRef) match {
case c: Connect if hasEmptyExpr(c) => EmptyStmt
case s => s
}
@@ -42,6 +42,6 @@ object MemTransformUtils {
}
def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem)
- def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
+ def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField =
MemPortUtils.memPortField(s.toMem, p, f)
}
diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
index 69c6b284..f325c0ba 100644
--- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala
@@ -7,19 +7,19 @@ import firrtl.ir._
import firrtl.Utils._
/** Given a mask, return a bitmask corresponding to the desired datatype.
- * Requirements:
- * - The mask type and datatype must be equivalent, except any ground type in
- * datatype must be matched by a 1-bit wide UIntType.
- * - The mask must be a reference, subfield, or subindex
- * The bitmask is a series of concatenations of the single mask bit over the
- * length of the corresponding ground type, e.g.:
- *{{{
- * wire mask: {x: UInt<1>, y: UInt<1>}
- * wire data: {x: UInt<2>, y: SInt<2>}
- * // this would return:
- * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
- * }}}
- */
+ * Requirements:
+ * - The mask type and datatype must be equivalent, except any ground type in
+ * datatype must be matched by a 1-bit wide UIntType.
+ * - The mask must be a reference, subfield, or subindex
+ * The bitmask is a series of concatenations of the single mask bit over the
+ * length of the corresponding ground type, e.g.:
+ * {{{
+ * wire mask: {x: UInt<1>, y: UInt<1>}
+ * wire data: {x: UInt<2>, y: SInt<2>}
+ * // this would return:
+ * cat(cat(mask.x, mask.x), cat(mask.y, mask.y))
+ * }}}
+ */
object toBitMask {
def apply(mask: Expression, dataType: Type): Expression = mask match {
case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType)
@@ -28,12 +28,13 @@ object toBitMask {
private def hiermask(mask: Expression, dataType: Type): Expression =
(mask.tpe, dataType) match {
case (mt: VectorType, dt: VectorType) =>
- seqCat((0 until mt.size).reverse map { i =>
+ seqCat((0 until mt.size).reverse.map { i =>
hiermask(WSubIndex(mask, i, mt.tpe, UnknownFlow), dt.tpe)
})
case (mt: BundleType, dt: BundleType) =>
- seqCat((mt.fields zip dt.fields) map { case (mf, df) =>
- hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe)
+ seqCat((mt.fields.zip(dt.fields)).map {
+ case (mf, df) =>
+ hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe)
})
case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) =>
seqCat(List.fill(bitWidth(dt).intValue)(mask))
@@ -44,7 +45,7 @@ object toBitMask {
object createMask {
def apply(dt: Type): Type = dt match {
case t: VectorType => VectorType(apply(t.tpe), t.size)
- case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe))))
+ case t: BundleType => BundleType(t.fields.map(f => f.copy(tpe = apply(f.tpe))))
case GroundType(w) if w == IntWidth(0) => UIntType(IntWidth(0))
case t: GroundType => BoolType
}
@@ -56,27 +57,33 @@ object MemPortUtils {
type Modules = collection.mutable.ArrayBuffer[DefModule]
def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq(
- Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))),
+ Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1).max(1)))),
Field("en", Default, BoolType),
Field("clk", Default, ClockType)
)
// Todo: merge it with memToBundle
def memType(mem: DefMemory): BundleType = {
- val rType = BundleType(defaultPortSeq(mem) :+
- Field("data", Flip, mem.dataType))
- val wType = BundleType(defaultPortSeq(mem) ++ Seq(
- Field("data", Default, mem.dataType),
- Field("mask", Default, createMask(mem.dataType))))
- val rwType = BundleType(defaultPortSeq(mem) ++ Seq(
- Field("rdata", Flip, mem.dataType),
- Field("wmode", Default, BoolType),
- Field("wdata", Default, mem.dataType),
- Field("wmask", Default, createMask(mem.dataType))))
+ val rType = BundleType(
+ defaultPortSeq(mem) :+
+ Field("data", Flip, mem.dataType)
+ )
+ val wType = BundleType(
+ defaultPortSeq(mem) ++ Seq(Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType)))
+ )
+ val rwType = BundleType(
+ defaultPortSeq(mem) ++ Seq(
+ Field("rdata", Flip, mem.dataType),
+ Field("wmode", Default, BoolType),
+ Field("wdata", Default, mem.dataType),
+ Field("wmask", Default, createMask(mem.dataType))
+ )
+ )
BundleType(
- (mem.readers map (Field(_, Flip, rType))) ++
- (mem.writers map (Field(_, Flip, wType))) ++
- (mem.readwriters map (Field(_, Flip, rwType))))
+ (mem.readers.map(Field(_, Flip, rType))) ++
+ (mem.writers.map(Field(_, Flip, wType))) ++
+ (mem.readwriters.map(Field(_, Flip, rwType)))
+ )
}
def memPortField(s: DefMemory, p: String, f: String): WSubField = {
diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
index c51a0adc..30529119 100644
--- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
+++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala
@@ -9,27 +9,27 @@ import firrtl.Mappers._
import MemPortUtils._
import MemTransformUtils._
-
/** Changes memory port names to standard port names (i.e. RW0 instead T_408)
- */
+ */
object RenameAnnotatedMemoryPorts extends Pass {
+
/** Renames memory ports to a standard naming scheme:
- * - R0, R1, ... for each read port
- * - W0, W1, ... for each write port
- * - RW0, RW1, ... for each readwrite port
- */
+ * - R0, R1, ... for each read port
+ * - W0, W1, ... for each write port
+ * - RW0, RW1, ... for each readwrite port
+ */
def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = {
- val rports = m.readers.indices map (i => s"R$i")
- val wports = m.writers.indices map (i => s"W$i")
- val rwports = m.readwriters.indices map (i => s"RW$i")
- m copy (readers = rports, writers = wports, readwriters = rwports)
+ val rports = m.readers.indices.map(i => s"R$i")
+ val wports = m.writers.indices.map(i => s"W$i")
+ val rwports = m.readwriters.indices.map(i => s"RW$i")
+ m.copy(readers = rports, writers = wports, readwriters = rwports)
}
/** Maps the serialized form of all memory port field names to the
- * corresponding new memory port field Expression.
- * E.g.:
- * - ("m.read.addr") becomes (m.R0.addr)
- */
+ * corresponding new memory port field Expression.
+ * E.g.:
+ * - ("m.read.addr") becomes (m.R0.addr)
+ */
def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = {
val defaultFields = Seq("addr", "en", "clk")
val rFields = defaultFields :+ "data"
@@ -37,7 +37,10 @@ object RenameAnnotatedMemoryPorts extends Pass {
val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask")
def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit =
- for ((p, i) <- ports.zipWithIndex; f <- fields) {
+ for {
+ (p, i) <- ports.zipWithIndex
+ f <- fields
+ } {
val newPort = WSubField(WRef(m.name), newPortKind + i)
val field = WSubField(newPort, f)
memPortMap(s"${m.name}.$p.$f") = field
@@ -55,16 +58,16 @@ object RenameAnnotatedMemoryPorts extends Pass {
val updatedMem = createMemProto(m)
getMemPortMap(m, memPortMap)
updatedMem
- case s => s map updateMemStmts(memPortMap)
+ case s => s.map(updateMemStmts(memPortMap))
}
/** Replaces candidate memories and their references with standard port names
- */
+ */
def updateMemMods(m: DefModule) = {
val memPortMap = new MemPortMap
- (m map updateMemStmts(memPortMap)
- map updateStmtRefs(memPortMap))
+ (m.map(updateMemStmts(memPortMap))
+ .map(updateStmtRefs(memPortMap)))
}
- def run(c: Circuit) = c copy (modules = c.modules map updateMemMods)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods))
}
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
index bfbc163a..fc381e88 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala
@@ -13,7 +13,6 @@ import firrtl.annotations._
import firrtl.stage.Forms
import wiring._
-
/** Annotates the name of the pins to add for WiringTransform */
case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation
@@ -35,14 +34,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
/** Return true if mask granularity is per bit, false if per byte or unspecified
*/
private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match {
- case None => false
+ case None => false
case Some(v) => v == 1
}
private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType))
+ defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)
+ )
private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
- defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)))
+ defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))
+ )
/** Catch incorrect memory instantiations when there are masked memories with unsupported aggregate types.
*
@@ -82,7 +83,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
)
private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType(
(defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match {
- case None => Nil
+ case None => Nil
case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType)))
case Some(_) => {
checkMaskDatatype(mem)
@@ -111,7 +112,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
Field("wdata", Default, flattenType(mem.dataType)),
Field("rdata", Flip, flattenType(mem.dataType))
) ++ (mem.maskGran match {
- case None => Nil
+ case None => Nil
case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType)))
case Some(_) => {
checkMaskDatatype(mem)
@@ -122,32 +123,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
def memToBundle(s: DefAnnotatedMemory) = BundleType(
s.readers.map(Field(_, Flip, rPortToBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToBundle(s))))
+ s.writers.map(Field(_, Flip, wPortToBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))
+ )
def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType(
s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++
- s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
- s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))))
+ s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++
+ s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))
+ )
/** Creates a wrapper module and external module to replace a candidate memory
- * The wrapper module has the same type as the memory it replaces
- * The external module
- */
+ * The wrapper module has the same type as the memory it replaces
+ * The external module
+ */
def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = {
assert(m.dataType != UnknownType)
val wrapperIoType = memToBundle(m)
- val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
// Creates a type with the write/readwrite masks omitted if necessary
val bbIoType = memToFlattenBundle(m)
- val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe))
+ val bbIoPorts = bbIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe))
val bbRef = WRef(m.name, bbIoType)
val hasMask = m.maskGran.isDefined
val fillMask = getFillWMask(m)
def portRef(p: String) = WRef(p, field_type(wrapperIoType, p))
val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++
- (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
- (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
- (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
+ (m.readers.flatMap(r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++
+ (m.writers.flatMap(w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++
+ (m.readwriters.flatMap(rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask)))
val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts))
val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty)
// TODO: Annotate? -- use actual annotation map
@@ -160,16 +163,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
// TODO(shunshou): get rid of copy pasta
// Connects the clk, en, and addr fields from the wrapperPort to the bbPort
def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] =
- Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f))
+ Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f))
// Generates mask bits (concatenates an aggregate to ground type)
// depending on mask granularity (# bits = data width / mask granularity)
def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression =
if (fillMask) toBitMask(mask, dataType) else toBits(mask)
- def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
+ def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] =
defaultConnects(wrapperPort, bbPort) :+
- fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
+ fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data"))
def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = {
val wrapperData = WSubField(wrapperPort, "data")
@@ -177,11 +180,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData))
hasMask match {
case false => defaultSeq
- case true => defaultSeq :+ Connect(
- NoInfo,
- WSubField(bbPort, "mask"),
- maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
- )
+ case true =>
+ defaultSeq :+ Connect(
+ NoInfo,
+ WSubField(bbPort, "mask"),
+ maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask)
+ )
}
}
@@ -190,61 +194,67 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq(
fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")),
connectFields(bbPort, "wmode", wrapperPort, "wmode"),
- Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)))
+ Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))
+ )
hasMask match {
case false => defaultSeq
- case true => defaultSeq :+ Connect(
- NoInfo,
- WSubField(bbPort, "wmask"),
- maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
- )
+ case true =>
+ defaultSeq :+ Connect(
+ NoInfo,
+ WSubField(bbPort, "wmask"),
+ maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask)
+ )
}
}
/** Mapping from (module, memory name) pairs to blackbox names */
private type NameMap = collection.mutable.HashMap[(String, String), String]
+
/** Construct NameMap by assigning unique names for each memory blackbox */
def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = {
s match {
- case m: DefAnnotatedMemory => m.memRef match {
- case None => nameMap(mname -> m.name) = namespace newName m.name
- case Some(_) =>
- }
+ case m: DefAnnotatedMemory =>
+ m.memRef match {
+ case None => nameMap(mname -> m.name) = namespace.newName(m.name)
+ case Some(_) =>
+ }
case _ =>
}
- s map constructNameMap(namespace, nameMap, mname)
+ s.map(constructNameMap(namespace, nameMap, mname))
}
- def updateMemStmts(namespace: Namespace,
- nameMap: NameMap,
- mname: String,
- memPortMap: MemPortMap,
- memMods: Modules)
- (s: Statement): Statement = s match {
+ def updateMemStmts(
+ namespace: Namespace,
+ nameMap: NameMap,
+ mname: String,
+ memPortMap: MemPortMap,
+ memMods: Modules
+ )(s: Statement
+ ): Statement = s match {
case m: DefAnnotatedMemory =>
if (m.maskGran.isEmpty) {
- m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
- m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
+ m.writers.foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression }
+ m.readwriters.foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression }
}
m.memRef match {
case None =>
// prototype mem
val newWrapperName = nameMap(mname -> m.name)
- val newMemBBName = namespace newName s"${newWrapperName}_ext"
- val newMem = m copy (name = newMemBBName)
+ val newMemBBName = namespace.newName(s"${newWrapperName}_ext")
+ val newMem = m.copy(name = newMemBBName)
memMods ++= createMemModule(newMem, newWrapperName)
WDefInstance(m.info, m.name, newWrapperName, UnknownType)
case Some((module, mem)) =>
WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType)
}
- case sx => sx map updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)
+ case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods))
}
def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = {
val memPortMap = new MemPortMap
- (m map updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)
- map updateStmtRefs(memPortMap))
+ (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods))
+ .map(updateStmtRefs(memPortMap)))
}
def execute(state: CircuitState): CircuitState = {
@@ -252,15 +262,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM
val namespace = Namespace(c)
val memMods = new Modules
val nameMap = new NameMap
- c.modules map (m => m map constructNameMap(namespace, nameMap, m.name))
- val modules = c.modules map updateMemMods(namespace, nameMap, memMods)
+ c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name)))
+ val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods))
// print conf
writer.serialize()
val pannos = state.annotations.collect { case a: PinAnnotation => a }
val pins = pannos match {
- case Seq() => Nil
+ case Seq() => Nil
case Seq(PinAnnotation(pins)) => pins
- case _ => throwInternalError("Something went wrong")
+ case _ => throwInternalError("Something went wrong")
}
val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) =>
seq ++ memMods.collect {
diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
index 87321ea0..79e07640 100644
--- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
+++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala
@@ -7,7 +7,7 @@ import firrtl._
import firrtl.annotations._
import firrtl.options.{HasShellOptions, ShellOption}
import Utils.error
-import java.io.{File, CharArrayWriter, PrintWriter}
+import java.io.{CharArrayWriter, File, PrintWriter}
import wiring._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
@@ -50,7 +50,15 @@ class ConfWriter(filename: String) {
// assert that we don't overflow going from BigInt to Int conversion
require(bitWidth(m.dataType) <= Int.MaxValue)
m.maskGran.foreach { case x => require(x <= Int.MaxValue) }
- val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt))
+ val conf = MemConf(
+ m.name,
+ m.depth,
+ bitWidth(m.dataType).toInt,
+ m.readers.length,
+ m.writers.length,
+ m.readwriters.length,
+ m.maskGran.map(_.toInt)
+ )
outputBuffer.append(conf.toString)
}
def serialize() = {
@@ -113,27 +121,31 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat
val options = Seq(
new ShellOption[String](
longOption = "repl-seq-mem",
- toAnnotationSeq = (a: String) => Seq( passes.memlib.ReplSeqMemAnnotation.parse(a),
- RunFirrtlTransformAnnotation(new ReplSeqMem) ),
+ toAnnotationSeq =
+ (a: String) => Seq(passes.memlib.ReplSeqMemAnnotation.parse(a), RunFirrtlTransformAnnotation(new ReplSeqMem)),
helpText = "Blackbox and emit a configuration file for each sequential memory",
shortOption = Some("frsq"),
- helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") ) )
+ helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>")
+ )
+ )
def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] =
- Seq(new SimpleMidTransform(Legalize),
- new SimpleMidTransform(ToMemIR),
- new SimpleMidTransform(ResolveMaskGranularity),
- new SimpleMidTransform(RenameAnnotatedMemoryPorts),
- new ResolveMemoryReference,
- new CreateMemoryAnnotations(inConfigFile),
- new ReplaceMemMacros(outConfigFile),
- new WiringTransform,
- new SimpleMidTransform(RemoveEmpty),
- new SimpleMidTransform(CheckInitialization),
- new SimpleMidTransform(InferTypes),
- Uniquify,
- new SimpleMidTransform(ResolveKinds),
- new SimpleMidTransform(ResolveFlows))
+ Seq(
+ new SimpleMidTransform(Legalize),
+ new SimpleMidTransform(ToMemIR),
+ new SimpleMidTransform(ResolveMaskGranularity),
+ new SimpleMidTransform(RenameAnnotatedMemoryPorts),
+ new ResolveMemoryReference,
+ new CreateMemoryAnnotations(inConfigFile),
+ new ReplaceMemMacros(outConfigFile),
+ new WiringTransform,
+ new SimpleMidTransform(RemoveEmpty),
+ new SimpleMidTransform(CheckInitialization),
+ new SimpleMidTransform(InferTypes),
+ Uniquify,
+ new SimpleMidTransform(ResolveKinds),
+ new SimpleMidTransform(ResolveFlows)
+ )
def execute(state: CircuitState): CircuitState = {
val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a }
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 41c47dce..434c7602 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -28,10 +28,10 @@ object AnalysisUtils {
connects(value.serialize) = WInvalid
case _ => // do nothing
}
- s map getConnects(connects)
+ s.map(getConnects(connects))
}
val connects = new Connects
- m map getConnects(connects)
+ m.map(getConnects(connects))
connects
}
@@ -56,8 +56,8 @@ object AnalysisUtils {
else if (weq(tvOrigin, fvOrigin)) tvOrigin
else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin
else e
- case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one
- case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero
+ case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one
+ case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero
case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) =>
val extractionWidth = (msb - lsb) + 1
val nodeWidth = bitWidth(args.head.tpe)
@@ -69,10 +69,10 @@ object AnalysisUtils {
case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
- connects get e.serialize match {
- case Some(ex) => getOrigin(connects)(ex)
- case None => e
- }
+ connects.get(e.serialize) match {
+ case Some(ex) => getOrigin(connects)(ex)
+ case None => e
+ }
case _ => e
}
}
@@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass {
*/
def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = {
val wenOrigin = getOrigin(connects)(wen)
- val wmaskOrigin = connects.keys filter
- (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)}
+ val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) }
// all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking)
- val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one))
+ val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one))
if (redundantMask) None else Some(wmaskOrigin.size)
}
@@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass {
def updateStmts(connects: Connects)(s: Statement): Statement = s match {
case m: DefAnnotatedMemory =>
val dataBits = bitWidth(m.dataType)
- val rwMasks = m.readwriters map (rw =>
- getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
- val wMasks = m.writers map (w =>
- getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
+ val rwMasks =
+ m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask")))
+ val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask")))
val maskGran = (rwMasks ++ wMasks).head match {
- case None => None
+ case None => None
case Some(maskBits) => Some(dataBits / maskBits)
}
m.copy(maskGran = maskGran)
- case sx => sx map updateStmts(connects)
+ case sx => sx.map(updateStmts(connects))
}
- def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m))
- def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m)))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems))
}
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
index b5ff10c6..e80e0c4a 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala
@@ -14,7 +14,7 @@ case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnot
}
/** Resolves annotation ref to memories that exactly match (except name) another memory
- */
+ */
class ResolveMemoryReference extends Transform with DependencyAPIMigration {
override def prerequisites = Forms.MidForm
@@ -45,10 +45,12 @@ class ResolveMemoryReference extends Transform with DependencyAPIMigration {
/** If a candidate memory is identical except for name to another, add an
* annotation that references the name of the other memory.
*/
- def updateMemStmts(mname: String,
- existingMems: AnnotatedMemories,
- noDedupMap: Map[String, Set[String]])
- (s: Statement): Statement = s match {
+ def updateMemStmts(
+ mname: String,
+ existingMems: AnnotatedMemories,
+ noDedupMap: Map[String, Set[String]]
+ )(s: Statement
+ ): Statement = s match {
// If not dedupable, no need to add to existing (since nothing can dedup with it)
// We just return the DefAnnotatedMemory as is in the default case below
case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) =>
diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
index 554a3572..9fe7f852 100644
--- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
+++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala
@@ -14,16 +14,17 @@ import firrtl.ir._
* - undefined read-under-write behavior
*/
object ToMemIR extends Pass {
+
/** Only annotate memories that are candidates for memory macro replacements
* i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined."
*/
import ReadUnderWrite._
def updateStmts(s: Statement): Statement = s match {
- case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
+ case m @ DefMemory(_, _, _, _, 1, 1, r, w, rw, Undefined) if (w.length + rw.length) == 1 && r.length <= 1 =>
DefAnnotatedMemory(m)
- case sx => sx map updateStmts
+ case sx => sx.map(updateStmts)
}
- def annotateModMems(m: DefModule) = m map updateStmts
- def run(c: Circuit) = c copy (modules = c.modules map annotateModMems)
+ def annotateModMems(m: DefModule) = m.map(updateStmts)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(annotateModMems))
}
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index dd644323..a2b14343 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -24,19 +24,19 @@ object MemDelayAndReadwriteTransformer {
case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect])
// Utilities for generating hardware
- def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType)
- def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType)
- def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r)
- def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe))
+ def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType)
+ def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType)
+ def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r)
+ def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe))
// Utilities for working with WithValid groups
def connect(l: WithValid, r: WithValid): Seq[Connect] = {
- val paired = (l.valid +: l.payload) zip (r.valid +: r.payload)
+ val paired = (l.valid +: l.payload).zip(r.valid +: r.payload)
paired.map { case (le, re) => connect(le, re) }
}
def condConnect(l: WithValid, r: WithValid): Seq[Connect] = {
- connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) }
+ connect(l.valid, r.valid) +: (l.payload.zip(r.payload)).map { case (le, re) => condConnect(r.valid)(le, re) }
}
// Internal representation of a pipeline stage with an associated valid signal
@@ -47,20 +47,23 @@ object MemDelayAndReadwriteTransformer {
private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_")
// Pipeline a group of signals with an associated valid signal. Gate registers when possible.
- def pipelineWithValid(ns: Namespace)(
- clock: Expression,
- depth: Int,
- src: WithValid,
- nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = {
+ def pipelineWithValid(
+ ns: Namespace
+ )(clock: Expression,
+ depth: Int,
+ src: WithValid,
+ nameTemplate: Option[WithValid] = None
+ ): (WithValid, Seq[Statement], Seq[Connect]) = {
def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e)
val template = nameTemplate.getOrElse(src)
- val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev =>
- def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind)
- val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef))
- val regs = (ref.valid +: ref.payload).map(asReg)
- PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref)))
+ val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) {
+ case prev =>
+ def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind)
+ val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef))
+ val regs = (ref.valid +: ref.payload).map(asReg)
+ PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref)))
}
(stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns))
}
@@ -84,10 +87,10 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
private def findMemConns(s: Statement): Unit = s match {
case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr
- case _ => s.foreach(findMemConns)
+ case _ => s.foreach(findMemConns)
}
- private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match {
+ private def swapMemRefs(e: Expression): Expression = e.map(swapMemRefs) match {
case sf: WSubField => exprReplacements.getOrElse(we(sf), sf)
case ex => ex
}
@@ -105,51 +108,57 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0
val wCmdDelay = mem.writeLatency - 1
- val readStmts = (mem.readers ++ mem.readwriters).map { case r =>
- def oldDriver(f: String) = netlist(we(memPortField(mem, r, f)))
- def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f)
- val clk = oldDriver("clk")
-
- // Pack sources of read command inputs into WithValid object -> different for readwriter
- val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en")
- val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr")))
- val cmdSink = WithValid(newField("en"), Seq(newField("addr")))
- val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
- val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
-
- // Pipeline read response using *last* command pipe stage enable as the valid signal
- val resp = WithValid(cmdPiped.valid, Seq(newField("data")))
- val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names
- val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate)
-
- // Make sure references to the read data get appropriately substituted
- val oldRDataName = if (rMap.contains(r)) "rdata" else "data"
- exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head
-
- // Return all statements; they're separated so connects can go after all declarations
- SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns)
+ val readStmts = (mem.readers ++ mem.readwriters).map {
+ case r =>
+ def oldDriver(f: String) = netlist(we(memPortField(mem, r, f)))
+ def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f)
+ val clk = oldDriver("clk")
+
+ // Pack sources of read command inputs into WithValid object -> different for readwriter
+ val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en")
+ val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr")))
+ val cmdSink = WithValid(newField("en"), Seq(newField("addr")))
+ val (cmdPiped, cmdDecls, cmdConns) =
+ pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
+ val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
+
+ // Pipeline read response using *last* command pipe stage enable as the valid signal
+ val resp = WithValid(cmdPiped.valid, Seq(newField("data")))
+ val respPipeNameTemplate =
+ Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names
+ val (respPiped, respDecls, respConns) =
+ pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate)
+
+ // Make sure references to the read data get appropriately substituted
+ val oldRDataName = if (rMap.contains(r)) "rdata" else "data"
+ exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head
+
+ // Return all statements; they're separated so connects can go after all declarations
+ SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns)
}
- val writeStmts = (mem.writers ++ mem.readwriters).map { case w =>
- def oldDriver(f: String) = netlist(we(memPortField(mem, w, f)))
- def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f)
- val clk = oldDriver("clk")
-
- // Pack sources of write command inputs into WithValid object -> different for readwriter
- val cmdSrc = if (wMap.contains(w)) {
- val en = AND(oldDriver("en"), oldDriver("wmode"))
- WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata")))
- } else {
- WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data")))
- }
-
- // Pipeline write command, connect to memory
- val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data")))
- val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
- val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
-
- // Return all statements; they're separated so connects can go after all declarations
- SplitStatements(cmdDecls, cmdConns ++ cmdPortConns)
+ val writeStmts = (mem.writers ++ mem.readwriters).map {
+ case w =>
+ def oldDriver(f: String) = netlist(we(memPortField(mem, w, f)))
+ def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f)
+ val clk = oldDriver("clk")
+
+ // Pack sources of write command inputs into WithValid object -> different for readwriter
+ val cmdSrc = if (wMap.contains(w)) {
+ val en = AND(oldDriver("en"), oldDriver("wmode"))
+ WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata")))
+ } else {
+ WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data")))
+ }
+
+ // Pipeline write command, connect to memory
+ val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data")))
+ val (cmdPiped, cmdDecls, cmdConns) =
+ pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink))
+ val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk)
+
+ // Return all statements; they're separated so connects can go after all declarations
+ SplitStatements(cmdDecls, cmdConns ++ cmdPortConns)
}
newConns ++= (readStmts ++ writeStmts).flatMap(_.conns)
@@ -171,8 +180,7 @@ object VerilogMemDelays extends Pass {
override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf)
override val optionalPrerequisiteOf =
- Seq( Dependency[VerilogEmitter],
- Dependency[SystemVerilogEmitter] )
+ Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
override def invalidates(a: Transform): Boolean = a match {
case _: transforms.ConstantPropagation | ResolveFlows => true
@@ -180,5 +188,5 @@ object VerilogMemDelays extends Pass {
}
def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed
- def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform))
+ def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform))
}
diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
index a43adfe2..b5f91e7b 100644
--- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
+++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala
@@ -6,7 +6,6 @@ import net.jcazevedo.moultingyaml._
import java.io.{CharArrayWriter, File, PrintWriter}
import firrtl.FileUtils
-
object CustomYAMLProtocol extends DefaultYamlProtocol {
// bottom depends on top
implicit val _pin = yamlFormat1(Pin)
@@ -20,17 +19,15 @@ case class Source(name: String, module: String)
case class Top(name: String)
case class Config(pin: Pin, source: Source, top: Top)
-
class YamlFileReader(file: String) {
- def parse[A](implicit reader: YamlReader[A]) : Seq[A] = {
+ def parse[A](implicit reader: YamlReader[A]): Seq[A] = {
if (new File(file).exists) {
val yamlString = FileUtils.getText(file)
- yamlString.parseYamls flatMap (x =>
- try Some(reader read x)
+ yamlString.parseYamls.flatMap(x =>
+ try Some(reader.read(x))
catch { case e: Exception => None }
)
- }
- else sys.error("Yaml file doesn't exist!")
+ } else sys.error("Yaml file doesn't exist!")
}
}
@@ -38,11 +35,11 @@ class YamlFileWriter(file: String) {
val outputBuffer = new CharArrayWriter
val separator = "--- \n"
def append(in: YamlValue): Unit = {
- outputBuffer append s"$separator${in.prettyPrint}"
+ outputBuffer.append(s"$separator${in.prettyPrint}")
}
def dump(): Unit = {
val outputFile = new PrintWriter(file)
- outputFile write outputBuffer.toString
+ outputFile.write(outputBuffer.toString)
outputFile.close()
}
}