aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu2016-09-21 16:05:22 -0700
committerAndrew Waterman2016-09-21 16:05:22 -0700
commit8b12dcbb76896a19f95dc4da19b3b8c74c1ddda3 (patch)
treef14267b79a901de6b0efbb87d819a763b86e6328 /src
parent5d515c93e2136bb8bb77c5c1f9c5b9f2eb640deb (diff)
Fix clock connections in InferReadWrite (#310)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/AnnotateMemMacros.scala2
-rw-r--r--src/main/scala/firrtl/passes/InferReadWrite.scala8
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala27
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala43
4 files changed, 58 insertions, 22 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
index 7ced7a99..aa5acb12 100644
--- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
+++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala
@@ -72,6 +72,8 @@ object AnalysisUtils {
if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e
case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) =>
getOrigin(connects, args.head)
+ // Todo: It's not clear it's ok to call remove validifs before mem passes...
+ case ValidIf(cond, value, ClockType) => getOrigin(connects, value)
// note: this should stop on a reg, but will stack overflow for combinational loops (not allowed)
case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind =>
connects get e.serialize match {
diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala
index ec996fdb..34359c14 100644
--- a/src/main/scala/firrtl/passes/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/InferReadWrite.scala
@@ -33,7 +33,7 @@ import firrtl.Mappers._
import firrtl.PrimOps._
import firrtl.Utils.{one, zero, BoolType}
import MemPortUtils.memPortField
-import AnalysisUtils.{Connects, getConnects}
+import AnalysisUtils.{Connects, getConnects, getConnectOrigin}
import WrappedExpression.weq
import Annotations._
@@ -117,7 +117,9 @@ object InferReadWritePass extends Pass {
for (w <- mem.writers ; r <- mem.readers) {
val wp = getProductTerms(connects)(memPortField(mem, w, "en"))
val rp = getProductTerms(connects)(memPortField(mem, r, "en"))
- if (wp exists (a => rp exists (b => checkComplement(a, b)))) {
+ val wclk = getConnectOrigin(connects, memPortField(mem, w, "clk"))
+ val rclk = getConnectOrigin(connects, memPortField(mem, r, "clk"))
+ if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) {
val rw = namespace newName "rw"
val rwExp = createSubField(createRef(mem.name), rw)
readwriters += rw
@@ -132,7 +134,7 @@ object InferReadWritePass extends Pass {
repl(memPortField(mem, w, "addr")) = EmptyExpression
repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata")
repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask")
- stmts += Connect(NoInfo, createSubField(rwExp, "clk"), createRef("clk")) // TODO: fix it
+ stmts += Connect(NoInfo, createSubField(rwExp, "clk"), wclk)
stmts += Connect(NoInfo, createSubField(rwExp, "en"),
DoPrim(Or, Seq(connects(memPortField(mem, r, "en")),
connects(memPortField(mem, w, "en"))), Nil, BoolType))
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
index 485664b4..c1b0de1e 100644
--- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -74,9 +74,9 @@ object RemoveCHIRRTL extends Pass {
case (s:CDefMPort) =>
val p = mports getOrElse (s.mem, EMPs)
s.direction match {
- case MRead => p.readers += MPort(s.name,s.exps(1))
- case MWrite => p.writers += MPort(s.name,s.exps(1))
- case MReadWrite => p.readwriters += MPort(s.name,s.exps(1))
+ case MRead => p.readers += MPort(s.name, s.exps(1))
+ case MWrite => p.writers += MPort(s.name, s.exps(1))
+ case MReadWrite => p.readwriters += MPort(s.name, s.exps(1))
}
mports(s.mem) = p
case s =>
@@ -90,17 +90,14 @@ object RemoveCHIRRTL extends Pass {
types(s.name) = s.tpe
val taddr = UIntType(IntWidth(1 max ceilLog2(s.size)))
val tdata = s.tpe
- def set_poison(vec: Seq[MPort], addr: String) = vec flatMap (r => Seq(
- IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), addr, taddr)),
- IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", taddr))
+ def set_poison(vec: Seq[MPort]) = vec flatMap (r => Seq(
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "addr", taddr)),
+ IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", ClockType))
))
def set_enable(vec: Seq[MPort], en: String) = vec map (r =>
- Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, taddr), zero)
+ Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, BoolType), zero)
)
- def set_wmode (vec: Seq[MPort], wmode: String) = vec map (r =>
- Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero)
- )
- def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
+ def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap {r =>
val tmask = createMask(s.tpe)
IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +:
(create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask))
@@ -110,13 +107,13 @@ object RemoveCHIRRTL extends Pass {
val rds = (mports getOrElse (s.name, EMPs)).readers
val wrs = (mports getOrElse (s.name, EMPs)).writers
val rws = (mports getOrElse (s.name, EMPs)).readwriters
- val stmts = set_poison(rds, "addr") ++
+ val stmts = set_poison(rds) ++
set_enable(rds, "en") ++
- set_poison(wrs, "addr") ++
+ set_poison(wrs) ++
set_enable(wrs, "en") ++
set_write(wrs, "data", "mask") ++
- set_poison(rws, "addr") ++
- set_wmode(rws, "wmode") ++
+ set_poison(rws) ++
+ set_enable(rws, "wmode") ++
set_enable(rws, "en") ++
set_write(rws, "wdata", "wmask")
val mem = DefMemory(s.info, s.name, s.tpe, s.size, 1, if (s.seq) 1 else 0,
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index 3af018bd..7e1a0c7e 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -34,6 +34,9 @@ import firrtl.Mappers._
import Annotations._
class InferReadWriteSpec extends SimpleTransformSpec {
+ class InferReadWriteCheckException extends PassException(
+ "Readwrite ports are not found!")
+
object InferReadWriteCheckPass extends Pass {
val name = "Check Infer ReadWrite Ports"
def findReadWrite(s: Statement): Boolean = s match {
@@ -51,7 +54,7 @@ class InferReadWriteSpec extends SimpleTransformSpec {
case m: ExtModule => false
}
if (!foundReadWrite) {
- errors append new PassException("Readwrite ports are not found!")
+ errors append new InferReadWriteCheckException
errors.trigger
}
c
@@ -73,7 +76,7 @@ class InferReadWriteSpec extends SimpleTransformSpec {
new EmitFirrtl(writer)
)
- "Infer ReadWrite Ports" should "infer readwrite ports" in {
+ "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in {
val input = """
circuit sram6t :
module sram6t :
@@ -97,10 +100,42 @@ circuit sram6t :
T_5 <= io.wdata
""".stripMargin
- val annotaitonMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1))))
+ val annotationMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1))))
val writer = new java.io.StringWriter
- compile(parse(input), annotaitonMap, writer)
+ compile(parse(input), annotationMap, writer)
// Check correctness of firrtl
parse(writer.toString)
}
+
+ "Infer ReadWrite Ports" should "not infer readwrite ports for the difference clocks" in {
+ val input = """
+circuit sram6t :
+ module sram6t :
+ input clk1 : Clock
+ input clk2 : Clock
+ input reset : UInt<1>
+ output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>}
+
+ io is invalid
+ smem mem : UInt<32>[128]
+ node T_0 = eq(io.wen, UInt<1>("h00"))
+ node T_1 = and(io.en, T_0)
+ wire T_2 : UInt
+ T_2 is invalid
+ when T_1 :
+ T_2 <= io.raddr
+ read mport T_3 = mem[T_2], clk1
+ io.rdata <= T_3
+ node T_4 = and(io.en, io.wen)
+ when T_4 :
+ write mport T_5 = mem[io.waddr], clk2
+ T_5 <= io.wdata
+""".stripMargin
+
+ val annotationMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1))))
+ val writer = new java.io.StringWriter
+ intercept[InferReadWriteCheckException] {
+ compile(parse(input), annotationMap, writer)
+ }
+ }
}