aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/memlib
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib')
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala10
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala3
2 files changed, 8 insertions, 5 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 661d6df4..2d1d7f6b 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -89,11 +89,12 @@ object InferReadWritePass extends Pass {
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
for (w <- mem.writers ; r <- mem.readers) {
- val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
- val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
+ val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
+ val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
- if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) {
+ if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
val rw = namespace newName "rw"
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
@@ -104,10 +105,11 @@ object InferReadWritePass extends Pass {
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode")
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
+ stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
stmts += Connect(NoInfo, WSubField(rwExp, "en"),
DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 0424b1dd..e254dcc9 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -65,7 +65,8 @@ object AnalysisUtils {
if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects)(args.head)
- case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
+ // It is a correct optimization to treat ValidIf as a connection
+ case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
connects get e.serialize match {