// SPDX-License-Identifier: Apache-2.0 package firrtl.passes package memlib import firrtl._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} import firrtl.options.{HasShellOptions, ShellOption} import MemPortUtils.memPortField import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects} import WrappedExpression.weq import annotations._ import firrtl.stage.{Forms, RunFirrtlTransformAnnotation} case object InferReadWriteAnnotation extends NoTargetAnnotation // This pass examine the enable signals of the read & write ports of memories // whose readLatency is greater than 1 (usually SeqMem in Chisel). // If any product term of the enable signal of the read port is the complement // of any product term of the enable signal of the write port, then the readwrite // port is inferred. object InferReadWritePass extends Pass { type Netlist = collection.mutable.HashMap[String, Expression] type Statements = collection.mutable.ArrayBuffer[Statement] type PortSet = collection.mutable.HashSet[String] private implicit def toString(e: Expression): String = e.serialize def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match { // No ConstProp yet... // TODO: do const prop before case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) => getProductTerms(connects)(cond) case Mux(cond, tval, fval, _) if weq(fval, zero) => getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval) // Visit each term of AND operation case DoPrim(op, args, consts, tpe) if op == And => e +: (args.flatMap(getProductTerms(connects))) // Visit connected nodes to references case _: WRef | _: WSubField | _: WSubIndex => connects.get(e) match { case None => Seq(e) case Some(ex) => e +: getProductTerms(connects)(ex) } // Otherwise just return itself case _ => Seq(e) } def checkComplement(a: Expression, b: Expression) = (a, b) match { // b ?= Not(a) case (_, DoPrim(Not, args, _, _)) => weq(args.head, a) // a ?= Not(b) case (DoPrim(Not, args, _, _), _) => weq(args.head, b) // b ?= Eq(a, 0) or b ?= Eq(0, a) case (_, DoPrim(Eq, args, _, _)) => weq(args.head, a) && weq(args(1), zero) || weq(args(1), a) && weq(args.head, zero) // a ?= Eq(b, 0) or b ?= Eq(0, a) case (DoPrim(Eq, args, _, _), _) => weq(args.head, b) && weq(args(1), zero) || weq(args(1), b) && weq(args.head, zero) case _ => false } def replaceExp(repl: Netlist)(e: Expression): Expression = e.map(replaceExp(repl)) match { case ex: WSubField => repl.getOrElse(ex.serialize, ex) case ex => ex } def replaceStmt(repl: Netlist)(s: Statement): Statement = s.map(replaceStmt(repl)).map(replaceExp(repl)) match { case Connect(_, EmptyExpression, _) => EmptyStmt case sx => sx } /* If the ports share the same address in an undefined-collision SyncReadMem, reads issued while the write * is enabled are *always* undefined; we may treat the read as if it were gated by the complement of w.en. * Though not a strict requirement, this currently applies only to single-cycle read/write memories. * N.B. for aggregate-typed memories, the spec is conservative and 'undefined' is not a function of the * write mask, allowing optimization regardless of mask value. This must be revisited if the spec changes. */ private def canOptimizeCollidingRW(mem: DefMemory): Boolean = { mem.readUnderWrite == ReadUnderWrite.Undefined && mem.readLatency == 1 && mem.writeLatency == 1 } def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match { // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => val readers = new PortSet val writers = new PortSet val readwriters = collection.mutable.ArrayBuffer[String]() val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) for { w <- mem.writers r <- mem.readers } { val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en")) val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en")) val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) val waddr = getOrigin(connects)(memPortField(mem, w, "addr")) val raddr = getOrigin(connects)(memPortField(mem, r, "addr")) val optimizeCollision = (weq(waddr, raddr) && canOptimizeCollidingRW(mem)) if (weq(wclk, rclk) && (proofOfMutualExclusion.nonEmpty || optimizeCollision)) { val rw = namespace.newName("rw") val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw readers += r writers += w repl(memPortField(mem, r, "clk")) = EmptyExpression repl(memPortField(mem, r, "en")) = EmptyExpression repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") repl(memPortField(mem, w, "clk")) = WSubField(rwExp, "clk") repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") stmts += Connect( NoInfo, WSubField(rwExp, "en"), DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType) ) if (optimizeCollision) { repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode") repl(memPortField(mem, w, "addr")) = WSubField(rwExp, "addr") } else { repl(memPortField(mem, w, "en")) = EmptyExpression repl(memPortField(mem, w, "addr")) = EmptyExpression stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) stmts += Connect( NoInfo, WSubField(rwExp, "addr"), Mux( connects(memPortField(mem, w, "en")), connects(memPortField(mem, w, "addr")), connects(memPortField(mem, r, "addr")), UnknownType ) ) } } } if (readwriters.isEmpty) mem else mem.copy( readers = mem.readers.filterNot(readers), writers = mem.writers.filterNot(writers), readwriters = mem.readwriters ++ readwriters ) case sx => sx.map(inferReadWriteStmt(connects, repl, stmts)) } def inferReadWrite(m: DefModule) = { val connects = getConnects(m) val repl = new Netlist val stmts = new Statements (m.map(inferReadWriteStmt(connects, repl, stmts)) .map(replaceStmt(repl))) match { case m: ExtModule => m case m: Module => m.copy(body = Block(m.body +: stmts.toSeq)) } } def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite)) } // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions { override def prerequisites = Forms.MidForm override def optionalPrerequisites = Seq.empty override def optionalPrerequisiteOf = Forms.MidEmitters override def invalidates(a: Transform) = false val options = Seq( new ShellOption[Unit]( longOption = "infer-rw", toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)), helpText = "Enable read/write port inference for memories", shortOption = Some("firw") ) ) def transforms = Seq( InferReadWritePass, CheckInitialization, InferTypes, ResolveKinds, ResolveFlows ) def execute(state: CircuitState): CircuitState = { val runTransform = state.annotations.contains(InferReadWriteAnnotation) if (runTransform) { val ret = runTransforms(state) state.copy(circuit = ret.circuit, annotations = ret.annotations, renames = ret.renames) } else { state } } }