From 90e691d517d29c22ead141266c5659a1535fdbf5 Mon Sep 17 00:00:00 2001 From: Donggyu Kim Date: Wed, 27 Jul 2016 14:45:17 -0700 Subject: fix read port enables in RemoveCHIRRTL read ports are declared outside when clauses and used multiple times, so their enables should be inserted when being replaced --- src/main/scala/firrtl/passes/Passes.scala | 74 +++++++++++++++++++++---------- 1 file changed, 50 insertions(+), 24 deletions(-) (limited to 'src') diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 44de3542..120c81a9 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -1227,6 +1227,7 @@ object RemoveCHIRRTL extends Pass { def remove_chirrtl_m (m:Module) : Module = { val hash = LinkedHashMap[String,MPorts]() val repl = LinkedHashMap[String,DataRef]() + val raddrs = HashMap[String, Expression]() val ut = UnknownType val mport_types = LinkedHashMap[String,Type]() def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) @@ -1310,11 +1311,15 @@ object RemoveCHIRRTL extends Pass { ens += "en" masks += "mask" } - case _ => { + case MRead => { repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false) addrs += "addr" clks += "clk" - ens += "en" + s.exps(0) match { + case e: Reference => + raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut) + case _=> + } } } val stmts = ArrayBuffer[Statement]() @@ -1334,23 +1339,25 @@ object RemoveCHIRRTL extends Pass { } def remove_chirrtl_s (s:Statement) : Statement = { var has_write_mport = false + var has_read_mport: Option[Expression] = None var has_readwrite_mport:Option[Expression] = None def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { - (e) match { - case (e:Reference) => { - if (repl.contains(e.name)) { - val vt = repl(e.name) - g match { - case MALE => SubField(vt.exp,vt.male,e.tpe) - case FEMALE => { - has_write_mport = true - if (vt.rdwrite == true) - has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) - SubField(vt.exp,vt.female,e.tpe) - } + (e) match { + case (e:Reference) if repl contains e.name => + val vt = repl(e.name) + g match { + case MALE => SubField(vt.exp,vt.male,e.tpe) + case FEMALE => { + has_write_mport = true + if (vt.rdwrite) + has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1)))) + SubField(vt.exp,vt.female,e.tpe) } - } else e - } + } + case (e:Reference) if g == FEMALE && (raddrs contains e.name) => + has_read_mport = Some(raddrs(e.name)) + e + case (e:Reference) => e case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe) case (e) => e map (remove_chirrtl_e(g)) } @@ -1367,20 +1374,35 @@ object RemoveCHIRRTL extends Pass { case (e) => e } } - (s) match { + (s) match { + case (s:DefNode) => { + val stmts = ArrayBuffer[Statement]() + val valuex = remove_chirrtl_e(MALE)(s.value) + stmts += DefNode(s.info,s.name,valuex) + has_read_mport match { + case None => + case Some(en) => stmts += Connect(s.info,en,one) + } + if (stmts.size > 1) Block(stmts) + else stmts(0) + } case (s:Connect) => { val stmts = ArrayBuffer[Statement]() val rocx = remove_chirrtl_e(MALE)(s.expr) val locx = remove_chirrtl_e(FEMALE)(s.loc) stmts += Connect(s.info,locx,rocx) + has_read_mport match { + case None => + case Some(en) => stmts += Connect(s.info,en,one) + } if (has_write_mport) { val e = get_mask(s.loc) for (x <- create_exps(e) ) { stmts += Connect(s.info,x,one) } - if (has_readwrite_mport != None) { - val wmode = has_readwrite_mport.get - stmts += Connect(s.info,wmode,one) + has_readwrite_mport match { + case None => + case Some(wmode) => stmts += Connect(s.info,wmode,one) } } if (stmts.size > 1) Block(stmts) @@ -1391,16 +1413,20 @@ object RemoveCHIRRTL extends Pass { val locx = remove_chirrtl_e(FEMALE)(s.loc) val rocx = remove_chirrtl_e(MALE)(s.expr) stmts += PartialConnect(s.info,locx,rocx) - if (has_write_mport != false) { + has_read_mport match { + case None => + case Some(en) => stmts += Connect(s.info,en,one) + } + if (has_write_mport) { val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) val locs = create_exps(get_mask(s.loc)) for (x <- ls ) { val locx = locs(x._1) stmts += Connect(s.info,locx,one) } - if (has_readwrite_mport != None) { - val wmode = has_readwrite_mport.get - stmts += Connect(s.info,wmode,one) + has_readwrite_mport match { + case None => + case Some(wmode) => stmts += Connect(s.info,wmode,one) } } if (stmts.size > 1) Block(stmts) -- cgit v1.2.3 From 22350029c9a91c30abd849c17108f8bc24054a78 Mon Sep 17 00:00:00 2001 From: Donggyu Kim Date: Wed, 27 Jul 2016 14:47:47 -0700 Subject: infer readwrite ports for backward compatibility --- src/main/scala/firrtl/LoweringCompilers.scala | 1 + src/main/scala/firrtl/passes/InferReadWrite.scala | 154 ++++++++++++++++++++++ 2 files changed, 155 insertions(+) create mode 100644 src/main/scala/firrtl/passes/InferReadWrite.scala (limited to 'src') diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 8beaf7f9..c27ffce7 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -107,6 +107,7 @@ class HighFirrtlToMiddleFirrtl () extends Transform with SimpleRun { passes.ExpandConnects, passes.RemoveAccesses, passes.ExpandWhens, + passes.InferReadWrite, passes.CheckInitialization, passes.ConstProp, passes.ResolveKinds, diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala new file mode 100644 index 00000000..dd62a240 --- /dev/null +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -0,0 +1,154 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl.passes + +import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap} + +import firrtl._ +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps._ + +object InferReadWrite extends Pass { + def name = "Infer ReadWrite Ports" + + def inferReadWrite(m: Module) = { + import WrappedExpression.we + val connects = HashMap[String, Expression]() + val repl = HashMap[String, Expression]() + val stmts = ArrayBuffer[Statement]() + val zero = we(UIntLiteral(0, IntWidth(1))) + val one = we(UIntLiteral(1, IntWidth(1))) + + def analyze(s: Statement): Unit = s match { + case s: Connect => + connects(s.loc.serialize) = s.expr + case s: PartialConnect => + connects(s.loc.serialize) = s.expr + case s: DefNode => + connects(s.name) = s.value + case s: Block => + s.stmts foreach analyze + case _ => + } + + def getProductTermsFromExp(e: Expression): Seq[Expression] = + e match { + // No ConstProp yet... + case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero => + cond +: getProductTerms(cond.serialize) + case DoPrim(op, args, consts, tpe) if op == And => + e +: (args flatMap getProductTermsFromExp) + case _: WRef | _: SubField | _: SubIndex | _: SubAccess => + e +: getProductTerms(e.serialize) + case _ => + List(e) + } + + def getProductTerms(node: String): Seq[Expression] = + if (connects contains node) getProductTermsFromExp(connects(node)) else Nil + + def checkComplement(a: Expression, b: Expression) = (a, b) match { + case (_, DoPrim(op, args, _, _)) if op == Not => + args.head.serialize == a.serialize + case (DoPrim(op, args, _, _), _) if op == Not => + args.head.serialize == b.serialize + case (_, DoPrim(op, args, _, _)) if op == Eq => + args(0).serialize == a.serialize && we(args(1)) == zero || + args(1).serialize == a.serialize && we(args(0)) == zero + case (DoPrim(op, args, _, _), _) if op == Eq => + args(0).serialize == b.serialize && we(args(1)) == zero || + args(1).serialize == b.serialize && we(args(0)) == zero + case _ => false + } + + def inferReadWrite(s: Statement): Statement = s map inferReadWrite match { + case mem: DefMemory if mem.readLatency > 0 => + var idx = 0 + val bt = UIntType(IntWidth(1)) + val ut = UnknownType + val ug = UNKNOWNGENDER + val readers = HashSet[String]() + val writers = HashSet[String]() + val readwriters = ArrayBuffer[String]() + for (w <- mem.writers ; r <- mem.readers) { + val wp = getProductTerms(s"${mem.name}.$w.en") + val rp = getProductTerms(s"${mem.name}.$r.en") + if (wp exists (a => rp exists (b => checkComplement(a, b)))) { + val rw = s"rw_$idx" + val rw_exp = WSubField(WRef(mem.name, ut, NodeKind(), ug), rw, ut, ug) + readwriters += rw + readers += r + writers += w + repl(s"${mem.name}.$r.en") = EmptyExpression + repl(s"${mem.name}.$r.clk") = EmptyExpression + repl(s"${mem.name}.$r.addr") = EmptyExpression + repl(s"${mem.name}.$r.data") = WSubField(rw_exp, "rdata", mem.dataType, MALE) + repl(s"${mem.name}.$w.en") = WSubField(rw_exp, "wmode", bt, FEMALE) + repl(s"${mem.name}.$w.clk") = EmptyExpression + repl(s"${mem.name}.$w.addr") = EmptyExpression + repl(s"${mem.name}.$w.data") = WSubField(rw_exp, "data", mem.dataType, FEMALE) + repl(s"${mem.name}.$w.mask") = WSubField(rw_exp, "mask", ut, FEMALE) + stmts += Connect(NoInfo, WSubField(rw_exp, "clk", ClockType, FEMALE), + WRef("clk", ClockType, NodeKind(), MALE)) + stmts += Connect(NoInfo, WSubField(rw_exp, "en", bt, FEMALE), + DoPrim(Or, List(connects(s"${mem.name}.$r.en"), connects(s"${mem.name}.$w.en")), Nil, bt)) + stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE), + Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"), + connects(s"${mem.name}.$r.addr"), ut)) + idx += 1 + } + } + if (readwriters.isEmpty) mem else DefMemory(mem.info, + mem.name, mem.dataType, mem.depth, mem.writeLatency, mem.readLatency, + mem.readers filterNot readers, mem.writers filterNot writers, + mem.readwriters ++ readwriters) + case s => s + } + + def replaceExp(e: Expression): Expression = + e map replaceExp match { + case e: WSubField => repl getOrElse (e.serialize, e) + case e => e + } + + def replaceStmt(s: Statement): Statement = + s map replaceStmt map replaceExp match { + case Connect(info, loc, exp) if loc == EmptyExpression => EmptyStmt + case s => s + } + + analyze(m.body) + Module(m.info, m.name, m.ports, Block((m.body map inferReadWrite map replaceStmt) +: stmts.toSeq)) + } + + def run (c:Circuit) = Circuit(c.info, c.modules map { + case m: Module => inferReadWrite(m) + case m: ExtModule => m + }, c.main) +} -- cgit v1.2.3 From c951e7453303f7aaf0c281f88a76ae2ba017ed38 Mon Sep 17 00:00:00 2001 From: Donggyu Kim Date: Wed, 27 Jul 2016 14:56:11 -0700 Subject: make infer readwrite ports optional turned on with '--inferRW ' --- src/main/scala/firrtl/Driver.scala | 7 +- src/main/scala/firrtl/LoweringCompilers.scala | 3 +- src/main/scala/firrtl/passes/InferReadWrite.scala | 56 ++++++++++- .../scala/firrtlTests/InferReadWriteSpec.scala | 104 +++++++++++++++++++++ 4 files changed, 164 insertions(+), 6 deletions(-) create mode 100644 src/test/scala/firrtlTests/InferReadWriteSpec.scala (limited to 'src') diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 59a2bb87..bd7210f4 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -44,6 +44,7 @@ Options: Currently supported: high low verilog --info-mode Specify Info Mode Supported modes: ignore, use, gen, append + --inferRW Enable readwrite port inference for the target circuit """ // Compiles circuit. First parses a circuit from an input file, @@ -87,11 +88,15 @@ Options: case _ => throw new Exception(s"Bad inline instance/module name: $value") } + def handleInferRWOption(value: String) = + passes.InferReadWriteAnnotation(value, TransID(-1)) + run(args: Array[String], Map( "high" -> new HighFirrtlCompiler(), "low" -> new LowFirrtlCompiler(), "verilog" -> new VerilogCompiler()), - Map("--inline" -> handleInlineOption _), + Map("--inline" -> handleInlineOption _, + "--inferRW" -> handleInferRWOption _), usage ) } diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index c27ffce7..4d7ddfe0 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -107,7 +107,6 @@ class HighFirrtlToMiddleFirrtl () extends Transform with SimpleRun { passes.ExpandConnects, passes.RemoveAccesses, passes.ExpandWhens, - passes.InferReadWrite, passes.CheckInitialization, passes.ConstProp, passes.ResolveKinds, @@ -190,6 +189,7 @@ class LowFirrtlCompiler extends Compiler { new passes.InlineInstances(TransID(0)), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), + new passes.InferReadWrite(TransID(-1)), new MiddleFirrtlToLowFirrtl(), new EmitFirrtl(writer) ) @@ -202,6 +202,7 @@ class VerilogCompiler extends Compiler { new IRToWorkingIR(), new ResolveAndCheck(), new HighFirrtlToMiddleFirrtl(), + new passes.InferReadWrite(TransID(-1)), new MiddleFirrtlToLowFirrtl(), new passes.InlineInstances(TransID(0)), new EmitVerilogFromLowFirrtl(writer) diff --git a/src/main/scala/firrtl/passes/InferReadWrite.scala b/src/main/scala/firrtl/passes/InferReadWrite.scala index dd62a240..2378216d 100644 --- a/src/main/scala/firrtl/passes/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/InferReadWrite.scala @@ -28,13 +28,26 @@ MODIFICATIONS. package firrtl.passes import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap} +import com.typesafe.scalalogging.LazyLogging import firrtl._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ +import Annotations._ -object InferReadWrite extends Pass { +case class InferReadWriteAnnotation(t: String, tID: TransID) + extends Annotation with Loose with Unstable { + val target = CircuitName(t) + def duplicate(n: Named) = this.copy(t=n.name) +} + +// This pass examine the enable signals of the read & write ports of memories +// whose readLatency is greater than 1 (usually SeqMem in Chisel). +// If any product term of the enable signal of the read port is the complement +// of any product term of the enable signal of the write port, then the readwrite +// port is inferred. +object InferReadWritePass extends Pass { def name = "Infer ReadWrite Ports" def inferReadWrite(m: Module) = { @@ -45,6 +58,7 @@ object InferReadWrite extends Pass { val zero = we(UIntLiteral(0, IntWidth(1))) val one = we(UIntLiteral(1, IntWidth(1))) + // find all wire connections def analyze(s: Statement): Unit = s match { case s: Connect => connects(s.loc.serialize) = s.expr @@ -62,10 +76,13 @@ object InferReadWrite extends Pass { // No ConstProp yet... case Mux(cond, tval, fval, _) if we(tval) == one && we(fval) == zero => cond +: getProductTerms(cond.serialize) + // Visit each term of AND operation case DoPrim(op, args, consts, tpe) if op == And => e +: (args flatMap getProductTermsFromExp) + // Visit connected nodes to references case _: WRef | _: SubField | _: SubIndex | _: SubAccess => e +: getProductTerms(e.serialize) + // Otherwise just return itselt case _ => List(e) } @@ -74,13 +91,17 @@ object InferReadWrite extends Pass { if (connects contains node) getProductTermsFromExp(connects(node)) else Nil def checkComplement(a: Expression, b: Expression) = (a, b) match { + // b ?= Not(a) case (_, DoPrim(op, args, _, _)) if op == Not => args.head.serialize == a.serialize + // a ?= Not(b) case (DoPrim(op, args, _, _), _) if op == Not => args.head.serialize == b.serialize + // b ?= Eq(a, 0) or b ?= Eq(0, a) case (_, DoPrim(op, args, _, _)) if op == Eq => args(0).serialize == a.serialize && we(args(1)) == zero || args(1).serialize == a.serialize && we(args(0)) == zero + // a ?= Eq(b, 0) or b ?= Eq(0, a) case (DoPrim(op, args, _, _), _) if op == Eq => args(0).serialize == b.serialize && we(args(1)) == zero || args(1).serialize == b.serialize && we(args(0)) == zero @@ -88,8 +109,8 @@ object InferReadWrite extends Pass { } def inferReadWrite(s: Statement): Statement = s map inferReadWrite match { + // infer readwrite ports only for non combinational memories case mem: DefMemory if mem.readLatency > 0 => - var idx = 0 val bt = UIntType(IntWidth(1)) val ut = UnknownType val ug = UNKNOWNGENDER @@ -100,7 +121,13 @@ object InferReadWrite extends Pass { val wp = getProductTerms(s"${mem.name}.$w.en") val rp = getProductTerms(s"${mem.name}.$r.en") if (wp exists (a => rp exists (b => checkComplement(a, b)))) { - val rw = s"rw_$idx" + val allPorts = (mem.readers ++ mem.writers ++ mem.readwriters ++ readwriters).toSet + // Uniquify names by examining all ports of the memory + var rw = (for { + idx <- Stream from 0 + newName = s"rw_$idx" + if !allPorts(newName) + } yield newName).head val rw_exp = WSubField(WRef(mem.name, ut, NodeKind(), ug), rw, ut, ug) readwriters += rw readers += r @@ -121,7 +148,6 @@ object InferReadWrite extends Pass { stmts += Connect(NoInfo, WSubField(rw_exp, "addr", ut, FEMALE), Mux(connects(s"${mem.name}.$w.en"), connects(s"${mem.name}.$w.addr"), connects(s"${mem.name}.$r.addr"), ut)) - idx += 1 } } if (readwriters.isEmpty) mem else DefMemory(mem.info, @@ -152,3 +178,25 @@ object InferReadWrite extends Pass { case m: ExtModule => m }, c.main) } + +// Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" +// To use this transform, circuit name should be annotated with its TransId. +class InferReadWrite(transID: TransID) extends Transform with LazyLogging { + def execute(circuit:Circuit, map: AnnotationMap) = + map get transID match { + case Some(p) => p get CircuitName(circuit.main) match { + case Some(InferReadWriteAnnotation(_, _)) => TransformResult((Seq( + InferReadWritePass, + CheckInitialization, + ResolveKinds, + InferTypes, + ResolveGenders) foldLeft circuit){ (c, pass) => + val x = Utils.time(pass.name)(pass run c) + logger debug x.serialize + x + }, None, Some(map)) + case _ => TransformResult(circuit, None, Some(map)) + } + case _ => TransformResult(circuit, None, Some(map)) + } +} diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala new file mode 100644 index 00000000..93f73741 --- /dev/null +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -0,0 +1,104 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import firrtl.passes._ +import firrtl.Mappers._ +import Annotations._ + +class InferReadWriteSpec extends SimpleTransformSpec { + object InferReadWriteCheckPass extends Pass { + val name = "Check Infer ReadWrite Ports" + var foundReadWrite = false + def findReadWrite(s: Statement): Unit = s match { + case s: DefMemory if s.readLatency > 0 => + foundReadWrite = s.name == "mem" && s.readwriters.size == 1 + case s: Block => + s.stmts foreach findReadWrite + case _ => + } + + def run (c: Circuit) = { + val errors = new Errors + c.modules foreach { + case m: Module => findReadWrite(m.body) + case m: ExtModule => m + } + if (!foundReadWrite) { + errors append new PassException("Readwrite ports are not found!") + } + errors.trigger + c + } + } + + object InferReadWriteCheck extends Transform with SimpleRun { + def execute (c: Circuit, map: AnnotationMap) = + run(c, Seq(InferReadWriteCheckPass)) + } + + def transforms (writer: java.io.Writer) = Seq( + new Chisel3ToHighFirrtl(), + new IRToWorkingIR(), + new ResolveAndCheck(), + new HighFirrtlToMiddleFirrtl(), + new InferReadWrite(TransID(-1)), + InferReadWriteCheck, + new EmitFirrtl(writer) + ) + + "Infer ReadWrite Ports" should "infer readwrite ports" in { + val input = """ +circuit sram6t : + module sram6t : + input clk : Clock + input reset : UInt<1> + output io : {flip en : UInt<1>, flip wen : UInt<1>, flip waddr : UInt<8>, flip wdata : UInt<32>, flip raddr : UInt<8>, rdata : UInt<32>} + + io is invalid + smem mem : UInt<32>[128] + node T_0 = eq(io.wen, UInt<1>("h00")) + node T_1 = and(io.en, T_0) + wire T_2 : UInt + T_2 is invalid + when T_1 : + T_2 <= io.raddr + read mport T_3 = mem[T_2], clk + io.rdata <= T_3 + node T_4 = and(io.en, io.wen) + when T_4 : + write mport T_5 = mem[io.waddr], clk + T_5 <= io.wdata +""".stripMargin + + val annotaitonMap = AnnotationMap(Seq(InferReadWriteAnnotation("sram6t", TransID(-1)))) + compile(parse(input), annotaitonMap, new java.io.StringWriter) + } +} -- cgit v1.2.3