diff options
| author | Donggyu | 2016-09-21 16:05:22 -0700 |
|---|---|---|
| committer | Andrew Waterman | 2016-09-21 16:05:22 -0700 |
| commit | 8b12dcbb76896a19f95dc4da19b3b8c74c1ddda3 (patch) | |
| tree | f14267b79a901de6b0efbb87d819a763b86e6328 /src | |
| parent | 5d515c93e2136bb8bb77c5c1f9c5b9f2eb640deb (diff) | |
Fix clock connections in InferReadWrite (#310)
Diffstat (limited to 'src')
4 files changed, 58 insertions, 22 deletions
diff --git a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala index 7ced7a99..aa5acb12 100644 --- a/src/main/scala/firrtl/passes/AnnotateMemMacros.scala +++ b/src/main/scala/firrtl/passes/AnnotateMemMacros.scala @@ -72,6 +72,8 @@ object AnalysisUtils { if (nodeWidth == extractionWidth) getOrigin(connects, args.head) else e case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => getOrigin(connects, args.head) + // Todo: It's not clear it's ok to call remove validifs before mem passes... + case ValidIf(cond, value, ClockType) => getOrigin(connects, value) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => connects get e.serialize match { diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala index ec996fdb..34359c14 100644 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -33,7 +33,7 @@ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} import MemPortUtils.memPortField -import AnalysisUtils.{Connects, getConnects} +import AnalysisUtils.{Connects, getConnects, getConnectOrigin} import WrappedExpression.weq import Annotations._ @@ -117,7 +117,9 @@ object InferReadWritePass extends Pass { for (w <- mem.writers ; r <- mem.readers) { val wp = getProductTerms(connects)(memPortField(mem, w, "en")) val rp = getProductTerms(connects)(memPortField(mem, r, "en")) - if (wp exists (a => rp exists (b => checkComplement(a, b)))) { + val wclk = getConnectOrigin(connects, memPortField(mem, w, "clk")) + val rclk = getConnectOrigin(connects, memPortField(mem, r, "clk")) + if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) { val rw = namespace newName "rw" val rwExp = createSubField(createRef(mem.name), rw) readwriters += rw @@ -132,7 +134,7 @@ object InferReadWritePass extends Pass { repl(memPortField(mem, w, "addr")) = EmptyExpression repl(memPortField(mem, w, "data")) = createSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = createSubField(rwExp, "wmask") - stmts += Connect(NoInfo, createSubField(rwExp, "clk"), createRef("clk")) // TODO: fix it + stmts += Connect(NoInfo, createSubField(rwExp, "clk"), wclk) stmts += Connect(NoInfo, createSubField(rwExp, "en"), DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), connects(memPortField(mem, w, "en"))), Nil, BoolType)) diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 485664b4..c1b0de1e 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -74,9 +74,9 @@ object RemoveCHIRRTL extends Pass { case (s:CDefMPort) => val p = mports getOrElse (s.mem, EMPs) s.direction match { - case MRead => p.readers += MPort(s.name,s.exps(1)) - case MWrite => p.writers += MPort(s.name,s.exps(1)) - case MReadWrite => p.readwriters += MPort(s.name,s.exps(1)) + case MRead => p.readers += MPort(s.name, s.exps(1)) + case MWrite => p.writers += MPort(s.name, s.exps(1)) + case MReadWrite => p.readwriters += MPort(s.name, s.exps(1)) } mports(s.mem) = p case s => @@ -90,17 +90,14 @@ object RemoveCHIRRTL extends Pass { types(s.name) = s.tpe val taddr = UIntType(IntWidth(1 max ceilLog2(s.size))) val tdata = s.tpe - def set_poison(vec: Seq[MPort], addr: String) = vec flatMap (r => Seq( - IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), addr, taddr)), - IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", taddr)) + def set_poison(vec: Seq[MPort]) = vec flatMap (r => Seq( + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "addr", taddr)), + IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), "clk", ClockType)) )) def set_enable(vec: Seq[MPort], en: String) = vec map (r => - Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, taddr), zero) + Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), en, BoolType), zero) ) - def set_wmode (vec: Seq[MPort], wmode: String) = vec map (r => - Connect(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), wmode, taddr), zero) - ) - def set_write (vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => + def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => val tmask = createMask(s.tpe) IsInvalid(s.info, SubField(SubField(Reference(s.name, ut), r.name, ut), data, tdata)) +: (create_exps(SubField(SubField(Reference(s.name, ut), r.name, ut), mask, tmask)) @@ -110,13 +107,13 @@ object RemoveCHIRRTL extends Pass { val rds = (mports getOrElse (s.name, EMPs)).readers val wrs = (mports getOrElse (s.name, EMPs)).writers val rws = (mports getOrElse (s.name, EMPs)).readwriters - val stmts = set_poison(rds, "addr") ++ + val stmts = set_poison(rds) ++ set_enable(rds, "en") ++ - set_poison(wrs, "addr") ++ + set_poison(wrs) ++ set_enable(wrs, "en") ++ set_write(wrs, "data", "mask") ++ - set_poison(rws, "addr") ++ - set_wmode(rws, "wmode") ++ + set_poison(rws) ++ + set_enable(rws, "wmode") ++ set_enable(rws, "en") ++ set_write(rws, "wdata", "wmask") val mem = DefMemory(s.info, s.name, s.tpe, s.size, 1, if (s.seq) 1 else 0, diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index 3af018bd..7e1a0c7e 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -34,6 +34,9 @@ import firrtl.Mappers._ import Annotations._ class InferReadWriteSpec extends SimpleTransformSpec { + class InferReadWriteCheckException extends PassException( + "Readwrite ports are not found!") + object InferReadWriteCheckPass extends Pass { val name = "Check Infer ReadWrite Ports" def findReadWrite(s: Statement): Boolean = s match { @@ -51,7 +54,7 @@ class InferReadWriteSpec extends SimpleTransformSpec { case m: ExtModule => false } if (!foundReadWrite) { - errors append new PassException("Readwrite ports are not found!") + errors append new InferReadWriteCheckException errors.trigger } c @@ -73,7 +76,7 @@ class InferReadWriteSpec extends SimpleTransformSpec { new EmitFirrtl(writer) ) - "Infer ReadWrite Ports" should "infer readwrite ports" in { + "Infer ReadWrite Ports" should "infer readwrite ports for the same clock" in { val input = """ circuit sram6t : module sram6t : @@ -97,10 +100,42 @@ circuit sram6t : T_5 <= io.wdata """.stripMargin - val annotaitonMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1)))) + val annotationMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1)))) val writer = new java.io.StringWriter - compile(parse(input), annotaitonMap, writer) + compile(parse(input), annotationMap, writer) // Check correctness of firrtl parse(writer.toString) } + + "Infer ReadWrite Ports" should "not infer readwrite ports for the difference clocks" in { + val input = """ +circuit sram6t : + module sram6t : + input clk1 : Clock + input clk2 : Clock + input reset : UInt<1> + output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>} + + io is invalid + smem mem : UInt<32>[128] + node T_0 = eq(io.wen, UInt<1>("h00")) + node T_1 = and(io.en, T_0) + wire T_2 : UInt + T_2 is invalid + when T_1 : + T_2 <= io.raddr + read mport T_3 = mem[T_2], clk1 + io.rdata <= T_3 + node T_4 = and(io.en, io.wen) + when T_4 : + write mport T_5 = mem[io.waddr], clk2 + T_5 <= io.wdata +""".stripMargin + + val annotationMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1)))) + val writer = new java.io.StringWriter + intercept[InferReadWriteCheckException] { + compile(parse(input), annotationMap, writer) + } + } } |
