aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib/InferReadWrite.scala')
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala101
1 files changed, 55 insertions, 46 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 4847a698..e290633e 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -10,12 +10,11 @@ import firrtl.PrimOps._
import firrtl.Utils.{one, zero, BoolType}
import firrtl.options.{HasShellOptions, ShellOption}
import MemPortUtils.memPortField
-import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin}
+import firrtl.passes.memlib.AnalysisUtils.{getConnects, getOrigin, Connects}
import WrappedExpression.weq
import annotations._
import firrtl.stage.{Forms, RunFirrtlTransformAnnotation}
-
case object InferReadWriteAnnotation extends NoTargetAnnotation
// This pass examine the enable signals of the read & write ports of memories
@@ -40,12 +39,13 @@ object InferReadWritePass extends Pass {
getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval)
// Visit each term of AND operation
case DoPrim(op, args, consts, tpe) if op == And =>
- e +: (args flatMap getProductTerms(connects))
+ e +: (args.flatMap(getProductTerms(connects)))
// Visit connected nodes to references
- case _: WRef | _: WSubField | _: WSubIndex => connects get e match {
- case None => Seq(e)
- case Some(ex) => e +: getProductTerms(connects)(ex)
- }
+ case _: WRef | _: WSubField | _: WSubIndex =>
+ connects.get(e) match {
+ case None => Seq(e)
+ case Some(ex) => e +: getProductTerms(connects)(ex)
+ }
// Otherwise just return itself
case _ => Seq(e)
}
@@ -58,96 +58,103 @@ object InferReadWritePass extends Pass {
// b ?= Eq(a, 0) or b ?= Eq(0, a)
case (_, DoPrim(Eq, args, _, _)) =>
weq(args.head, a) && weq(args(1), zero) ||
- weq(args(1), a) && weq(args.head, zero)
+ weq(args(1), a) && weq(args.head, zero)
// a ?= Eq(b, 0) or b ?= Eq(0, a)
case (DoPrim(Eq, args, _, _), _) =>
weq(args.head, b) && weq(args(1), zero) ||
- weq(args(1), b) && weq(args.head, zero)
+ weq(args(1), b) && weq(args.head, zero)
case _ => false
}
-
def replaceExp(repl: Netlist)(e: Expression): Expression =
- e map replaceExp(repl) match {
- case ex: WSubField => repl getOrElse (ex.serialize, ex)
+ e.map(replaceExp(repl)) match {
+ case ex: WSubField => repl.getOrElse(ex.serialize, ex)
case ex => ex
}
def replaceStmt(repl: Netlist)(s: Statement): Statement =
- s map replaceStmt(repl) map replaceExp(repl) match {
+ s.map(replaceStmt(repl)).map(replaceExp(repl)) match {
case Connect(_, EmptyExpression, _) => EmptyStmt
- case sx => sx
+ case sx => sx
}
- def inferReadWriteStmt(connects: Connects,
- repl: Netlist,
- stmts: Statements)
- (s: Statement): Statement = s match {
+ def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements)(s: Statement): Statement = s match {
// infer readwrite ports only for non combinational memories
case mem: DefMemory if mem.readLatency > 0 =>
val readers = new PortSet
val writers = new PortSet
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
- for (w <- mem.writers ; r <- mem.readers) {
+ for {
+ w <- mem.writers
+ r <- mem.readers
+ } {
val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
- val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms.exists(b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
- val rw = namespace newName "rw"
+ val rw = namespace.newName("rw")
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
readers += r
writers += w
- repl(memPortField(mem, r, "clk")) = EmptyExpression
- repl(memPortField(mem, r, "en")) = EmptyExpression
+ repl(memPortField(mem, r, "clk")) = EmptyExpression
+ repl(memPortField(mem, r, "en")) = EmptyExpression
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
- repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = EmptyExpression
+ repl(memPortField(mem, w, "clk")) = EmptyExpression
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
- stmts += Connect(NoInfo, WSubField(rwExp, "en"),
- DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
- connects(memPortField(mem, w, "en"))), Nil, BoolType))
- stmts += Connect(NoInfo, WSubField(rwExp, "addr"),
- Mux(connects(memPortField(mem, w, "en")),
- connects(memPortField(mem, w, "addr")),
- connects(memPortField(mem, r, "addr")), UnknownType))
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "en"),
+ DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType)
+ )
+ stmts += Connect(
+ NoInfo,
+ WSubField(rwExp, "addr"),
+ Mux(
+ connects(memPortField(mem, w, "en")),
+ connects(memPortField(mem, w, "addr")),
+ connects(memPortField(mem, r, "addr")),
+ UnknownType
+ )
+ )
}
}
- if (readwriters.isEmpty) mem else mem copy (
- readers = mem.readers filterNot readers,
- writers = mem.writers filterNot writers,
- readwriters = mem.readwriters ++ readwriters)
- case sx => sx map inferReadWriteStmt(connects, repl, stmts)
+ if (readwriters.isEmpty) mem
+ else
+ mem.copy(
+ readers = mem.readers.filterNot(readers),
+ writers = mem.writers.filterNot(writers),
+ readwriters = mem.readwriters ++ readwriters
+ )
+ case sx => sx.map(inferReadWriteStmt(connects, repl, stmts))
}
def inferReadWrite(m: DefModule) = {
val connects = getConnects(m)
val repl = new Netlist
val stmts = new Statements
- (m map inferReadWriteStmt(connects, repl, stmts)
- map replaceStmt(repl)) match {
+ (m.map(inferReadWriteStmt(connects, repl, stmts))
+ .map(replaceStmt(repl))) match {
case m: ExtModule => m
- case m: Module => m copy (body = Block(m.body +: stmts.toSeq))
+ case m: Module => m.copy(body = Block(m.body +: stmts.toSeq))
}
}
- def run(c: Circuit) = c copy (modules = c.modules map inferReadWrite)
+ def run(c: Circuit) = c.copy(modules = c.modules.map(inferReadWrite))
}
// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl"
// To use this transform, circuit name should be annotated with its TransId.
-class InferReadWrite extends Transform
- with DependencyAPIMigration
- with SeqTransformBased
- with HasShellOptions {
+class InferReadWrite extends Transform with DependencyAPIMigration with SeqTransformBased with HasShellOptions {
override def prerequisites = Forms.MidForm
override def optionalPrerequisites = Seq.empty
@@ -159,7 +166,9 @@ class InferReadWrite extends Transform
longOption = "infer-rw",
toAnnotationSeq = (_: Unit) => Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)),
helpText = "Enable read/write port inference for memories",
- shortOption = Some("firw") ) )
+ shortOption = Some("firw")
+ )
+ )
def transforms = Seq(
InferReadWritePass,