aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala28
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala10
-rw-r--r--src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala3
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala34
-rw-r--r--src/test/scala/firrtlTests/ReplSeqMemTests.scala88
5 files changed, 142 insertions, 21 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 6b3508a6..cc69be6f 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -82,12 +82,10 @@ object RemoveCHIRRTL extends Transform {
def set_enable(vec: Seq[MPort], en: String) = vec map (r =>
Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero)
)
- def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
+ def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap { r =>
val tmask = createMask(sx.tpe)
- IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), data, tdata)) +:
- (create_exps(SubField(SubField(Reference(sx.name, ut), r.name, ut), mask, tmask))
- map (Connect(sx.info, _, zero))
- )
+ val portRef = SubField(Reference(sx.name, ut), r.name, ut)
+ Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask)))
}
val rds = (mports getOrElse (sx.name, EMPs)).readers
val wrs = (mports getOrElse (sx.name, EMPs)).writers
@@ -109,12 +107,15 @@ object RemoveCHIRRTL extends Transform {
val addrs = ArrayBuffer[String]()
val clks = ArrayBuffer[String]()
val ens = ArrayBuffer[String]()
+ val masks = ArrayBuffer.empty[Expression]
+ val portRef = SubField(Reference(sx.mem, ut), sx.name, ut)
sx.direction match {
case MReadWrite =>
- refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "rdata", "wdata", "wmask", rdwrite = true)
+ refs(sx.name) = DataRef(portRef, "rdata", "wdata", "wmask", rdwrite = true)
addrs += "addr"
clks += "clk"
ens += "en"
+ masks ++= create_exps(SubField(portRef, "wmask", createMask(sx.tpe)))
renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata")
renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata")
val es = create_all_exps(WRef(sx.name, sx.tpe))
@@ -124,10 +125,11 @@ object RemoveCHIRRTL extends Transform {
case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize))
}
case MWrite =>
- refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false)
+ refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false)
addrs += "addr"
clks += "clk"
ens += "en"
+ masks ++= create_exps(SubField(portRef, "mask", createMask(sx.tpe)))
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
val es = create_all_exps(WRef(sx.name, sx.tpe))
val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe))
@@ -135,12 +137,12 @@ object RemoveCHIRRTL extends Transform {
case (e, w) => renames.rename(e.serialize, w.serialize)
}
case MRead =>
- refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false)
+ refs(sx.name) = DataRef(portRef, "data", "data", "blah", rdwrite = false)
addrs += "addr"
clks += "clk"
sx.exps.head match {
case e: Reference if smems(sx.mem) =>
- raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut)
+ raddrs(e.name) = SubField(portRef, "en", ut)
case _ => ens += "en"
}
renames.rename(sx.name, s"${sx.mem}.${sx.name}.data")
@@ -152,9 +154,11 @@ object RemoveCHIRRTL extends Transform {
case MInfer => // do nothing if it's not being used
}
Block(
- (addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++
- (clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++
- (ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one))))
+ (addrs map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++
+ (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++
+ (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++
+ masks.map(lhs => Connect(sx.info, lhs, zero))
+ )
case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames)
}
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index 661d6df4..2d1d7f6b 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -89,11 +89,12 @@ object InferReadWritePass extends Pass {
val readwriters = collection.mutable.ArrayBuffer[String]()
val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters)
for (w <- mem.writers ; r <- mem.readers) {
- val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
- val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
+ val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en"))
+ val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en"))
+ val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b)))
val wclk = getOrigin(connects)(memPortField(mem, w, "clk"))
val rclk = getOrigin(connects)(memPortField(mem, r, "clk"))
- if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) {
+ if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) {
val rw = namespace newName "rw"
val rwExp = WSubField(WRef(mem.name), rw)
readwriters += rw
@@ -104,10 +105,11 @@ object InferReadWritePass extends Pass {
repl(memPortField(mem, r, "addr")) = EmptyExpression
repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata")
repl(memPortField(mem, w, "clk")) = EmptyExpression
- repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode")
+ repl(memPortField(mem, w, "en")) = EmptyExpression
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask")
+ stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get)
stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk)
stmts += Connect(NoInfo, WSubField(rwExp, "en"),
DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
index 0424b1dd..e254dcc9 100644
--- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
+++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala
@@ -65,7 +65,8 @@ object AnalysisUtils {
if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects)(args.head)
- case ValidIf(cond, value, ClockType) => getOrigin(connects)(value)
+ // It is a correct optimization to treat ValidIf as a connection
+ case ValidIf(cond, value, _) => getOrigin(connects)(value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
connects get e.serialize match {
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index 34e228be..bffb1b51 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -7,6 +7,7 @@ import firrtl.ir._
import firrtl.passes._
import firrtl.Mappers._
import annotations._
+import FirrtlCheckers._
class InferReadWriteSpec extends SimpleTransformSpec {
class InferReadWriteCheckException extends PassException(
@@ -138,4 +139,37 @@ circuit sram6t :
compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
}
}
+
+ "wmode" should "be simplified" in {
+ val input = """
+circuit sram6t :
+ module sram6t :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}
+
+ io is invalid
+ smem mem : UInt<32> [2048]
+ node wen = and(io.valid, io.write)
+ node ren = and(io.valid, not(io.write))
+ when wen :
+ write mport _T_14 = mem[io.addr], clock
+ _T_14 <= io.dataIn
+ node _T_16 = eq(wen, UInt<1>("h0"))
+ when _T_16 :
+ wire _T_18 : UInt
+ _T_18 is invalid
+ when ren :
+ _T_18 <= io.addr
+ node _T_20 = or(_T_18, UInt<11>("h0"))
+ node _T_21 = bits(_T_20, 10, 0)
+ read mport _T_22 = mem[_T_21], clock
+ io.dataOut <= _T_22
+""".stripMargin
+
+ val annos = Seq(memlib.InferReadWriteAnnotation)
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // Check correctness of firrtl
+ res should containLine (s"mem.rw.wmode <= wen")
+ }
}
diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
index dcc23235..2eae3580 100644
--- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala
+++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala
@@ -8,6 +8,7 @@ import firrtl.passes._
import firrtl.transforms._
import firrtl.passes.memlib._
import annotations._
+import FirrtlCheckers._
class ReplSeqMemSpec extends SimpleTransformSpec {
def emitter = new LowFirrtlEmitter
@@ -66,7 +67,6 @@ circuit Top :
val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc))
val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
// Check correctness of firrtl
- println(res.annotations)
parse(res.getEmittedCircuit.value)
(new java.io.File(confLoc)).delete()
}
@@ -181,7 +181,8 @@ circuit Top :
"asClock(a)" -> "a",
"a" -> "a",
"or(a, b)" -> "or(a, b)",
- "bits(a, 0, 0)" -> "a"
+ "bits(a, 0, 0)" -> "a",
+ "validif(a, b)" -> "b"
)
tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) }
@@ -296,9 +297,88 @@ circuit CustomMemory :
require(numExtMods == 1)
(new java.io.File(confLoc)).delete()
}
+
+ "ReplSeqMem" should "should not have a mask if there is none" in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2] }
+
+ smem mem : UInt<8>[2][1024]
+ read mport r = mem[io.raddr], clock
+ io.out <= r
+
+ when io.en :
+ write mport w = mem[io.waddr], clock
+ w <= io.wdata
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ res.getEmittedCircuit.value shouldNot include ("mask")
+ (new java.io.File(confLoc)).delete()
+ }
+
+ "ReplSeqMem" should "should not conjoin enable signal with mask condition" in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
+
+ smem mem : UInt<8>[2][1024]
+ read mport r = mem[io.raddr], clock
+ io.out <= r
+
+ when io.en :
+ write mport w = mem[io.waddr], clock
+ when io.mask[0] :
+ w[0] <= io.wdata[0]
+ when io.mask[1] :
+ w[1] <= io.wdata[1]
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc))
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
+ res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)")
+ res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)")
+ (new java.io.File(confLoc)).delete()
+ }
+
+ "ReplSeqMem" should "should not conjoin enable signal with wmask condition (RW Port)" in {
+ val input = """
+circuit CustomMemory :
+ module CustomMemory :
+ input clock : Clock
+ output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] }
+
+ io.out is invalid
+
+ smem mem : UInt<8>[2][1024]
+
+ when io.en :
+ write mport w = mem[io.waddr], clock
+ when io.mask[0] :
+ w[0] <= io.wdata[0]
+ when io.mask[1] :
+ w[1] <= io.wdata[1]
+ when not(io.en) :
+ read mport r = mem[io.raddr], clock
+ io.out <= r
+
+"""
+ val confLoc = "ReplSeqMemTests.confTEMP"
+ val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc),
+ InferReadWriteAnnotation)
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos))
+ // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask
+ res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)")
+ res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)")
+ (new java.io.File(confLoc)).delete()
+ }
}
// TODO: make more checks
-// readwrite vs. no readwrite
-// mask + no mask
// conf