diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/memlib | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib')
15 files changed, 399 insertions, 335 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala index 14bd9e44..d237c36a 100644 --- a/src/main/scala/firrtl/passes/memlib/DecorateMems.scala +++ b/src/main/scala/firrtl/passes/memlib/DecorateMems.scala @@ -19,8 +19,9 @@ class CreateMemoryAnnotations(reader: Option[YamlFileReader]) extends Transform import CustomYAMLProtocol._ val configs = r.parse[Config] val oldAnnos = state.annotations - val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { case ((annos, pins), config) => - (annos, pins :+ config.pin.name) + val (as, pins) = configs.foldLeft((oldAnnos, Seq.empty[String])) { + case ((annos, pins), config) => + (annos, pins :+ config.pin.name) } state.copy(annotations = PinAnnotation(pins.toSeq) +: as) } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 4847a698..e290633e 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -10,12 +10,11 @@ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} import firrtl.options.{HasShellOptions, ShellOption} import MemPortUtils.memPortField -import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} +import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects} import WrappedExpression.weq import annotations._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} - case object InferReadWriteAnnotation extends NoTargetAnnotation // This pass examine the enable signals of the read & write ports of memories @@ -40,12 +39,13 @@ object InferReadWritePass extends Pass { getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval) // Visit each term of AND operation case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTerms(connects)) + e +: (args.flatMap(getProductTerms(connects))) // Visit connected nodes to references - case _: WRef | _: WSubField | _: WSubIndex => connects get e match { - case None => Seq(e) - case Some(ex) => e +: getProductTerms(connects)(ex) - } + case _: WRef | _: WSubField | _: WSubIndex => + connects.get(e) match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) + } // Otherwise just return itself case _ => Seq(e) } @@ -58,96 +58,103 @@ object InferReadWritePass extends Pass { // b ?= Eq(a, 0) or b ?= Eq(0, a) case (_, DoPrim(Eq, args, _, _)) => weq(args.head, a) && weq(args(1), zero) || - weq(args(1), a) && weq(args.head, zero) + weq(args(1), a) && weq(args.head, zero) // a ?= Eq(b, 0) or b ?= Eq(0, a) case (DoPrim(Eq, args, _, _), _) => weq(args.head, b) && weq(args(1), zero) || - weq(args(1), b) && weq(args.head, zero) + weq(args(1), b) && weq(args.head, zero) case _ => false } - def replaceExp(repl: Netlist)(e: Expression): Expression = - e map replaceExp(repl) match { - case ex: WSubField => repl getOrElse (ex.serialize, ex) + e.map(replaceExp(repl)) match { + case ex: WSubField => repl.getOrElse(ex.serialize, ex) case ex => ex } def replaceStmt(repl: Netlist)(s: Statement): Statement = - s map replaceStmt(repl) map replaceExp(repl) match { + s.map(replaceStmt(repl)).map(replaceExp(repl)) match { case Connect(_, EmptyExpression, _) => EmptyStmt - case sx => sx + case sx => sx } - def inferReadWriteStmt(connects: Connects, - repl: Netlist, - stmts: Statements) - (s: Statement): Statement = s match { + def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match { // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => val readers = new PortSet val writers = new PortSet val readwriters = collection.mutable.ArrayBuffer[String]() val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) - for (w <- mem.writers ; r <- mem.readers) { + for { + w <- mem.writers + r <- mem.readers + } { val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en")) val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en")) - val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b))) + val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) { - val rw = namespace newName "rw" + val rw = namespace.newName("rw") val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw readers += r writers += w - repl(memPortField(mem, r, "clk")) = EmptyExpression - repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") - repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = EmptyExpression + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = EmptyExpression repl(memPortField(mem, w, "addr")) = EmptyExpression repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) - stmts += Connect(NoInfo, WSubField(rwExp, "en"), - DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), - connects(memPortField(mem, w, "en"))), Nil, BoolType)) - stmts += Connect(NoInfo, WSubField(rwExp, "addr"), - Mux(connects(memPortField(mem, w, "en")), - connects(memPortField(mem, w, "addr")), - connects(memPortField(mem, r, "addr")), UnknownType)) + stmts += Connect( + NoInfo, + WSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType) + ) + stmts += Connect( + NoInfo, + WSubField(rwExp, "addr"), + Mux( + connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), + UnknownType + ) + ) } } - if (readwriters.isEmpty) mem else mem copy ( - readers = mem.readers filterNot readers, - writers = mem.writers filterNot writers, - readwriters = mem.readwriters ++ readwriters) - case sx => sx map inferReadWriteStmt(connects, repl, stmts) + if (readwriters.isEmpty) mem + else + mem.copy( + readers = mem.readers.filterNot(readers), + writers = mem.writers.filterNot(writers), + readwriters = mem.readwriters ++ readwriters + ) + case sx => sx.map(inferReadWriteStmt(connects, repl, stmts)) } def inferReadWrite(m: DefModule) = { val connects = getConnects(m) val repl = new Netlist val stmts = new Statements - (m map inferReadWriteStmt(connects, repl, stmts) - map replaceStmt(repl)) match { + (m.map(inferReadWriteStmt(connects, repl, stmts)) + .map(replaceStmt(repl))) match { case m: ExtModule => m - case m: Module => m copy (body = Block(m.body +: stmts.toSeq)) + case m: Module => m.copy(body = Block(m.body +: stmts.toSeq)) } } - def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) + def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite)) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform - with DependencyAPIMigration - with SeqTransformBased - with HasShellOptions { +class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty @@ -159,7 +166,9 @@ class InferReadWrite extends Transform longOption = "infer-rw", toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)), helpText = "Enable read/write port inference for memories", - shortOption = Some("firw") ) ) + shortOption = Some("firw") + ) + ) def transforms = Seq( InferReadWritePass, diff --git a/src/main/scala/firrtl/passes/memlib/MemConf.scala b/src/main/scala/firrtl/passes/memlib/MemConf.scala index 3809c47c..871a1093 100644 --- a/src/main/scala/firrtl/passes/memlib/MemConf.scala +++ b/src/main/scala/firrtl/passes/memlib/MemConf.scala @@ -3,7 +3,6 @@ package firrtl.passes package memlib - sealed abstract class MemPort(val name: String) { override def toString = name } case object ReadPort extends MemPort("read") @@ -19,22 +18,27 @@ object MemPort { def apply(s: String): Option[MemPort] = MemPort.all.find(_.name == s) def fromString(s: String): Map[MemPort, Int] = { - s.split(",").toSeq.map(MemPort.apply).map(_ match { - case Some(x) => x - case _ => throw new Exception(s"Error parsing MemPort string : ${s}") - }).groupBy(identity).mapValues(_.size).toMap + s.split(",") + .toSeq + .map(MemPort.apply) + .map(_ match { + case Some(x) => x + case _ => throw new Exception(s"Error parsing MemPort string : ${s}") + }) + .groupBy(identity) + .mapValues(_.size) + .toMap } } case class MemConf( - name: String, - depth: BigInt, - width: Int, - ports: Map[MemPort, Int], - maskGranularity: Option[Int] -) { + name: String, + depth: BigInt, + width: Int, + ports: Map[MemPort, Int], + maskGranularity: Option[Int]) { - private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") } mkString (",") + private def portsStr = ports.map { case (port, num) => Seq.fill(num)(port.name).mkString(",") }.mkString(",") private def maskGranStr = maskGranularity.map((p) => s"mask_gran $p").getOrElse("") // Assert that all of the entries in the port map are greater than zero to make it easier to compare two of these case classes @@ -49,21 +53,34 @@ object MemConf { val regex = raw"\s*name\s+(\w+)\s+depth\s+(\d+)\s+width\s+(\d+)\s+ports\s+([^\s]+)\s+(?:mask_gran\s+(\d+))?\s*".r def fromString(s: String): Seq[MemConf] = { - s.split("\n").toSeq.map(_ match { - case MemConf.regex(name, depth, width, ports, maskGran) => Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))) - case "" => None - case _ => throw new Exception(s"Error parsing MemConf string : ${s}") - }).flatten + s.split("\n") + .toSeq + .map(_ match { + case MemConf.regex(name, depth, width, ports, maskGran) => + Some(MemConf(name, BigInt(depth), width.toInt, MemPort.fromString(ports), Option(maskGran).map(_.toInt))) + case "" => None + case _ => throw new Exception(s"Error parsing MemConf string : ${s}") + }) + .flatten } - def apply(name: String, depth: BigInt, width: Int, readPorts: Int, writePorts: Int, readWritePorts: Int, maskGranularity: Option[Int]): MemConf = { + def apply( + name: String, + depth: BigInt, + width: Int, + readPorts: Int, + writePorts: Int, + readWritePorts: Int, + maskGranularity: Option[Int] + ): MemConf = { val ports: Seq[(MemPort, Int)] = (if (maskGranularity.isEmpty) { - (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++ - (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts)) - } else { - (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++ - (if (readWritePorts == 0) Seq() else Seq(MaskedReadWritePort -> readWritePorts)) - }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts)) + (if (writePorts == 0) Seq() else Seq(WritePort -> writePorts)) ++ + (if (readWritePorts == 0) Seq() else Seq(ReadWritePort -> readWritePorts)) + } else { + (if (writePorts == 0) Seq() else Seq(MaskedWritePort -> writePorts)) ++ + (if (readWritePorts == 0) Seq() + else Seq(MaskedReadWritePort -> readWritePorts)) + }) ++ (if (readPorts == 0) Seq() else Seq(ReadPort -> readPorts)) new MemConf(name, depth, width, ports.toMap, maskGranularity) } } diff --git a/src/main/scala/firrtl/passes/memlib/MemIR.scala b/src/main/scala/firrtl/passes/memlib/MemIR.scala index 3731ea86..c8cd3e8d 100644 --- a/src/main/scala/firrtl/passes/memlib/MemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/MemIR.scala @@ -19,38 +19,38 @@ object DefAnnotatedMemory { m.readwriters, m.readUnderWrite, None, // mask granularity annotation - None // No reference yet to another memory + None // No reference yet to another memory ) } } case class DefAnnotatedMemory( - info: Info, - name: String, - dataType: Type, - depth: BigInt, - writeLatency: Int, - readLatency: Int, - readers: Seq[String], - writers: Seq[String], - readwriters: Seq[String], - readUnderWrite: ReadUnderWrite.Value, - maskGran: Option[BigInt], - memRef: Option[(String, String)] /* (Module, Mem) */ - //pins: Seq[Pin], - ) extends Statement with IsDeclaration { + info: Info, + name: String, + dataType: Type, + depth: BigInt, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String], + readUnderWrite: ReadUnderWrite.Value, + maskGran: Option[BigInt], + memRef: Option[(String, String)] /* (Module, Mem) */ + //pins: Seq[Pin], +) extends Statement + with IsDeclaration { override def serialize: String = this.toMem.serialize - def mapStmt(f: Statement => Statement): Statement = this - def mapExpr(f: Expression => Expression): Statement = this - def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) - def mapString(f: String => String): Statement = this.copy(name = f(name)) - def toMem = DefMemory(info, name, dataType, depth, - writeLatency, readLatency, readers, writers, - readwriters, readUnderWrite) - def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) - def foreachStmt(f: Statement => Unit): Unit = () - def foreachExpr(f: Expression => Unit): Unit = () - def foreachType(f: Type => Unit): Unit = f(dataType) - def foreachString(f: String => Unit): Unit = f(name) - def foreachInfo(f: Info => Unit): Unit = f(info) + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = this + def mapType(f: Type => Type): Statement = this.copy(dataType = f(dataType)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def toMem = + DefMemory(info, name, dataType, depth, writeLatency, readLatency, readers, writers, readwriters, readUnderWrite) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = () + def foreachType(f: Type => Unit): Unit = f(dataType) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) } diff --git a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala index f0c9ebf4..1db132f7 100644 --- a/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala +++ b/src/main/scala/firrtl/passes/memlib/MemLibOptions.scala @@ -7,8 +7,7 @@ import firrtl.options.{RegisteredLibrary, ShellOption} class MemLibOptions extends RegisteredLibrary { val name: String = "MemLib Options" - val options: Seq[ShellOption[_]] = Seq( new InferReadWrite, - new ReplSeqMem ) + val options: Seq[ShellOption[_]] = Seq(new InferReadWrite, new ReplSeqMem) .flatMap(_.options) } diff --git a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala index b6a9a23d..f153fa2b 100644 --- a/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemTransformUtils.scala @@ -11,12 +11,12 @@ import MemPortUtils.{MemPortMap} object MemTransformUtils { /** Replaces references to old memory port names with new memory port names - */ + */ def updateStmtRefs(repl: MemPortMap)(s: Statement): Statement = { //TODO(izraelevitz): check speed def updateRef(e: Expression): Expression = { - val ex = e map updateRef - repl getOrElse (ex.serialize, ex) + val ex = e.map(updateRef) + repl.getOrElse(ex.serialize, ex) } def hasEmptyExpr(stmt: Statement): Boolean = { @@ -24,16 +24,16 @@ object MemTransformUtils { def testEmptyExpr(e: Expression): Expression = { e match { case EmptyExpression => foundEmpty = true - case _ => + case _ => } - e map testEmptyExpr // map must return; no foreach + e.map(testEmptyExpr) // map must return; no foreach } - stmt map testEmptyExpr + stmt.map(testEmptyExpr) foundEmpty } def updateStmtRefs(s: Statement): Statement = - s map updateStmtRefs map updateRef match { + s.map(updateStmtRefs).map(updateRef) match { case c: Connect if hasEmptyExpr(c) => EmptyStmt case s => s } @@ -42,6 +42,6 @@ object MemTransformUtils { } def defaultPortSeq(mem: DefAnnotatedMemory): Seq[Field] = MemPortUtils.defaultPortSeq(mem.toMem) - def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField = + def memPortField(s: DefAnnotatedMemory, p: String, f: String): WSubField = MemPortUtils.memPortField(s.toMem, p, f) } diff --git a/src/main/scala/firrtl/passes/memlib/MemUtils.scala b/src/main/scala/firrtl/passes/memlib/MemUtils.scala index 69c6b284..f325c0ba 100644 --- a/src/main/scala/firrtl/passes/memlib/MemUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/MemUtils.scala @@ -7,19 +7,19 @@ import firrtl.ir._ import firrtl.Utils._ /** Given a mask, return a bitmask corresponding to the desired datatype. - * Requirements: - * - The mask type and datatype must be equivalent, except any ground type in - * datatype must be matched by a 1-bit wide UIntType. - * - The mask must be a reference, subfield, or subindex - * The bitmask is a series of concatenations of the single mask bit over the - * length of the corresponding ground type, e.g.: - *{{{ - * wire mask: {x: UInt<1>, y: UInt<1>} - * wire data: {x: UInt<2>, y: SInt<2>} - * // this would return: - * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) - * }}} - */ + * Requirements: + * - The mask type and datatype must be equivalent, except any ground type in + * datatype must be matched by a 1-bit wide UIntType. + * - The mask must be a reference, subfield, or subindex + * The bitmask is a series of concatenations of the single mask bit over the + * length of the corresponding ground type, e.g.: + * {{{ + * wire mask: {x: UInt<1>, y: UInt<1>} + * wire data: {x: UInt<2>, y: SInt<2>} + * // this would return: + * cat(cat(mask.x, mask.x), cat(mask.y, mask.y)) + * }}} + */ object toBitMask { def apply(mask: Expression, dataType: Type): Expression = mask match { case ex @ (_: WRef | _: WSubField | _: WSubIndex) => hiermask(ex, dataType) @@ -28,12 +28,13 @@ object toBitMask { private def hiermask(mask: Expression, dataType: Type): Expression = (mask.tpe, dataType) match { case (mt: VectorType, dt: VectorType) => - seqCat((0 until mt.size).reverse map { i => + seqCat((0 until mt.size).reverse.map { i => hiermask(WSubIndex(mask, i, mt.tpe, UnknownFlow), dt.tpe) }) case (mt: BundleType, dt: BundleType) => - seqCat((mt.fields zip dt.fields) map { case (mf, df) => - hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe) + seqCat((mt.fields.zip(dt.fields)).map { + case (mf, df) => + hiermask(WSubField(mask, mf.name, mf.tpe, UnknownFlow), df.tpe) }) case (UIntType(width), dt: GroundType) if width == IntWidth(BigInt(1)) => seqCat(List.fill(bitWidth(dt).intValue)(mask)) @@ -44,7 +45,7 @@ object toBitMask { object createMask { def apply(dt: Type): Type = dt match { case t: VectorType => VectorType(apply(t.tpe), t.size) - case t: BundleType => BundleType(t.fields map (f => f copy (tpe=apply(f.tpe)))) + case t: BundleType => BundleType(t.fields.map(f => f.copy(tpe = apply(f.tpe)))) case GroundType(w) if w == IntWidth(0) => UIntType(IntWidth(0)) case t: GroundType => BoolType } @@ -56,27 +57,33 @@ object MemPortUtils { type Modules = collection.mutable.ArrayBuffer[DefModule] def defaultPortSeq(mem: DefMemory): Seq[Field] = Seq( - Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1) max 1))), + Field("addr", Default, UIntType(IntWidth(getUIntWidth(mem.depth - 1).max(1)))), Field("en", Default, BoolType), Field("clk", Default, ClockType) ) // Todo: merge it with memToBundle def memType(mem: DefMemory): BundleType = { - val rType = BundleType(defaultPortSeq(mem) :+ - Field("data", Flip, mem.dataType)) - val wType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("data", Default, mem.dataType), - Field("mask", Default, createMask(mem.dataType)))) - val rwType = BundleType(defaultPortSeq(mem) ++ Seq( - Field("rdata", Flip, mem.dataType), - Field("wmode", Default, BoolType), - Field("wdata", Default, mem.dataType), - Field("wmask", Default, createMask(mem.dataType)))) + val rType = BundleType( + defaultPortSeq(mem) :+ + Field("data", Flip, mem.dataType) + ) + val wType = BundleType( + defaultPortSeq(mem) ++ Seq(Field("data", Default, mem.dataType), Field("mask", Default, createMask(mem.dataType))) + ) + val rwType = BundleType( + defaultPortSeq(mem) ++ Seq( + Field("rdata", Flip, mem.dataType), + Field("wmode", Default, BoolType), + Field("wdata", Default, mem.dataType), + Field("wmask", Default, createMask(mem.dataType)) + ) + ) BundleType( - (mem.readers map (Field(_, Flip, rType))) ++ - (mem.writers map (Field(_, Flip, wType))) ++ - (mem.readwriters map (Field(_, Flip, rwType)))) + (mem.readers.map(Field(_, Flip, rType))) ++ + (mem.writers.map(Field(_, Flip, wType))) ++ + (mem.readwriters.map(Field(_, Flip, rwType))) + ) } def memPortField(s: DefMemory, p: String, f: String): WSubField = { diff --git a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala index c51a0adc..30529119 100644 --- a/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala +++ b/src/main/scala/firrtl/passes/memlib/RenameAnnotatedMemoryPorts.scala @@ -9,27 +9,27 @@ import firrtl.Mappers._ import MemPortUtils._ import MemTransformUtils._ - /** Changes memory port names to standard port names (i.e. RW0 instead T_408) - */ + */ object RenameAnnotatedMemoryPorts extends Pass { + /** Renames memory ports to a standard naming scheme: - * - R0, R1, ... for each read port - * - W0, W1, ... for each write port - * - RW0, RW1, ... for each readwrite port - */ + * - R0, R1, ... for each read port + * - W0, W1, ... for each write port + * - RW0, RW1, ... for each readwrite port + */ def createMemProto(m: DefAnnotatedMemory): DefAnnotatedMemory = { - val rports = m.readers.indices map (i => s"R$i") - val wports = m.writers.indices map (i => s"W$i") - val rwports = m.readwriters.indices map (i => s"RW$i") - m copy (readers = rports, writers = wports, readwriters = rwports) + val rports = m.readers.indices.map(i => s"R$i") + val wports = m.writers.indices.map(i => s"W$i") + val rwports = m.readwriters.indices.map(i => s"RW$i") + m.copy(readers = rports, writers = wports, readwriters = rwports) } /** Maps the serialized form of all memory port field names to the - * corresponding new memory port field Expression. - * E.g.: - * - ("m.read.addr") becomes (m.R0.addr) - */ + * corresponding new memory port field Expression. + * E.g.: + * - ("m.read.addr") becomes (m.R0.addr) + */ def getMemPortMap(m: DefAnnotatedMemory, memPortMap: MemPortMap): Unit = { val defaultFields = Seq("addr", "en", "clk") val rFields = defaultFields :+ "data" @@ -37,7 +37,10 @@ object RenameAnnotatedMemoryPorts extends Pass { val rwFields = defaultFields ++ Seq("wmode", "wdata", "rdata", "wmask") def updateMemPortMap(ports: Seq[String], fields: Seq[String], newPortKind: String): Unit = - for ((p, i) <- ports.zipWithIndex; f <- fields) { + for { + (p, i) <- ports.zipWithIndex + f <- fields + } { val newPort = WSubField(WRef(m.name), newPortKind + i) val field = WSubField(newPort, f) memPortMap(s"${m.name}.$p.$f") = field @@ -55,16 +58,16 @@ object RenameAnnotatedMemoryPorts extends Pass { val updatedMem = createMemProto(m) getMemPortMap(m, memPortMap) updatedMem - case s => s map updateMemStmts(memPortMap) + case s => s.map(updateMemStmts(memPortMap)) } /** Replaces candidate memories and their references with standard port names - */ + */ def updateMemMods(m: DefModule) = { val memPortMap = new MemPortMap - (m map updateMemStmts(memPortMap) - map updateStmtRefs(memPortMap)) + (m.map(updateMemStmts(memPortMap)) + .map(updateStmtRefs(memPortMap))) } - def run(c: Circuit) = c copy (modules = c.modules map updateMemMods) + def run(c: Circuit) = c.copy(modules = c.modules.map(updateMemMods)) } diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala index bfbc163a..fc381e88 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemMacros.scala @@ -13,7 +13,6 @@ import firrtl.annotations._ import firrtl.stage.Forms import wiring._ - /** Annotates the name of the pins to add for WiringTransform */ case class PinAnnotation(pins: Seq[String]) extends NoTargetAnnotation @@ -35,14 +34,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM /** Return true if mask granularity is per bit, false if per byte or unspecified */ private def getFillWMask(mem: DefAnnotatedMemory) = mem.maskGran match { - case None => false + case None => false case Some(v) => v == 1 } private def rPortToBundle(mem: DefAnnotatedMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType)) + defaultPortSeq(mem) :+ Field("data", Flip, mem.dataType) + ) private def rPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( - defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType))) + defaultPortSeq(mem) :+ Field("data", Flip, flattenType(mem.dataType)) + ) /** Catch incorrect memory instantiations when there are masked memories with unsupported aggregate types. * @@ -82,7 +83,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM ) private def wPortToFlattenBundle(mem: DefAnnotatedMemory) = BundleType( (defaultPortSeq(mem) :+ Field("data", Default, flattenType(mem.dataType))) ++ (mem.maskGran match { - case None => Nil + case None => Nil case Some(_) if getFillWMask(mem) => Seq(Field("mask", Default, flattenType(mem.dataType))) case Some(_) => { checkMaskDatatype(mem) @@ -111,7 +112,7 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM Field("wdata", Default, flattenType(mem.dataType)), Field("rdata", Flip, flattenType(mem.dataType)) ) ++ (mem.maskGran match { - case None => Nil + case None => Nil case Some(_) if (getFillWMask(mem)) => Seq(Field("wmask", Default, flattenType(mem.dataType))) case Some(_) => { checkMaskDatatype(mem) @@ -122,32 +123,34 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM def memToBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToBundle(s)))) + s.writers.map(Field(_, Flip, wPortToBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToBundle(s))) + ) def memToFlattenBundle(s: DefAnnotatedMemory) = BundleType( s.readers.map(Field(_, Flip, rPortToFlattenBundle(s))) ++ - s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ - s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s)))) + s.writers.map(Field(_, Flip, wPortToFlattenBundle(s))) ++ + s.readwriters.map(Field(_, Flip, rwPortToFlattenBundle(s))) + ) /** Creates a wrapper module and external module to replace a candidate memory - * The wrapper module has the same type as the memory it replaces - * The external module - */ + * The wrapper module has the same type as the memory it replaces + * The external module + */ def createMemModule(m: DefAnnotatedMemory, wrapperName: String): Seq[DefModule] = { assert(m.dataType != UnknownType) val wrapperIoType = memToBundle(m) - val wrapperIoPorts = wrapperIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val wrapperIoPorts = wrapperIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) // Creates a type with the write/readwrite masks omitted if necessary val bbIoType = memToFlattenBundle(m) - val bbIoPorts = bbIoType.fields map (f => Port(NoInfo, f.name, Input, f.tpe)) + val bbIoPorts = bbIoType.fields.map(f => Port(NoInfo, f.name, Input, f.tpe)) val bbRef = WRef(m.name, bbIoType) val hasMask = m.maskGran.isDefined val fillMask = getFillWMask(m) def portRef(p: String) = WRef(p, field_type(wrapperIoType, p)) val stmts = Seq(WDefInstance(NoInfo, m.name, m.name, UnknownType)) ++ - (m.readers flatMap (r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ - (m.writers flatMap (w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ - (m.readwriters flatMap (rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) + (m.readers.flatMap(r => adaptReader(portRef(r), WSubField(bbRef, r)))) ++ + (m.writers.flatMap(w => adaptWriter(portRef(w), WSubField(bbRef, w), hasMask, fillMask))) ++ + (m.readwriters.flatMap(rw => adaptReadWriter(portRef(rw), WSubField(bbRef, rw), hasMask, fillMask))) val wrapper = Module(NoInfo, wrapperName, wrapperIoPorts, Block(stmts)) val bb = ExtModule(NoInfo, m.name, bbIoPorts, m.name, Seq.empty) // TODO: Annotate? -- use actual annotation map @@ -160,16 +163,16 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM // TODO(shunshou): get rid of copy pasta // Connects the clk, en, and addr fields from the wrapperPort to the bbPort def defaultConnects(wrapperPort: WRef, bbPort: WSubField): Seq[Connect] = - Seq("clk", "en", "addr") map (f => connectFields(bbPort, f, wrapperPort, f)) + Seq("clk", "en", "addr").map(f => connectFields(bbPort, f, wrapperPort, f)) // Generates mask bits (concatenates an aggregate to ground type) // depending on mask granularity (# bits = data width / mask granularity) def maskBits(mask: WSubField, dataType: Type, fillMask: Boolean): Expression = if (fillMask) toBitMask(mask, dataType) else toBits(mask) - def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = + def adaptReader(wrapperPort: WRef, bbPort: WSubField): Seq[Statement] = defaultConnects(wrapperPort, bbPort) :+ - fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) + fromBits(WSubField(wrapperPort, "data"), WSubField(bbPort, "data")) def adaptWriter(wrapperPort: WRef, bbPort: WSubField, hasMask: Boolean, fillMask: Boolean): Seq[Statement] = { val wrapperData = WSubField(wrapperPort, "data") @@ -177,11 +180,12 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM Connect(NoInfo, WSubField(bbPort, "data"), toBits(wrapperData)) hasMask match { case false => defaultSeq - case true => defaultSeq :+ Connect( - NoInfo, - WSubField(bbPort, "mask"), - maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) - ) + case true => + defaultSeq :+ Connect( + NoInfo, + WSubField(bbPort, "mask"), + maskBits(WSubField(wrapperPort, "mask"), wrapperData.tpe, fillMask) + ) } } @@ -190,61 +194,67 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val defaultSeq = defaultConnects(wrapperPort, bbPort) ++ Seq( fromBits(WSubField(wrapperPort, "rdata"), WSubField(bbPort, "rdata")), connectFields(bbPort, "wmode", wrapperPort, "wmode"), - Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData))) + Connect(NoInfo, WSubField(bbPort, "wdata"), toBits(wrapperWData)) + ) hasMask match { case false => defaultSeq - case true => defaultSeq :+ Connect( - NoInfo, - WSubField(bbPort, "wmask"), - maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) - ) + case true => + defaultSeq :+ Connect( + NoInfo, + WSubField(bbPort, "wmask"), + maskBits(WSubField(wrapperPort, "wmask"), wrapperWData.tpe, fillMask) + ) } } /** Mapping from (module, memory name) pairs to blackbox names */ private type NameMap = collection.mutable.HashMap[(String, String), String] + /** Construct NameMap by assigning unique names for each memory blackbox */ def constructNameMap(namespace: Namespace, nameMap: NameMap, mname: String)(s: Statement): Statement = { s match { - case m: DefAnnotatedMemory => m.memRef match { - case None => nameMap(mname -> m.name) = namespace newName m.name - case Some(_) => - } + case m: DefAnnotatedMemory => + m.memRef match { + case None => nameMap(mname -> m.name) = namespace.newName(m.name) + case Some(_) => + } case _ => } - s map constructNameMap(namespace, nameMap, mname) + s.map(constructNameMap(namespace, nameMap, mname)) } - def updateMemStmts(namespace: Namespace, - nameMap: NameMap, - mname: String, - memPortMap: MemPortMap, - memMods: Modules) - (s: Statement): Statement = s match { + def updateMemStmts( + namespace: Namespace, + nameMap: NameMap, + mname: String, + memPortMap: MemPortMap, + memMods: Modules + )(s: Statement + ): Statement = s match { case m: DefAnnotatedMemory => if (m.maskGran.isEmpty) { - m.writers foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } - m.readwriters foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } + m.writers.foreach { w => memPortMap(s"${m.name}.$w.mask") = EmptyExpression } + m.readwriters.foreach { w => memPortMap(s"${m.name}.$w.wmask") = EmptyExpression } } m.memRef match { case None => // prototype mem val newWrapperName = nameMap(mname -> m.name) - val newMemBBName = namespace newName s"${newWrapperName}_ext" - val newMem = m copy (name = newMemBBName) + val newMemBBName = namespace.newName(s"${newWrapperName}_ext") + val newMem = m.copy(name = newMemBBName) memMods ++= createMemModule(newMem, newWrapperName) WDefInstance(m.info, m.name, newWrapperName, UnknownType) case Some((module, mem)) => WDefInstance(m.info, m.name, nameMap(module -> mem), UnknownType) } - case sx => sx map updateMemStmts(namespace, nameMap, mname, memPortMap, memMods) + case sx => sx.map(updateMemStmts(namespace, nameMap, mname, memPortMap, memMods)) } def updateMemMods(namespace: Namespace, nameMap: NameMap, memMods: Modules)(m: DefModule) = { val memPortMap = new MemPortMap - (m map updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods) - map updateStmtRefs(memPortMap)) + (m.map(updateMemStmts(namespace, nameMap, m.name, memPortMap, memMods)) + .map(updateStmtRefs(memPortMap))) } def execute(state: CircuitState): CircuitState = { @@ -252,15 +262,15 @@ class ReplaceMemMacros(writer: ConfWriter) extends Transform with DependencyAPIM val namespace = Namespace(c) val memMods = new Modules val nameMap = new NameMap - c.modules map (m => m map constructNameMap(namespace, nameMap, m.name)) - val modules = c.modules map updateMemMods(namespace, nameMap, memMods) + c.modules.map(m => m.map(constructNameMap(namespace, nameMap, m.name))) + val modules = c.modules.map(updateMemMods(namespace, nameMap, memMods)) // print conf writer.serialize() val pannos = state.annotations.collect { case a: PinAnnotation => a } val pins = pannos match { - case Seq() => Nil + case Seq() => Nil case Seq(PinAnnotation(pins)) => pins - case _ => throwInternalError("Something went wrong") + case _ => throwInternalError("Something went wrong") } val annos = pins.foldLeft(Seq[Annotation]()) { (seq, pin) => seq ++ memMods.collect { diff --git a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala index 87321ea0..79e07640 100644 --- a/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala +++ b/src/main/scala/firrtl/passes/memlib/ReplaceMemTransform.scala @@ -7,7 +7,7 @@ import firrtl._ import firrtl.annotations._ import firrtl.options.{HasShellOptions, ShellOption} import Utils.error -import java.io.{File, CharArrayWriter, PrintWriter} +import java.io.{CharArrayWriter, File, PrintWriter} import wiring._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} @@ -50,7 +50,15 @@ class ConfWriter(filename: String) { // assert that we don't overflow going from BigInt to Int conversion require(bitWidth(m.dataType) <= Int.MaxValue) m.maskGran.foreach { case x => require(x <= Int.MaxValue) } - val conf = MemConf(m.name, m.depth, bitWidth(m.dataType).toInt, m.readers.length, m.writers.length, m.readwriters.length, m.maskGran.map(_.toInt)) + val conf = MemConf( + m.name, + m.depth, + bitWidth(m.dataType).toInt, + m.readers.length, + m.writers.length, + m.readwriters.length, + m.maskGran.map(_.toInt) + ) outputBuffer.append(conf.toString) } def serialize() = { @@ -113,27 +121,31 @@ class ReplSeqMem extends Transform with HasShellOptions with DependencyAPIMigrat val options = Seq( new ShellOption[String]( longOption = "repl-seq-mem", - toAnnotationSeq = (a: String) => Seq( passes.memlib.ReplSeqMemAnnotation.parse(a), - RunFirrtlTransformAnnotation(new ReplSeqMem) ), + toAnnotationSeq = + (a: String) => Seq(passes.memlib.ReplSeqMemAnnotation.parse(a), RunFirrtlTransformAnnotation(new ReplSeqMem)), helpText = "Blackbox and emit a configuration file for each sequential memory", shortOption = Some("frsq"), - helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") ) ) + helpValueName = Some("-c:<circuit>:-i:<file>:-o:<file>") + ) + ) def transforms(inConfigFile: Option[YamlFileReader], outConfigFile: ConfWriter): Seq[Transform] = - Seq(new SimpleMidTransform(Legalize), - new SimpleMidTransform(ToMemIR), - new SimpleMidTransform(ResolveMaskGranularity), - new SimpleMidTransform(RenameAnnotatedMemoryPorts), - new ResolveMemoryReference, - new CreateMemoryAnnotations(inConfigFile), - new ReplaceMemMacros(outConfigFile), - new WiringTransform, - new SimpleMidTransform(RemoveEmpty), - new SimpleMidTransform(CheckInitialization), - new SimpleMidTransform(InferTypes), - Uniquify, - new SimpleMidTransform(ResolveKinds), - new SimpleMidTransform(ResolveFlows)) + Seq( + new SimpleMidTransform(Legalize), + new SimpleMidTransform(ToMemIR), + new SimpleMidTransform(ResolveMaskGranularity), + new SimpleMidTransform(RenameAnnotatedMemoryPorts), + new ResolveMemoryReference, + new CreateMemoryAnnotations(inConfigFile), + new ReplaceMemMacros(outConfigFile), + new WiringTransform, + new SimpleMidTransform(RemoveEmpty), + new SimpleMidTransform(CheckInitialization), + new SimpleMidTransform(InferTypes), + Uniquify, + new SimpleMidTransform(ResolveKinds), + new SimpleMidTransform(ResolveFlows) + ) def execute(state: CircuitState): CircuitState = { val annos = state.annotations.collect { case a: ReplSeqMemAnnotation => a } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 41c47dce..434c7602 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -28,10 +28,10 @@ object AnalysisUtils { connects(value.serialize) = WInvalid case _ => // do nothing } - s map getConnects(connects) + s.map(getConnects(connects)) } val connects = new Connects - m map getConnects(connects) + m.map(getConnects(connects)) connects } @@ -56,8 +56,8 @@ object AnalysisUtils { else if (weq(tvOrigin, fvOrigin)) tvOrigin else if (weq(fvOrigin, zero) && weq(condOrigin, tvOrigin)) condOrigin else e - case DoPrim(PrimOps.Or, args, consts, tpe) if args exists (weq(_, one)) => one - case DoPrim(PrimOps.And, args, consts, tpe) if args exists (weq(_, zero)) => zero + case DoPrim(PrimOps.Or, args, consts, tpe) if args.exists(weq(_, one)) => one + case DoPrim(PrimOps.And, args, consts, tpe) if args.exists(weq(_, zero)) => zero case DoPrim(PrimOps.Bits, args, Seq(msb, lsb), tpe) => val extractionWidth = (msb - lsb) + 1 val nodeWidth = bitWidth(args.head.tpe) @@ -69,10 +69,10 @@ object AnalysisUtils { case ValidIf(cond, value, _) => getOrigin(connects)(value) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => - connects get e.serialize match { - case Some(ex) => getOrigin(connects)(ex) - case None => e - } + connects.get(e.serialize) match { + case Some(ex) => getOrigin(connects)(ex) + case None => e + } case _ => e } } @@ -90,10 +90,9 @@ object ResolveMaskGranularity extends Pass { */ def getMaskBits(connects: Connects, wen: Expression, wmask: Expression): Option[Int] = { val wenOrigin = getOrigin(connects)(wen) - val wmaskOrigin = connects.keys filter - (_ startsWith wmask.serialize) map {s: String => getOrigin(connects, s)} + val wmaskOrigin = connects.keys.filter(_.startsWith(wmask.serialize)).map { s: String => getOrigin(connects, s) } // all wmask bits are equal to wmode/wen or all wmask bits = 1(for redundancy checking) - val redundantMask = wmaskOrigin forall (x => weq(x, wenOrigin) || weq(x, one)) + val redundantMask = wmaskOrigin.forall(x => weq(x, wenOrigin) || weq(x, one)) if (redundantMask) None else Some(wmaskOrigin.size) } @@ -103,18 +102,17 @@ object ResolveMaskGranularity extends Pass { def updateStmts(connects: Connects)(s: Statement): Statement = s match { case m: DefAnnotatedMemory => val dataBits = bitWidth(m.dataType) - val rwMasks = m.readwriters map (rw => - getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) - val wMasks = m.writers map (w => - getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) + val rwMasks = + m.readwriters.map(rw => getMaskBits(connects, memPortField(m, rw, "wmode"), memPortField(m, rw, "wmask"))) + val wMasks = m.writers.map(w => getMaskBits(connects, memPortField(m, w, "en"), memPortField(m, w, "mask"))) val maskGran = (rwMasks ++ wMasks).head match { - case None => None + case None => None case Some(maskBits) => Some(dataBits / maskBits) } m.copy(maskGran = maskGran) - case sx => sx map updateStmts(connects) + case sx => sx.map(updateStmts(connects)) } - def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m)) - def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule): DefModule = m.map(updateStmts(getConnects(m))) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(annotateModMems)) } diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala index b5ff10c6..e80e0c4a 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMemoryReference.scala @@ -14,7 +14,7 @@ case class NoDedupMemAnnotation(target: ComponentName) extends SingleTargetAnnot } /** Resolves annotation ref to memories that exactly match (except name) another memory - */ + */ class ResolveMemoryReference extends Transform with DependencyAPIMigration { override def prerequisites = Forms.MidForm @@ -45,10 +45,12 @@ class ResolveMemoryReference extends Transform with DependencyAPIMigration { /** If a candidate memory is identical except for name to another, add an * annotation that references the name of the other memory. */ - def updateMemStmts(mname: String, - existingMems: AnnotatedMemories, - noDedupMap: Map[String, Set[String]]) - (s: Statement): Statement = s match { + def updateMemStmts( + mname: String, + existingMems: AnnotatedMemories, + noDedupMap: Map[String, Set[String]] + )(s: Statement + ): Statement = s match { // If not dedupable, no need to add to existing (since nothing can dedup with it) // We just return the DefAnnotatedMemory as is in the default case below case m: DefAnnotatedMemory if dedupable(noDedupMap, mname, m.name) => diff --git a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala index 554a3572..9fe7f852 100644 --- a/src/main/scala/firrtl/passes/memlib/ToMemIR.scala +++ b/src/main/scala/firrtl/passes/memlib/ToMemIR.scala @@ -14,16 +14,17 @@ import firrtl.ir._ * - undefined read-under-write behavior */ object ToMemIR extends Pass { + /** Only annotate memories that are candidates for memory macro replacements * i.e. rw, w + r (read, write 1 cycle delay) and read-under-write "undefined." */ import ReadUnderWrite._ def updateStmts(s: Statement): Statement = s match { - case m @ DefMemory(_,_,_,_,1,1,r,w,rw,Undefined) if (w.length + rw.length) == 1 && r.length <= 1 => + case m @ DefMemory(_, _, _, _, 1, 1, r, w, rw, Undefined) if (w.length + rw.length) == 1 && r.length <= 1 => DefAnnotatedMemory(m) - case sx => sx map updateStmts + case sx => sx.map(updateStmts) } - def annotateModMems(m: DefModule) = m map updateStmts - def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule) = m.map(updateStmts) + def run(c: Circuit) = c.copy(modules = c.modules.map(annotateModMems)) } diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index dd644323..a2b14343 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -24,19 +24,19 @@ object MemDelayAndReadwriteTransformer { case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect]) // Utilities for generating hardware - def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) - def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) - def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) - def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) + def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) + def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) + def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) + def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) // Utilities for working with WithValid groups def connect(l: WithValid, r: WithValid): Seq[Connect] = { - val paired = (l.valid +: l.payload) zip (r.valid +: r.payload) + val paired = (l.valid +: l.payload).zip(r.valid +: r.payload) paired.map { case (le, re) => connect(le, re) } } def condConnect(l: WithValid, r: WithValid): Seq[Connect] = { - connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) } + connect(l.valid, r.valid) +: (l.payload.zip(r.payload)).map { case (le, re) => condConnect(r.valid)(le, re) } } // Internal representation of a pipeline stage with an associated valid signal @@ -47,20 +47,23 @@ object MemDelayAndReadwriteTransformer { private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_") // Pipeline a group of signals with an associated valid signal. Gate registers when possible. - def pipelineWithValid(ns: Namespace)( - clock: Expression, - depth: Int, - src: WithValid, - nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = { + def pipelineWithValid( + ns: Namespace + )(clock: Expression, + depth: Int, + src: WithValid, + nameTemplate: Option[WithValid] = None + ): (WithValid, Seq[Statement], Seq[Connect]) = { def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e) val template = nameTemplate.getOrElse(src) - val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev => - def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) - val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) - val regs = (ref.valid +: ref.payload).map(asReg) - PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) + val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { + case prev => + def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) + val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) + val regs = (ref.valid +: ref.payload).map(asReg) + PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) } (stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns)) } @@ -84,10 +87,10 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { private def findMemConns(s: Statement): Unit = s match { case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr - case _ => s.foreach(findMemConns) + case _ => s.foreach(findMemConns) } - private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match { + private def swapMemRefs(e: Expression): Expression = e.map(swapMemRefs) match { case sf: WSubField => exprReplacements.getOrElse(we(sf), sf) case ex => ex } @@ -105,51 +108,57 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0 val wCmdDelay = mem.writeLatency - 1 - val readStmts = (mem.readers ++ mem.readwriters).map { case r => - def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) - def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) - val clk = oldDriver("clk") - - // Pack sources of read command inputs into WithValid object -> different for readwriter - val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") - val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) - val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) - val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) - val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) - - // Pipeline read response using *last* command pipe stage enable as the valid signal - val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) - val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names - val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) - - // Make sure references to the read data get appropriately substituted - val oldRDataName = if (rMap.contains(r)) "rdata" else "data" - exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head - - // Return all statements; they're separated so connects can go after all declarations - SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) + val readStmts = (mem.readers ++ mem.readwriters).map { + case r => + def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) + def newField(f: String) = memPortField(newMem, rMap.getOrElse(r, r), f) + val clk = oldDriver("clk") + + // Pack sources of read command inputs into WithValid object -> different for readwriter + val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") + val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) + val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) + val (cmdPiped, cmdDecls, cmdConns) = + pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Pipeline read response using *last* command pipe stage enable as the valid signal + val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) + val respPipeNameTemplate = + Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names + val (respPiped, respDecls, respConns) = + pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) + + // Make sure references to the read data get appropriately substituted + val oldRDataName = if (rMap.contains(r)) "rdata" else "data" + exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) } - val writeStmts = (mem.writers ++ mem.readwriters).map { case w => - def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) - def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) - val clk = oldDriver("clk") - - // Pack sources of write command inputs into WithValid object -> different for readwriter - val cmdSrc = if (wMap.contains(w)) { - val en = AND(oldDriver("en"), oldDriver("wmode")) - WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) - } else { - WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) - } - - // Pipeline write command, connect to memory - val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) - val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) - val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) - - // Return all statements; they're separated so connects can go after all declarations - SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) + val writeStmts = (mem.writers ++ mem.readwriters).map { + case w => + def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) + def newField(f: String) = memPortField(newMem, wMap.getOrElse(w, w), f) + val clk = oldDriver("clk") + + // Pack sources of write command inputs into WithValid object -> different for readwriter + val cmdSrc = if (wMap.contains(w)) { + val en = AND(oldDriver("en"), oldDriver("wmode")) + WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) + } else { + WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) + } + + // Pipeline write command, connect to memory + val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) + val (cmdPiped, cmdDecls, cmdConns) = + pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) } newConns ++= (readStmts ++ writeStmts).flatMap(_.conns) @@ -171,8 +180,7 @@ object VerilogMemDelays extends Pass { override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) override val optionalPrerequisiteOf = - Seq( Dependency[VerilogEmitter], - Dependency[SystemVerilogEmitter] ) + Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) override def invalidates(a: Transform): Boolean = a match { case _: transforms.ConstantPropagation | ResolveFlows => true @@ -180,5 +188,5 @@ object VerilogMemDelays extends Pass { } def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed - def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) } diff --git a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala index a43adfe2..b5f91e7b 100644 --- a/src/main/scala/firrtl/passes/memlib/YamlUtils.scala +++ b/src/main/scala/firrtl/passes/memlib/YamlUtils.scala @@ -6,7 +6,6 @@ import net.jcazevedo.moultingyaml._ import java.io.{CharArrayWriter, File, PrintWriter} import firrtl.FileUtils - object CustomYAMLProtocol extends DefaultYamlProtocol { // bottom depends on top implicit val _pin = yamlFormat1(Pin) @@ -20,17 +19,15 @@ case class Source(name: String, module: String) case class Top(name: String) case class Config(pin: Pin, source: Source, top: Top) - class YamlFileReader(file: String) { - def parse[A](implicit reader: YamlReader[A]) : Seq[A] = { + def parse[A](implicit reader: YamlReader[A]): Seq[A] = { if (new File(file).exists) { val yamlString = FileUtils.getText(file) - yamlString.parseYamls flatMap (x => - try Some(reader read x) + yamlString.parseYamls.flatMap(x => + try Some(reader.read(x)) catch { case e: Exception => None } ) - } - else sys.error("Yaml file doesn't exist!") + } else sys.error("Yaml file doesn't exist!") } } @@ -38,11 +35,11 @@ class YamlFileWriter(file: String) { val outputBuffer = new CharArrayWriter val separator = "--- \n" def append(in: YamlValue): Unit = { - outputBuffer append s"$separator${in.prettyPrint}" + outputBuffer.append(s"$separator${in.prettyPrint}") } def dump(): Unit = { val outputFile = new PrintWriter(file) - outputFile write outputBuffer.toString + outputFile.write(outputBuffer.toString) outputFile.close() } } |
