diff options
| author | chick | 2020-08-14 19:47:53 -0700 |
|---|---|---|
| committer | Jack Koenig | 2020-08-14 19:47:53 -0700 |
| commit | 6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch) | |
| tree | 2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/passes/memlib/InferReadWrite.scala | |
| parent | b516293f703c4de86397862fee1897aded2ae140 (diff) | |
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib/InferReadWrite.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/InferReadWrite.scala | 101 |
1 files changed, 55 insertions, 46 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 4847a698..e290633e 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -10,12 +10,11 @@ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} import firrtl.options.{HasShellOptions, ShellOption} import MemPortUtils.memPortField -import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} +import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects} import WrappedExpression.weq import annotations._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} - case object InferReadWriteAnnotation extends NoTargetAnnotation // This pass examine the enable signals of the read & write ports of memories @@ -40,12 +39,13 @@ object InferReadWritePass extends Pass { getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval) // Visit each term of AND operation case DoPrim(op, args, consts, tpe) if op == And => - e +: (args flatMap getProductTerms(connects)) + e +: (args.flatMap(getProductTerms(connects))) // Visit connected nodes to references - case _: WRef | _: WSubField | _: WSubIndex => connects get e match { - case None => Seq(e) - case Some(ex) => e +: getProductTerms(connects)(ex) - } + case _: WRef | _: WSubField | _: WSubIndex => + connects.get(e) match { + case None => Seq(e) + case Some(ex) => e +: getProductTerms(connects)(ex) + } // Otherwise just return itself case _ => Seq(e) } @@ -58,96 +58,103 @@ object InferReadWritePass extends Pass { // b ?= Eq(a, 0) or b ?= Eq(0, a) case (_, DoPrim(Eq, args, _, _)) => weq(args.head, a) && weq(args(1), zero) || - weq(args(1), a) && weq(args.head, zero) + weq(args(1), a) && weq(args.head, zero) // a ?= Eq(b, 0) or b ?= Eq(0, a) case (DoPrim(Eq, args, _, _), _) => weq(args.head, b) && weq(args(1), zero) || - weq(args(1), b) && weq(args.head, zero) + weq(args(1), b) && weq(args.head, zero) case _ => false } - def replaceExp(repl: Netlist)(e: Expression): Expression = - e map replaceExp(repl) match { - case ex: WSubField => repl getOrElse (ex.serialize, ex) + e.map(replaceExp(repl)) match { + case ex: WSubField => repl.getOrElse(ex.serialize, ex) case ex => ex } def replaceStmt(repl: Netlist)(s: Statement): Statement = - s map replaceStmt(repl) map replaceExp(repl) match { + s.map(replaceStmt(repl)).map(replaceExp(repl)) match { case Connect(_, EmptyExpression, _) => EmptyStmt - case sx => sx + case sx => sx } - def inferReadWriteStmt(connects: Connects, - repl: Netlist, - stmts: Statements) - (s: Statement): Statement = s match { + def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match { // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => val readers = new PortSet val writers = new PortSet val readwriters = collection.mutable.ArrayBuffer[String]() val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) - for (w <- mem.writers ; r <- mem.readers) { + for { + w <- mem.writers + r <- mem.readers + } { val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en")) val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en")) - val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b))) + val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) { - val rw = namespace newName "rw" + val rw = namespace.newName("rw") val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw readers += r writers += w - repl(memPortField(mem, r, "clk")) = EmptyExpression - repl(memPortField(mem, r, "en")) = EmptyExpression + repl(memPortField(mem, r, "clk")) = EmptyExpression + repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") - repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = EmptyExpression + repl(memPortField(mem, w, "clk")) = EmptyExpression + repl(memPortField(mem, w, "en")) = EmptyExpression repl(memPortField(mem, w, "addr")) = EmptyExpression repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) - stmts += Connect(NoInfo, WSubField(rwExp, "en"), - DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), - connects(memPortField(mem, w, "en"))), Nil, BoolType)) - stmts += Connect(NoInfo, WSubField(rwExp, "addr"), - Mux(connects(memPortField(mem, w, "en")), - connects(memPortField(mem, w, "addr")), - connects(memPortField(mem, r, "addr")), UnknownType)) + stmts += Connect( + NoInfo, + WSubField(rwExp, "en"), + DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType) + ) + stmts += Connect( + NoInfo, + WSubField(rwExp, "addr"), + Mux( + connects(memPortField(mem, w, "en")), + connects(memPortField(mem, w, "addr")), + connects(memPortField(mem, r, "addr")), + UnknownType + ) + ) } } - if (readwriters.isEmpty) mem else mem copy ( - readers = mem.readers filterNot readers, - writers = mem.writers filterNot writers, - readwriters = mem.readwriters ++ readwriters) - case sx => sx map inferReadWriteStmt(connects, repl, stmts) + if (readwriters.isEmpty) mem + else + mem.copy( + readers = mem.readers.filterNot(readers), + writers = mem.writers.filterNot(writers), + readwriters = mem.readwriters ++ readwriters + ) + case sx => sx.map(inferReadWriteStmt(connects, repl, stmts)) } def inferReadWrite(m: DefModule) = { val connects = getConnects(m) val repl = new Netlist val stmts = new Statements - (m map inferReadWriteStmt(connects, repl, stmts) - map replaceStmt(repl)) match { + (m.map(inferReadWriteStmt(connects, repl, stmts)) + .map(replaceStmt(repl))) match { case m: ExtModule => m - case m: Module => m copy (body = Block(m.body +: stmts.toSeq)) + case m: Module => m.copy(body = Block(m.body +: stmts.toSeq)) } } - def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite) + def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite)) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform - with DependencyAPIMigration - with SeqTransformBased - with HasShellOptions { +class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty @@ -159,7 +166,9 @@ class InferReadWrite extends Transform longOption = "infer-rw", toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)), helpText = "Enable read/write port inference for memories", - shortOption = Some("firw") ) ) + shortOption = Some("firw") + ) + ) def transforms = Seq( InferReadWritePass, |
