diff options
Diffstat (limited to 'src')
5 files changed, 142 insertions, 21 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index 6b3508a6..cc69be6f 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -82,12 +82,10 @@ object RemoveCHIRRTL extends Transform { def set_enable(vec: Seq[MPort], en: String) = vec map (r => Connect(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), en, BoolType), zero) ) - def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap {r => + def set_write(vec: Seq[MPort], data: String, mask: String) = vec flatMap { r => val tmask = createMask(sx.tpe) - IsInvalid(sx.info, SubField(SubField(Reference(sx.name, ut), r.name, ut), data, tdata)) +: - (create_exps(SubField(SubField(Reference(sx.name, ut), r.name, ut), mask, tmask)) - map (Connect(sx.info, _, zero)) - ) + val portRef = SubField(Reference(sx.name, ut), r.name, ut) + Seq(IsInvalid(sx.info, SubField(portRef, data, tdata)), IsInvalid(sx.info, SubField(portRef, mask, tmask))) } val rds = (mports getOrElse (sx.name, EMPs)).readers val wrs = (mports getOrElse (sx.name, EMPs)).writers @@ -109,12 +107,15 @@ object RemoveCHIRRTL extends Transform { val addrs = ArrayBuffer[String]() val clks = ArrayBuffer[String]() val ens = ArrayBuffer[String]() + val masks = ArrayBuffer.empty[Expression] + val portRef = SubField(Reference(sx.mem, ut), sx.name, ut) sx.direction match { case MReadWrite => - refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "rdata", "wdata", "wmask", rdwrite = true) + refs(sx.name) = DataRef(portRef, "rdata", "wdata", "wmask", rdwrite = true) addrs += "addr" clks += "clk" ens += "en" + masks ++= create_exps(SubField(portRef, "wmask", createMask(sx.tpe))) renames.rename(sx.name, s"${sx.mem}.${sx.name}.rdata") renames.rename(sx.name, s"${sx.mem}.${sx.name}.wdata") val es = create_all_exps(WRef(sx.name, sx.tpe)) @@ -124,10 +125,11 @@ object RemoveCHIRRTL extends Transform { case ((e, r), w) => renames.rename(e.serialize, Seq(r.serialize, w.serialize)) } case MWrite => - refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "mask", rdwrite = false) + refs(sx.name) = DataRef(portRef, "data", "data", "mask", rdwrite = false) addrs += "addr" clks += "clk" ens += "en" + masks ++= create_exps(SubField(portRef, "mask", createMask(sx.tpe))) renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") val es = create_all_exps(WRef(sx.name, sx.tpe)) val ws = create_all_exps(WRef(s"${sx.mem}.${sx.name}.data", sx.tpe)) @@ -135,12 +137,12 @@ object RemoveCHIRRTL extends Transform { case (e, w) => renames.rename(e.serialize, w.serialize) } case MRead => - refs(sx.name) = DataRef(SubField(Reference(sx.mem, ut), sx.name, ut), "data", "data", "blah", rdwrite = false) + refs(sx.name) = DataRef(portRef, "data", "data", "blah", rdwrite = false) addrs += "addr" clks += "clk" sx.exps.head match { case e: Reference if smems(sx.mem) => - raddrs(e.name) = SubField(SubField(Reference(sx.mem, ut), sx.name, ut), "en", ut) + raddrs(e.name) = SubField(portRef, "en", ut) case _ => ens += "en" } renames.rename(sx.name, s"${sx.mem}.${sx.name}.data") @@ -152,9 +154,11 @@ object RemoveCHIRRTL extends Transform { case MInfer => // do nothing if it's not being used } Block( - (addrs map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps.head))) ++ - (clks map (x => Connect(sx.info, SubField(SubField(Reference(sx.mem, ut), sx.name, ut), x, ut), sx.exps(1)))) ++ - (ens map (x => Connect(sx.info,SubField(SubField(Reference(sx.mem,ut), sx.name, ut), x, ut), one)))) + (addrs map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps.head))) ++ + (clks map (x => Connect(sx.info, SubField(portRef, x, ut), sx.exps(1)))) ++ + (ens map (x => Connect(sx.info,SubField(portRef, x, ut), one))) ++ + masks.map(lhs => Connect(sx.info, lhs, zero)) + ) case sx => sx map collect_refs(mports, smems, types, refs, raddrs, renames) } diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 661d6df4..2d1d7f6b 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -89,11 +89,12 @@ object InferReadWritePass extends Pass { val readwriters = collection.mutable.ArrayBuffer[String]() val namespace = Namespace(mem.readers ++ mem.writers ++ mem.readwriters) for (w <- mem.writers ; r <- mem.readers) { - val wp = getProductTerms(connects)(memPortField(mem, w, "en")) - val rp = getProductTerms(connects)(memPortField(mem, r, "en")) + val wenProductTerms = getProductTerms(connects)(memPortField(mem, w, "en")) + val renProductTerms = getProductTerms(connects)(memPortField(mem, r, "en")) + val proofOfMutualExclusion = wenProductTerms.find(a => renProductTerms exists (b => checkComplement(a, b))) val wclk = getOrigin(connects)(memPortField(mem, w, "clk")) val rclk = getOrigin(connects)(memPortField(mem, r, "clk")) - if (weq(wclk, rclk) && (wp exists (a => rp exists (b => checkComplement(a, b))))) { + if (weq(wclk, rclk) && proofOfMutualExclusion.nonEmpty) { val rw = namespace newName "rw" val rwExp = WSubField(WRef(mem.name), rw) readwriters += rw @@ -104,10 +105,11 @@ object InferReadWritePass extends Pass { repl(memPortField(mem, r, "addr")) = EmptyExpression repl(memPortField(mem, r, "data")) = WSubField(rwExp, "rdata") repl(memPortField(mem, w, "clk")) = EmptyExpression - repl(memPortField(mem, w, "en")) = WSubField(rwExp, "wmode") + repl(memPortField(mem, w, "en")) = EmptyExpression repl(memPortField(mem, w, "addr")) = EmptyExpression repl(memPortField(mem, w, "data")) = WSubField(rwExp, "wdata") repl(memPortField(mem, w, "mask")) = WSubField(rwExp, "wmask") + stmts += Connect(NoInfo, WSubField(rwExp, "wmode"), proofOfMutualExclusion.get) stmts += Connect(NoInfo, WSubField(rwExp, "clk"), wclk) stmts += Connect(NoInfo, WSubField(rwExp, "en"), DoPrim(Or, Seq(connects(memPortField(mem, r, "en")), diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 0424b1dd..e254dcc9 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -65,7 +65,8 @@ object AnalysisUtils { if (nodeWidth == extractionWidth) getOrigin(connects)(args.head) else e case DoPrim((PrimOps.AsUInt | PrimOps.AsSInt | PrimOps.AsClock), args, _, _) => getOrigin(connects)(args.head) - case ValidIf(cond, value, ClockType) => getOrigin(connects)(value) + // It is a correct optimization to treat ValidIf as a connection + case ValidIf(cond, value, _) => getOrigin(connects)(value) // note: this should stop on a reg, but will stack overflow for combinational loops (not allowed) case _: WRef | _: WSubField | _: WSubIndex | _: WSubAccess if kind(e) != RegKind => connects get e.serialize match { diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index 34e228be..bffb1b51 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -7,6 +7,7 @@ import firrtl.ir._ import firrtl.passes._ import firrtl.Mappers._ import annotations._ +import FirrtlCheckers._ class InferReadWriteSpec extends SimpleTransformSpec { class InferReadWriteCheckException extends PassException( @@ -138,4 +139,37 @@ circuit sram6t : compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) } } + + "wmode" should "be simplified" in { + val input = """ +circuit sram6t : + module sram6t : + input clock : Clock + input reset : UInt<1> + output io : { flip addr : UInt<11>, flip valid : UInt<1>, flip write : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>} + + io is invalid + smem mem : UInt<32> [2048] + node wen = and(io.valid, io.write) + node ren = and(io.valid, not(io.write)) + when wen : + write mport _T_14 = mem[io.addr], clock + _T_14 <= io.dataIn + node _T_16 = eq(wen, UInt<1>("h0")) + when _T_16 : + wire _T_18 : UInt + _T_18 is invalid + when ren : + _T_18 <= io.addr + node _T_20 = or(_T_18, UInt<11>("h0")) + node _T_21 = bits(_T_20, 10, 0) + read mport _T_22 = mem[_T_21], clock + io.dataOut <= _T_22 +""".stripMargin + + val annos = Seq(memlib.InferReadWriteAnnotation) + val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + // Check correctness of firrtl + res should containLine (s"mem.rw.wmode <= wen") + } } diff --git a/src/test/scala/firrtlTests/ReplSeqMemTests.scala b/src/test/scala/firrtlTests/ReplSeqMemTests.scala index dcc23235..2eae3580 100644 --- a/src/test/scala/firrtlTests/ReplSeqMemTests.scala +++ b/src/test/scala/firrtlTests/ReplSeqMemTests.scala @@ -8,6 +8,7 @@ import firrtl.passes._ import firrtl.transforms._ import firrtl.passes.memlib._ import annotations._ +import FirrtlCheckers._ class ReplSeqMemSpec extends SimpleTransformSpec { def emitter = new LowFirrtlEmitter @@ -66,7 +67,6 @@ circuit Top : val annos = Seq(ReplSeqMemAnnotation.parse("-c:Top:-o:"+confLoc)) val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) // Check correctness of firrtl - println(res.annotations) parse(res.getEmittedCircuit.value) (new java.io.File(confLoc)).delete() } @@ -181,7 +181,8 @@ circuit Top : "asClock(a)" -> "a", "a" -> "a", "or(a, b)" -> "or(a, b)", - "bits(a, 0, 0)" -> "a" + "bits(a, 0, 0)" -> "a", + "validif(a, b)" -> "b" ) tests foreach { case(hurdle, origin) => checkConnectOrigin(hurdle, origin) } @@ -296,9 +297,88 @@ circuit CustomMemory : require(numExtMods == 1) (new java.io.File(confLoc)).delete() } + + "ReplSeqMem" should "should not have a mask if there is none" in { + val input = """ +circuit CustomMemory : + module CustomMemory : + input clock : Clock + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2] } + + smem mem : UInt<8>[2][1024] + read mport r = mem[io.raddr], clock + io.out <= r + + when io.en : + write mport w = mem[io.waddr], clock + w <= io.wdata +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + res.getEmittedCircuit.value shouldNot include ("mask") + (new java.io.File(confLoc)).delete() + } + + "ReplSeqMem" should "should not conjoin enable signal with mask condition" in { + val input = """ +circuit CustomMemory : + module CustomMemory : + input clock : Clock + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + + smem mem : UInt<8>[2][1024] + read mport r = mem[io.raddr], clock + io.out <= r + + when io.en : + write mport w = mem[io.waddr], clock + when io.mask[0] : + w[0] <= io.wdata[0] + when io.mask[1] : + w[1] <= io.wdata[1] +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc)) + val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask + res should containLine ("mem.W0_mask_0 <= validif(io_en, io_mask_0)") + res should containLine ("mem.W0_mask_1 <= validif(io_en, io_mask_1)") + (new java.io.File(confLoc)).delete() + } + + "ReplSeqMem" should "should not conjoin enable signal with wmask condition (RW Port)" in { + val input = """ +circuit CustomMemory : + module CustomMemory : + input clock : Clock + output io : { flip en : UInt<1>, out : UInt<8>[2], flip raddr : UInt<10>, flip waddr : UInt<10>, flip wdata : UInt<8>[2], flip mask : UInt<8>[2] } + + io.out is invalid + + smem mem : UInt<8>[2][1024] + + when io.en : + write mport w = mem[io.waddr], clock + when io.mask[0] : + w[0] <= io.wdata[0] + when io.mask[1] : + w[1] <= io.wdata[1] + when not(io.en) : + read mport r = mem[io.raddr], clock + io.out <= r + +""" + val confLoc = "ReplSeqMemTests.confTEMP" + val annos = Seq(ReplSeqMemAnnotation.parse("-c:CustomMemory:-o:"+confLoc), + InferReadWriteAnnotation) + val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + // TODO Until RemoveCHIRRTL is removed, enable will still drive validif for mask + res should containLine ("mem.RW0_wmask_0 <= validif(io_en, io_mask_0)") + res should containLine ("mem.RW0_wmask_1 <= validif(io_en, io_mask_1)") + (new java.io.File(confLoc)).delete() + } } // TODO: make more checks -// readwrite vs. no readwrite -// mask + no mask // conf |
