aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Waterman2018-07-03 15:45:00 -0700
committerJack Koenig2018-07-03 15:45:00 -0700
commitceac36d7ce1223078ca47bc097884532faacd7e1 (patch)
tree275334042d0a41d362108404ec89e18a357d4104 /src
parent2b405652a266114377816b8175ef9fad7d36ed14 (diff)
Improve code generation for smem wmode and [w]mask ports (#834)
[skip formal checks] LEC passes with Formality * Improve code generation for smem RW-port wmode port A common case for these port-enables is wen = valid & write ren = valid & !write which the RW-port transform currently turns into en = (valid & write) | (valid & !write) wmode = valid & write because it proved `wen` and `ren` are mutually exclusive via `write`. Synthesis tools can trivially optimize `en` to `valid`, so that's not a problem, but the wmode field can't be optimized if going into a black box. This PR instead sets `wmode` to whatever node was used to prove mutual exclusion, which is always a simpler expression. In this case: en = (valid & write) | (valid & !write) wmode = write * In RemoveCHIRRTL, infer mask relative to port definition Previously, it was inferred relative to the memory definition causing the mask condition to be redundantly conjoined with the enable signal. Also enable ReplSeqMems to ignore all ValidIfs (not just on Clocks) to improve QoR.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala28
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala10
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala3
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala34
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala88
5 files changed, 142 insertions, 21 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 6b3508a6..cc69be6f 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -82,12 +82,10 @@ object RemoveCHIRRTL extends Transform {
def set_enable(vec: Seq[MPort], en: String) = vec map (r =>
Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)
)
- def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
+ def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap { r =>
val tmask = createMask(sx.tpe)
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), data, tdata)) +:
- (create_exps(SubField(SubField(Reference(sx.name, ut), r.name, ut), mask, tmask))
- map (Connect(sx.info, _, zero))
- )
+ val portRef = SubField(Reference(sx.name, ut), r.name, ut)
+ Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask)))
}
val rds = (mports getOrElse (sx.name, EMPs)).readers
val wrs = (mports getOrElse (sx.name, EMPs)).writers
@@ -109,12 +107,15 @@ object RemoveCHIRRTL extends Transform {
val addrs = ArrayBuffer[String]()
val clks = ArrayBuffer[String]()
val ens = ArrayBuffer[String]()
+ val masks = ArrayBuffer.empty[Expression]
+ val portRef = SubField(Reference(sx.mem, ut), sx.name, ut)
sx.direction match {
case MReadWrite =>
- refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "rdata", "wdata", "wmask", rdwrite = true)
+ refs(sx.name) = DataRef(portRef, "rdata", "wdata", "wmask", rdwrite = true)
addrs += "addr"
clks += "clk"
ens += "en"
+ masks ++= create_exps(SubField(portRef, "wmask", createMask(sx.tpe)))
renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata")
renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata")
val es = create_all_exps(WRef(sx.name, sx.tpe))
@@ -124,10 +125,11 @@ object RemoveCHIRRTL extends Transform {
case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
}
case MWrite =>
- refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false)
+ refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false)
addrs += "addr"
clks += "clk"
ens += "en"
+ masks ++= create_exps(SubField(portRef, "mask", createMask(sx.tpe)))
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
@@ -135,12 +137,12 @@ object RemoveCHIRRTL extends Transform {
case (e, w) => renames.rename(e.serialize, w.serialize)
}
case MRead =>
- refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false)
+ refs(sx.name) = DataRef(portRef, "data", "data", "blah", rdwrite = false)
addrs += "addr"
clks += "clk"
sx.exps.head match {
case e: Reference if smems(sx.mem) =>
- raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut)
+ raddrs(e.name) = SubField(portRef, "en", ut)
case _ => ens += "en"
}
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
@@ -152,9 +154,11 @@ object RemoveCHIRRTL extends Transform {
case MInfer => // do nothing if it's not being used
}
Block(
- (addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++
- (clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++
- (ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one))))
+ (addrs map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
+ (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
+ (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++
+ masks.map(lhs => Connect(sx.info, lhs, zero))
+ )
case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
}
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 661d6df4..2d1d7f6b 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -89,11 +89,12 @@ object InferReadWritePass extends Pass {
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
for (w <- mem.writers ; r <- mem.readers) {
- val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
- val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
+ val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
+ val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
- if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) {
+ if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
val rw = namespace newName "rw"
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
@@ -104,10 +105,11 @@ object InferReadWritePass extends Pass {
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode")
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
+ stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
stmts += Connect(NoInfo, WSubField(rwExp, "en"),
DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 0424b1dd..e254dcc9 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -65,7 +65,8 @@ object AnalysisUtils {
if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects)(args.head)
- case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
+ // It is a correct optimization to treat ValidIf as a connection
+ case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
connects get e.serialize match {
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index 34e228be..bffb1b51 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -7,6 +7,7 @@ import firrtl.ir._
import firrtl.passes._
import firrtl.Mappers._
import annotations._
+import FirrtlCheckers._
class InferReadWriteSpec extends SimpleTransformSpec {
class InferReadWriteCheckException extends PassException(
@@ -138,4 +139,37 @@ circuit sram6t :
compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
}
}
+
+ "wmode" should "be simplified" in {
+ val input = """
+circuit sram6t :
+ module sram6t :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}
+
+ io is invalid
+ smem mem : UInt<32> [2048]
+ node wen = and(io.valid, io.write)
+ node ren = and(io.valid, not(io.write))
+ when wen :
+ write mport _T_14 = mem[io.addr], clock
+ _T_14 <= io.dataIn
+ node _T_16 = eq(wen, UInt<1>("h0"))
+ when _T_16 :
+ wire _T_18 : UInt
+ _T_18 is invalid
+ when ren :
+ _T_18 <= io.addr
+ node _T_20 = or(_T_18, UInt<11>("h0"))
+ node _T_21 = bits(_T_20, 10, 0)
+ read mport _T_22 = mem[_T_21], clock
+ io.dataOut <= _T_22
+""".stripMargin
+
+ val annos = Seq(memlib.InferReadWriteAnnotation)
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // Check correctness of firrtl
+ res should containLine (s"mem.rw.wmode <= wen")
+ }
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index dcc23235..2eae3580 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -8,6 +8,7 @@ import firrtl.passes._
import firrtl.transforms._
import firrtl.passes.memlib._
import annotations._
+import FirrtlCheckers._
class ReplSeqMemSpec extends SimpleTransformSpec {
def emitter = new LowFirrtlEmitter
@@ -66,7 +67,6 @@ circuit Top :
val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
- println(res.annotations)
parse(res.getEmittedCircuit.value)
(new java.io.File(confLoc)).delete()
}
@@ -181,7 +181,8 @@ circuit Top :
"asClock(a)" -> "a",
"a" -> "a",
"or(a, b)" -> "or(a, b)",
- "bits(a, 0, 0)" -> "a"
+ "bits(a, 0, 0)" -> "a",
+ "validif(a, b)" -> "b"
)
tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) }
@@ -296,9 +297,88 @@ circuit CustomMemory :
require(numExtMods == 1)
(new java.io.File(confLoc)).delete()
}
+
+ "ReplSeqMem" should "should not have a mask if there is none" in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2] }
+
+ smem mem : UInt<8>[2][1024]
+ read mport r = mem[io.raddr], clock
+ io.out <= r
+
+ when io.en :
+ write mport w = mem[io.waddr], clock
+ w <= io.wdata
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ res.getEmittedCircuit.value shouldNot include ("mask")
+ (new java.io.File(confLoc)).delete()
+ }
+
+ "ReplSeqMem" should "should not conjoin enable signal with mask condition" in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
+
+ smem mem : UInt<8>[2][1024]
+ read mport r = mem[io.raddr], clock
+ io.out <= r
+
+ when io.en :
+ write mport w = mem[io.waddr], clock
+ when io.mask[0] :
+ w[0] <= io.wdata[0]
+ when io.mask[1] :
+ w[1] <= io.wdata[1]
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
+ res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)")
+ res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)")
+ (new java.io.File(confLoc)).delete()
+ }
+
+ "ReplSeqMem" should "should not conjoin enable signal with wmask condition (RW Port)" in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
+
+ io.out is invalid
+
+ smem mem : UInt<8>[2][1024]
+
+ when io.en :
+ write mport w = mem[io.waddr], clock
+ when io.mask[0] :
+ w[0] <= io.wdata[0]
+ when io.mask[1] :
+ w[1] <= io.wdata[1]
+ when not(io.en) :
+ read mport r = mem[io.raddr], clock
+ io.out <= r
+
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
+ InferReadWriteAnnotation)
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
+ res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)")
+ res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)")
+ (new java.io.File(confLoc)).delete()
+ }
}
// TODO: make more checks
-// readwrite vs. no readwrite
-// mask + no mask
// conf