diff options
| author | Albert Magyar | 2019-10-21 12:10:51 -0700 |
|---|---|---|
| committer | GitHub | 2019-10-21 12:10:51 -0700 |
| commit | b43288d588d04775230456ca85fa231a8cf397fe (patch) | |
| tree | 0933b15baca7520faf5aae0f9e1fc60bb36390d4 | |
| parent | fd981848c7d2a800a15f9acfbf33b57dd1c6225b (diff) | |
| parent | 24f7d90b032f7058ae379ff3592c9d29c7f987e7 (diff) | |
Merge pull request #1202 from freechipsproject/fix-verilog-mem-delay-en
Fix handling of read enables for write-first (default) memories in VerilogMemDelays
7 files changed, 589 insertions, 170 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 328e6caa..2dc73db3 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -8,195 +8,163 @@ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.traversals.Foreachers._ -import firrtl.PrimOps._ + import MemPortUtils._ +import WrappedExpression._ import collection.mutable -object DelayPipe { - private case class PipeState(ref: Expression, decl: Statement = EmptyStmt, connect: Statement = EmptyStmt, idx: Int = 0) +object MemDelayAndReadwriteTransformer { + // Representation of a group of signals and associated valid signals + case class WithValid(valid: Expression, payload: Seq[Expression]) + + // Grouped statements that are split into declarations and connects to ease ordering + case class SplitStatements(decls: Seq[Statement], conns: Seq[Connect]) + + // Utilities for generating hardware + def NOT(e: Expression) = DoPrim(PrimOps.Not, Seq(e), Nil, BoolType) + def AND(e1: Expression, e2: Expression) = DoPrim(PrimOps.And, Seq(e1, e2), Nil, BoolType) + def connect(l: Expression, r: Expression): Connect = Connect(NoInfo, l, r) + def condConnect(c: Expression)(l: Expression, r: Expression): Connect = connect(l, Mux(c, r, l, l.tpe)) + + // Utilities for working with WithValid groups + def connect(l: WithValid, r: WithValid): Seq[Connect] = { + val paired = (l.valid +: l.payload) zip (r.valid +: r.payload) + paired.map { case (le, re) => connect(le, re) } + } + + def condConnect(l: WithValid, r: WithValid): Seq[Connect] = { + connect(l.valid, r.valid) +: (l.payload zip r.payload).map { case (le, re) => condConnect(r.valid)(le, re) } + } + + // Internal representation of a pipeline stage with an associated valid signal + private case class PipeStageWithValid(idx: Int, ref: WithValid, stmts: SplitStatements = SplitStatements(Nil, Nil)) + + // Utilities for creating legal names for registers + private val metaChars = raw"[\[\]\.]".r + private def flatName(e: Expression) = metaChars.replaceAllIn(e.serialize, "_") - def apply(ns: Namespace)(e: Expression, delay: Int, clock: Expression): (Expression, Seq[Statement]) = { - def addStage(prev: PipeState): PipeState = { - val idx = prev.idx + 1 - val name = ns.newName(s"${e.serialize}_r${idx}".replace('.', '_')) - val regRef = WRef(name, e.tpe, RegKind) - val regDecl = DefRegister(NoInfo, name, e.tpe, clock, zero, regRef) - PipeState(regRef, regDecl, Connect(NoInfo, regRef, prev.ref), idx) + // Pipeline a group of signals with an associated valid signal. Gate registers when possible. + def pipelineWithValid(ns: Namespace)( + clock: Expression, + depth: Int, + src: WithValid, + nameTemplate: Option[WithValid] = None): (WithValid, Seq[Statement], Seq[Connect]) = { + + def asReg(e: Expression) = DefRegister(NoInfo, e.serialize, e.tpe, clock, zero, e) + val template = nameTemplate.getOrElse(src) + + val stages = Seq.iterate(PipeStageWithValid(0, src), depth + 1) { case prev => + def pipeRegRef(e: Expression) = WRef(ns.newName(s"${flatName(e)}_pipe_${prev.idx}"), e.tpe, RegKind) + val ref = WithValid(pipeRegRef(template.valid), template.payload.map(pipeRegRef)) + val regs = (ref.valid +: ref.payload).map(asReg) + PipeStageWithValid(prev.idx + 1, ref, SplitStatements(regs, condConnect(ref, prev.ref))) } - val pipeline = Seq.iterate(PipeState(e), delay+1)(addStage) - (pipeline.last.ref, pipeline.map(_.decl) ++ pipeline.map(_.connect)) + (stages.last.ref, stages.flatMap(_.stmts.decls), stages.flatMap(_.stmts.conns)) } } -/** This pass generates delay reigsters for memories for verilog */ -object VerilogMemDelays extends Pass { - val ug = UnknownFlow - type Netlist = collection.mutable.HashMap[String, Expression] - implicit def expToString(e: Expression): String = e.serialize - private def NOT(e: Expression) = DoPrim(Not, Seq(e), Nil, BoolType) - private def AND(e1: Expression, e2: Expression) = DoPrim(And, Seq(e1, e2), Nil, BoolType) - - def buildNetlist(netlist: Netlist)(s: Statement): Unit = s match { - case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(loc) = expr - case _ => - s.foreach(buildNetlist(netlist)) +/** + * This class performs the primary work of the transform: splitting readwrite ports into separate + * read and write ports while simultaneously compiling memory latencies to combinational-read + * memories with delay pipelines. It is represented as a class that takes a module as a constructor + * argument, as it encapsulates the mutable state required to analyze and transform one module. + * + * @note The final transformed module is found in the (sole public) field [[transformed]] + */ +class MemDelayAndReadwriteTransformer(m: DefModule) { + import MemDelayAndReadwriteTransformer._ + + private val ns = Namespace(m) + private val netlist = new collection.mutable.HashMap[WrappedExpression, Expression] + private val exprReplacements = new collection.mutable.HashMap[WrappedExpression, Expression] + private val newConns = new mutable.ArrayBuffer[Connect] + + private def findMemConns(s: Statement): Unit = s match { + case Connect(_, loc, expr) if (kind(loc) == MemKind) => netlist(we(loc)) = expr + case _ => s.foreach(findMemConns) } - def memDelayStmt( - netlist: Netlist, - namespace: Namespace, - repl: Netlist, - stmts: mutable.ArrayBuffer[Statement]) - (s: Statement): Statement = s.map(memDelayStmt(netlist, namespace, repl, stmts)) match { - case sx: DefMemory => - val ports = (sx.readers ++ sx.writers).toSet - def newPortName(rw: String, p: String) = (for { - idx <- Stream from 0 - newName = s"${rw}_${p}_$idx" - if !ports(newName) - } yield newName).head - val rwMap = (sx.readwriters map (rw => - rw ->( (newPortName(rw, "r"), newPortName(rw, "w")) ))).toMap - // 1. readwrite ports are split into read & write ports - // 2. memories are transformed into combinational - // because latency pipes are added for longer latencies - val mem = sx copy ( - readers = sx.readers ++ (sx.readwriters map (rw => rwMap(rw)._1)), - writers = sx.writers ++ (sx.readwriters map (rw => rwMap(rw)._2)), - readwriters = Nil, readLatency = 0, writeLatency = 1) - def prependPipe(e: Expression, // Expression to be piped - n: Int, // pipe depth - clk: Expression, // clock expression - cond: Expression // condition for pipes - ): (Expression, Seq[Statement]) = { - // returns - // 1) reference to the last pipe register - // 2) pipe registers and connects - val node = DefNode(NoInfo, namespace.newTemp, netlist(e)) - val wref = WRef(node.name, e.tpe, NodeKind, SourceFlow) - ((0 until n) foldLeft( (wref, Seq[Statement](node)) )){case ((ex, stmts), i) => - val name = namespace newName s"${LowerTypes.loweredName(e)}_pipe_$i" - val exx = WRef(name, e.tpe, RegKind, ug) - (exx, stmts ++ Seq(DefRegister(NoInfo, name, e.tpe, clk, zero, exx)) ++ - (if (i < n - 1 && WrappedExpression.weq(cond, one)) Seq(Connect(NoInfo, exx, ex)) else { - val condn = namespace newName s"${LowerTypes.loweredName(e)}_en" - val condx = WRef(condn, BoolType, NodeKind, SinkFlow) - Seq(DefNode(NoInfo, condn, cond), - Connect(NoInfo, exx, Mux(condx, ex, exx, e.tpe))) - }) - ) - } + private def swapMemRefs(e: Expression): Expression = e map swapMemRefs match { + case sf: WSubField => exprReplacements.getOrElse(we(sf), sf) + case ex => ex + } + + private def transform(s: Statement): Statement = s.map(transform) match { + case mem: DefMemory => + // Per-memory bookkeeping + val portNS = Namespace(mem.readers ++ mem.writers) + val rMap = mem.readwriters.map(rw => (rw -> portNS.newName(s"${rw}_r"))).toMap + val wMap = mem.readwriters.map(rw => (rw -> portNS.newName(s"${rw}_w"))).toMap + val rCmdDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) 0 else mem.readLatency + val rRespDelay = if (mem.readUnderWrite == ReadUnderWrite.Old) mem.readLatency else 0 + val wCmdDelay = mem.writeLatency - 1 + + val readStmts = (mem.readers ++ mem.readwriters).map { case r => + def oldDriver(f: String) = netlist(we(memPortField(mem, r, f))) + def newField(f: String) = memPortField(mem, rMap.getOrElse(r, r), f) + val clk = oldDriver("clk") + + // Pack sources of read command inputs into WithValid object -> different for readwriter + val enSrc = if (rMap.contains(r)) AND(oldDriver("en"), NOT(oldDriver("wmode"))) else oldDriver("en") + val cmdSrc = WithValid(enSrc, Seq(oldDriver("addr"))) + val cmdSink = WithValid(newField("en"), Seq(newField("addr"))) + val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, rCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Pipeline read response using *last* command pipe stage enable as the valid signal + val resp = WithValid(cmdPiped.valid, Seq(newField("data"))) + val respPipeNameTemplate = Some(resp.copy(valid = cmdSink.valid)) // base pipeline register names off field names + val (respPiped, respDecls, respConns) = pipelineWithValid(ns)(clk, rRespDelay, resp, nameTemplate = respPipeNameTemplate) + + // Make sure references to the read data get appropriately substituted + val oldRDataName = if (rMap.contains(r)) "rdata" else "data" + exprReplacements(we(memPortField(mem, r, oldRDataName))) = respPiped.payload.head + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls ++ respDecls, cmdConns ++ cmdPortConns ++ respConns) } - def readPortConnects(reader: String, - clk: Expression, - en: Expression, - addr: Expression) = Seq( - Connect(NoInfo, memPortField(mem, reader, "clk"), clk), - // connect latency pipes to read ports - Connect(NoInfo, memPortField(mem, reader, "en"), en), - Connect(NoInfo, memPortField(mem, reader, "addr"), addr) - ) - def writePortConnects(writer: String, - clk: Expression, - en: Expression, - mask: Expression, - addr: Expression, - data: Expression) = Seq( - Connect(NoInfo, memPortField(mem, writer, "clk"), clk), - // connect latency pipes to write ports - Connect(NoInfo, memPortField(mem, writer, "en"), en), - Connect(NoInfo, memPortField(mem, writer, "mask"), mask), - Connect(NoInfo, memPortField(mem, writer, "addr"), addr), - Connect(NoInfo, memPortField(mem, writer, "data"), data) - ) - - stmts ++= ((sx.readers flatMap {reader => - val clk = netlist(memPortField(sx, reader, "clk")) - if (sx.readUnderWrite == ReadUnderWrite.Old) { - // For a read-first ("old") mem, read data gets delayed, so don't delay read address/en - val rdata = memPortField(sx, reader, "data") - val enDriver = netlist(memPortField(sx, reader, "en")) - val addrDriver = netlist(memPortField(sx, reader, "addr")) - readPortConnects(reader, clk, enDriver, addrDriver) - } else { - // For a write-first ("new") or undefined mem, delay read control inputs - val (en, ss1) = prependPipe(memPortField(sx, reader, "en"), sx.readLatency - 1, clk, one) - val (addr, ss2) = prependPipe(memPortField(sx, reader, "addr"), sx.readLatency, clk, en) - ss1 ++ ss2 ++ readPortConnects(reader, clk, en, addr) - } - }) ++ (sx.writers flatMap {writer => - // generate latency pipes for write ports (enable, mask, addr, data) - val clk = netlist(memPortField(sx, writer, "clk")) - val (en, ss1) = prependPipe(memPortField(sx, writer, "en"), sx.writeLatency - 1, clk, one) - val (mask, ss2) = prependPipe(memPortField(sx, writer, "mask"), sx.writeLatency - 1, clk, one) - val (addr, ss3) = prependPipe(memPortField(sx, writer, "addr"), sx.writeLatency - 1, clk, one) - val (data, ss4) = prependPipe(memPortField(sx, writer, "data"), sx.writeLatency - 1, clk, one) - ss1 ++ ss2 ++ ss3 ++ ss4 ++ writePortConnects(writer, clk, en, mask, addr, data) - }) ++ (sx.readwriters flatMap {readwriter => - val (reader, writer) = rwMap(readwriter) - val clk = netlist(memPortField(sx, readwriter, "clk")) - // generate latency pipes for readwrite ports (enable, addr, wmode, wmask, wdata) - val (en, ss1) = prependPipe(memPortField(sx, readwriter, "en"), sx.readLatency - 1, clk, one) - val (wmode, ss2) = prependPipe(memPortField(sx, readwriter, "wmode"), sx.writeLatency - 1, clk, one) - val (wmask, ss3) = prependPipe(memPortField(sx, readwriter, "wmask"), sx.writeLatency - 1, clk, one) - val (wdata, ss4) = prependPipe(memPortField(sx, readwriter, "wdata"), sx.writeLatency - 1, clk, one) - val (waddr, ss5) = prependPipe(memPortField(sx, readwriter, "addr"), sx.writeLatency - 1, clk, one) - val stmts = ss1 ++ ss2 ++ ss3 ++ ss4 ++ ss5 ++ writePortConnects(writer, clk, AND(en, wmode), wmask, waddr, wdata) - if (sx.readUnderWrite == ReadUnderWrite.Old) { - // For a read-first ("old") mem, read data gets delayed, so don't delay read address/en - val enDriver = netlist(memPortField(sx, readwriter, "en")) - val addrDriver = netlist(memPortField(sx, readwriter, "addr")) - val wmodeDriver = netlist(memPortField(sx, readwriter, "wmode")) - stmts ++ readPortConnects(reader, clk, AND(enDriver, NOT(wmodeDriver)), addrDriver) + + val writeStmts = (mem.writers ++ mem.readwriters).map { case w => + def oldDriver(f: String) = netlist(we(memPortField(mem, w, f))) + def newField(f: String) = memPortField(mem, wMap.getOrElse(w, w), f) + val clk = oldDriver("clk") + + // Pack sources of write command inputs into WithValid object -> different for readwriter + val cmdSrc = if (wMap.contains(w)) { + val en = AND(oldDriver("en"), oldDriver("wmode")) + WithValid(en, Seq(oldDriver("addr"), oldDriver("wmask"), oldDriver("wdata"))) } else { - // For a write-first ("new") or undefined mem, delay read control inputs - val (raddr, raddrPipeStmts) = prependPipe(memPortField(sx, readwriter, "addr"), sx.readLatency, clk, AND(en, NOT(wmode))) - repl(memPortField(sx, readwriter, "rdata")) = memPortField(mem, reader, "data") - stmts ++ raddrPipeStmts ++ readPortConnects(reader, clk, en, raddr) + WithValid(oldDriver("en"), Seq(oldDriver("addr"), oldDriver("mask"), oldDriver("data"))) } - })) - - def pipeReadData(p: String): Seq[Statement] = { - val newName = rwMap.get(p).map(_._1).getOrElse(p) // Name of final read port, whether renamed (rw port) or not - val rdataNew = memPortField(mem, newName, "data") - val rdataOld = rwMap.get(p).map(rw => memPortField(sx, p, "rdata")).getOrElse(rdataNew) - val clk = netlist(rdataOld.copy(name = "clk")) - val (rdataPipe, rdataPipeStmts) = DelayPipe(namespace)(rdataNew, sx.readLatency, clk) // TODO: use enable - repl(rdataOld) = rdataPipe - rdataPipeStmts - } - // We actually pipe the read data here; this groups it with the mem declaration to keep declarations early - if (sx.readUnderWrite == ReadUnderWrite.Old) { - Block(mem +: (sx.readers ++ sx.readwriters).flatMap(pipeReadData(_))) - } else { - mem + // Pipeline write command, connect to memory + val cmdSink = WithValid(newField("en"), Seq(newField("addr"), newField("mask"), newField("data"))) + val (cmdPiped, cmdDecls, cmdConns) = pipelineWithValid(ns)(clk, wCmdDelay, cmdSrc, nameTemplate = Some(cmdSink)) + val cmdPortConns = connect(cmdSink, cmdPiped) :+ connect(newField("clk"), clk) + + // Return all statements; they're separated so connects can go after all declarations + SplitStatements(cmdDecls, cmdConns ++ cmdPortConns) } - case sx: Connect if kind(sx.loc) == MemKind => EmptyStmt - case sx => sx map replaceExp(repl) - } - def replaceExp(repl: Netlist)(e: Expression): Expression = e match { - case ex: WSubField => repl get ex match { - case Some(exx) => exx - case None => ex - } - case ex => ex map replaceExp(repl) + newConns ++= (readStmts ++ writeStmts).flatMap(_.conns) + val newReaders = mem.readers ++ mem.readwriters.map(rMap(_)) + val newWriters = mem.writers ++ mem.readwriters.map(wMap(_)) + val newMem = DefMemory(mem.info, mem.name, mem.dataType, mem.depth, 1, 0, newReaders, newWriters, Nil) + Block(newMem +: (readStmts ++ writeStmts).flatMap(_.decls)) + case sx: Connect if kind(sx.loc) == MemKind => EmptyStmt // Filter old mem connections + case sx => sx.map(swapMemRefs) } - def appendStmts(sx: Seq[Statement])(s: Statement): Statement = Block(s +: sx) - - def memDelayMod(m: DefModule): DefModule = { - val netlist = new Netlist - val namespace = Namespace(m) - val repl = new Netlist - val extraStmts = mutable.ArrayBuffer.empty[Statement] - m.foreach(buildNetlist(netlist)) - m.map(memDelayStmt(netlist, namespace, repl, extraStmts)) - .map(appendStmts(extraStmts)) + val transformed = m match { + case mod: Module => + findMemConns(mod.body) + mod.copy(body = Block(transform(mod.body) +: newConns.toSeq)) + case mod => mod } +} - def run(c: Circuit): Circuit = - c copy (modules = c.modules map memDelayMod) +object VerilogMemDelays extends Pass { + def transform(m: DefModule): DefModule = (new MemDelayAndReadwriteTransformer(m)).transformed + def run(c: Circuit): Circuit = c.copy(modules = c.modules.map(transform)) } diff --git a/src/test/scala/firrtlTests/MemEnFeedbackSpec.scala b/src/test/scala/firrtlTests/MemEnFeedbackSpec.scala new file mode 100644 index 00000000..d94d199a --- /dev/null +++ b/src/test/scala/firrtlTests/MemEnFeedbackSpec.scala @@ -0,0 +1,41 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ + +// Tests long-standing bug from #1179, VerilogMemDelays producing combinational loops in corner case +abstract class MemEnFeedbackSpec extends FirrtlFlatSpec { + val ruw: String + def input: String = + s"""circuit loop : + | module loop : + | input clk : Clock + | input raddr : UInt<5> + | mem m : + | data-type => UInt<1> + | depth => 32 + | reader => r + | read-latency => 1 + | write-latency => 1 + | read-under-write => ${ruw} + | m.r.clk <= clk + | m.r.addr <= raddr + | m.r.en <= m.r.data + |""".stripMargin + def compileInput(): Unit = (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) +} + +class WriteFirstMemEnFeedbackSpec extends MemEnFeedbackSpec { + val ruw = "new" + "A write-first sync-read mem with feedback from data to enable" should "compile without errors" in { + compileInput() + } +} + +class ReadFirstMemEnFeedbackSpec extends MemEnFeedbackSpec { + val ruw = "old" + "A read-first sync-read mem with feedback from data to enable" should "compile without errors" in { + compileInput() + } +} diff --git a/src/test/scala/firrtlTests/MemLatencySpec.scala b/src/test/scala/firrtlTests/MemLatencySpec.scala new file mode 100644 index 00000000..79986cc2 --- /dev/null +++ b/src/test/scala/firrtlTests/MemLatencySpec.scala @@ -0,0 +1,130 @@ +package firrtlTests + +import firrtlTests.execution._ + +object MemLatencySpec { + case class Write(addr: Int, data: Int, mask: Option[Boolean] = None) + case class Read(addr: Int, expectedValue: Int) + case class MemAccess(w: Option[Write], r: Option[Read]) + def writeOnly(addr: Int, data: Int) = MemAccess(Some(Write(addr, data)), None) + def readOnly(addr: Int, expectedValue: Int) = MemAccess(None, Some(Read(addr, expectedValue))) +} + +abstract class MemLatencySpec(rLatency: Int, wLatency: Int, ruw: String) + extends SimpleExecutionTest + with VerilogExecution { + + import MemLatencySpec._ + + require(rLatency >= 0, s"Illegal read-latency ${rLatency} supplied to MemLatencySpec") + require(wLatency > 0, s"Illegal write-latency ${wLatency} supplied to MemLatencySpec") + + val body = + s"""mem m : + | data-type => UInt<32> + | depth => 256 + | reader => r + | writer => w + | read-latency => ${rLatency} + | write-latency => ${wLatency} + | read-under-write => ${ruw} + |m.r.clk <= clock + |m.w.clk <= clock + |""".stripMargin + + val memAccesses: Seq[MemAccess] + + def mask2Poke(m: Option[Boolean]) = m match { + case Some(false) => Poke("m.w.mask", 0) + case _ => Poke("m.w.mask", 1) + } + + def wPokes = memAccesses.map { + case MemAccess(Some(Write(a, d, m)), _) => + Seq(Poke("m.w.en", 1), Poke("m.w.addr", a), Poke("m.w.data", d), mask2Poke(m)) + case _ => Seq(Poke("m.w.en", 0), Invalidate("m.w.addr"), Invalidate("m.w.data")) + } + + def rPokes = memAccesses.map { + case MemAccess(_, Some(Read(a, _))) => Seq(Poke("m.r.en", 1), Poke("m.r.addr", a)) + case _ => Seq(Poke("m.r.en", 0), Invalidate("m.r.addr")) + } + + // Need to idle for <rLatency> cycles at the end + val idle = Seq(Poke("m.w.en", 0), Poke("m.r.en", 0)) + def pokes = (wPokes zip rPokes).map { case (wp, rp) => wp ++ rp } ++ Seq.fill(rLatency)(idle) + + // Need to delay read value expects by <rLatency> + def expects = Seq.fill(rLatency)(Seq(Step(1))) ++ memAccesses.map { + case MemAccess(_, Some(Read(_, expected))) => Seq(Expect("m.r.data", expected), Step(1)) + case _ => Seq(Step(1)) + } + + def commands: Seq[SimpleTestCommand] = (pokes zip expects).flatMap { case (p, e) => p ++ e } +} + +trait ToggleMaskAndEnable { + import MemLatencySpec._ + /** + * A canonical sequence of memory accesses for sanity checking memories of different latencies. + * The shortest true "RAW" hazard is reading address 14 two accesses after writing it. Since this + * access assumed the new value of 87, this means that the access pattern is only valid for + * certain combinations of read- and write-latencies that vary between read- and write-first + * memories. + * + * @note Read-first mems should return expected values for (write-latency <= 2) + * @note Write-first mems should return expected values for (write-latency <= read-latency + 2) + */ + val memAccesses: Seq[MemAccess] = Seq( + MemAccess(Some(Write(6, 32)), None), + MemAccess(Some(Write(14, 87)), None), + MemAccess(None, None), + MemAccess(Some(Write(19, 63)), Some(Read(14, 87))), + MemAccess(Some(Write(22, 49)), None), + MemAccess(Some(Write(11, 99)), Some(Read(6, 32))), + MemAccess(Some(Write(42, 42)), None), + MemAccess(Some(Write(77, 81)), None), + MemAccess(Some(Write(6, 7)), Some(Read(19, 63))), + MemAccess(Some(Write(39, 5)), Some(Read(42, 42))), + MemAccess(Some(Write(39, 6, Some(false))), Some(Read(77, 81))), // set mask to zero, should not write + MemAccess(None, Some(Read(6, 7))), // also read a twice-written address + MemAccess(None, Some(Read(39, 5))) // ensure masked writes didn't happen + ) +} + +/* + * This framework is for execution tests, so these tests all focus on + * *legal* configurations. Illegal memory parameters that should + * result in errors should be tested in MemSpec. + */ + +// These two are the same in practice, but the two tests could help expose bugs in VerilogMemDelays +class CombMemSpecNewRUW extends MemLatencySpec(rLatency = 0, wLatency = 1, ruw = "new") with ToggleMaskAndEnable +class CombMemSpecOldRUW extends MemLatencySpec(rLatency = 0, wLatency = 1, ruw = "old") with ToggleMaskAndEnable + +// Odd combination: combinational read with 2-cycle write latency +class CombMemWL2SpecNewRUW extends MemLatencySpec(rLatency = 0, wLatency = 2, ruw = "new") with ToggleMaskAndEnable +class CombMemWL2SpecOldRUW extends MemLatencySpec(rLatency = 0, wLatency = 2, ruw = "old") with ToggleMaskAndEnable + +// Standard sync read mem +class WriteFirstMemToggleSpec extends MemLatencySpec(rLatency = 1, wLatency = 1, ruw = "new") with ToggleMaskAndEnable +class ReadFirstMemToggleSpec extends MemLatencySpec(rLatency = 1, wLatency = 1, ruw = "old") with ToggleMaskAndEnable + +// Read latency 2 +class WriteFirstMemToggleSpecRL2 extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "new") with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL2 extends MemLatencySpec(rLatency = 2, wLatency = 1, ruw = "old") with ToggleMaskAndEnable + +// Write latency 2 +class WriteFirstMemToggleSpecWL2 extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "new") with ToggleMaskAndEnable +class ReadFirstMemToggleSpecWL2 extends MemLatencySpec(rLatency = 1, wLatency = 2, ruw = "old") with ToggleMaskAndEnable + +// Read latency 2, write latency 2 +class WriteFirstMemToggleSpecRL2WL2 extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "new") with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL2WL2 extends MemLatencySpec(rLatency = 2, wLatency = 2, ruw = "old") with ToggleMaskAndEnable + +// Read latency 3, write latency 2 +class WriteFirstMemToggleSpecRL3WL2 extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "new") with ToggleMaskAndEnable +class ReadFirstMemToggleSpecRL3WL2 extends MemLatencySpec(rLatency = 3, wLatency = 2, ruw = "old") with ToggleMaskAndEnable + +// Read latency 2, write latency 4 -> ToggleSpec pattern only valid for write-first at this combo +class WriteFirstMemToggleSpecRL2WL4 extends MemLatencySpec(rLatency = 2, wLatency = 4, ruw = "new") with ToggleMaskAndEnable diff --git a/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala b/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala new file mode 100644 index 00000000..7d250664 --- /dev/null +++ b/src/test/scala/firrtlTests/execution/ExecutionTestHelper.scala @@ -0,0 +1,112 @@ +package firrtlTests.execution + +import firrtl._ +import firrtl.ir._ + +object DUTRules { + val dutName = "dut" + val clock = Reference("clock", ClockType) + val reset = Reference("reset", Utils.BoolType) + val counter = Reference("step", UnknownType) + + // Need a flat name for the register that latches poke values + val illegal = raw"[\[\]\.]".r + val pokeRegSuffix = "_poke" + def pokeRegName(e: Expression) = illegal.replaceAllIn(e.serialize, "_") + pokeRegSuffix + + // Naming patterns are static, so DUT has to be checked for proper form + collisions + def hasNameConflicts(c: Circuit): Boolean = { + val top = c.modules.find(_.name == c.main).get + val names = Namespace(top).cloneUnderlying + names.contains(counter.name) || names.exists(_.contains(pokeRegSuffix)) + } +} + +object ExecutionTestHelper { + val counterType = UIntType(IntWidth(32)) + def apply(body: String): ExecutionTestHelper = { + // Parse input and check that it complies with test syntax rules + val c = ParseStatement.makeDUT(body) + require(!DUTRules.hasNameConflicts(c), "Avoid using 'step' or 'poke' in DUT component names") + + // Generate test step counter, create ExecutionTestHelper that represents initial test state + val cnt = DefRegister(NoInfo, DUTRules.counter.name, counterType, DUTRules.clock, DUTRules.reset, Utils.zero) + val inc = Connect(NoInfo, DUTRules.counter, DoPrim(PrimOps.Add, Seq(DUTRules.counter, UIntLiteral(1)), Nil, UnknownType)) + ExecutionTestHelper(c, Seq(cnt, inc), Map.empty[Expression, Expression], Nil, Nil) + } +} + +case class ExecutionTestHelper( + dut: Circuit, + setup: Seq[Statement], + pokeRegs: Map[Expression, Expression], + completedSteps: Seq[Conditionally], + activeStep: Seq[Statement] +) { + + def step(n: Int): ExecutionTestHelper = { + require(n > 0, "Step length must be positive") + (0 until n).foldLeft(this) { case (eth, int) => eth.next } + } + + def poke(expString: String, value: Literal): ExecutionTestHelper = { + val pokeExp = ParseExpression(expString) + val pokeable = ensurePokeable(pokeExp) + pokeable.addStatements( + Connect(NoInfo, pokeExp, value), + Connect(NoInfo, pokeable.pokeRegs(pokeExp), value)) + } + + def invalidate(expString: String): ExecutionTestHelper = { + addStatements(IsInvalid(NoInfo, ParseExpression(expString))) + } + + def expect(expString: String, value: Literal): ExecutionTestHelper = { + val peekExp = ParseExpression(expString) + val neq = DoPrim(PrimOps.Neq, Seq(peekExp, value), Nil, Utils.BoolType) + addStatements(Stop(NoInfo, 1, DUTRules.clock, neq)) + } + + def finish(): ExecutionTestHelper = { + addStatements(Stop(NoInfo, 0, DUTRules.clock, Utils.one)).next + } + + // Private helper methods + + private def t = completedSteps.length + + private def addStatements(stmts: Statement*) = copy(activeStep = activeStep ++ stmts) + + private def next: ExecutionTestHelper = { + val count = Reference(DUTRules.counter.name, DUTRules.counter.tpe) + val ifStep = DoPrim(PrimOps.Eq, Seq(count, UIntLiteral(t)), Nil, Utils.BoolType) + val onThisStep = Conditionally(NoInfo, ifStep, Block(activeStep), EmptyStmt) + copy(completedSteps = completedSteps :+ onThisStep, activeStep = Nil) + } + + private def top: Module = { + dut.modules.collectFirst({ case m: Module if m.name == dut.main => m }).get + } + + private[execution] def emit: Circuit = { + val finished = finish() + val modulesX = dut.modules.collect { + case m: Module if m.name == dut.main => + m.copy(body = Block(m.body +: (setup ++ finished.completedSteps))) + case m => m + } + dut.copy(modules = modulesX) + } + + private def ensurePokeable(pokeExp: Expression): ExecutionTestHelper = { + if (pokeRegs.contains(pokeExp)) { + this + } else { + val pName = DUTRules.pokeRegName(pokeExp) + val pRef = Reference(pName, UnknownType) + val pReg = DefRegister(NoInfo, pName, UIntType(UnknownWidth), DUTRules.clock, Utils.zero, pRef) + val defaultConn = Connect(NoInfo, pokeExp, pRef) + copy(setup = setup ++ Seq(pReg, defaultConn), pokeRegs = pokeRegs + (pokeExp -> pRef)) + } + } +} diff --git a/src/test/scala/firrtlTests/execution/ParserHelpers.scala b/src/test/scala/firrtlTests/execution/ParserHelpers.scala new file mode 100644 index 00000000..3472c19c --- /dev/null +++ b/src/test/scala/firrtlTests/execution/ParserHelpers.scala @@ -0,0 +1,52 @@ +package firrtlTests.execution + +import firrtl._ +import firrtl.ir._ + +class ParserHelperException(val pe: ParserException, input: String) + extends FirrtlUserException(s"Got error ${pe.toString} while parsing input:\n${input}") + +/** + * A utility class that parses a FIRRTL string representing a statement to a sub-AST + */ +object ParseStatement { + private def wrapStmtStr(stmtStr: String): String = { + val indent = " " + val indented = stmtStr.split("\n").mkString(indent, s"\n${indent}", "") + s"""circuit ${DUTRules.dutName} : + | module ${DUTRules.dutName} : + | input clock : Clock + | input reset : UInt<1> + |${indented}""".stripMargin + } + + private def parse(stmtStr: String): Circuit = { + try { + Parser.parseString(wrapStmtStr(stmtStr), Parser.IgnoreInfo) + } catch { + case e: ParserException => throw new ParserHelperException(e, stmtStr) + } + } + + def apply(stmtStr: String): Statement = { + val c = parse(stmtStr) + val stmt = c.modules.collectFirst { case Module(_, _, _, b: Block) => b.stmts.head } + stmt.get + } + + private[execution] def makeDUT(body: String): Circuit = parse(body) +} + +/** + * A utility class that parses a FIRRTL string representing an expression to a sub-AST + */ +object ParseExpression { + def apply(expStr: String): Expression = { + try { + val s = ParseStatement(s"${expStr} is invalid") + s.asInstanceOf[IsInvalid].expr + } catch { + case e: ParserHelperException => throw new ParserHelperException(e.pe, expStr) + } + } +} diff --git a/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala b/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala new file mode 100644 index 00000000..5abeb819 --- /dev/null +++ b/src/test/scala/firrtlTests/execution/SimpleExecutionTest.scala @@ -0,0 +1,84 @@ +package firrtlTests.execution + +import java.io.File + +import firrtl.ir._ +import firrtlTests._ + +sealed trait SimpleTestCommand +case class Step(n: Int) extends SimpleTestCommand +case class Invalidate(expStr: String) extends SimpleTestCommand +case class Poke(expStr: String, value: Int) extends SimpleTestCommand +case class Expect(expStr: String, value: Int) extends SimpleTestCommand + +/** + * This trait defines an interface to run a self-contained test circuit. + */ +trait TestExecution { + def runEmittedDUT(c: Circuit, testDir: File): Unit +} + +/** + * A class that makes it easier to write execution-driven tests. + * + * By combining a DUT body (supplied as a string without an enclosing + * module or circuit) with a sequence of test operations, an + * executable, self-contained Verilog testbench may be automatically + * created and checked. + * + * @note It is necessary to mix in a trait extending TestExecution + * @note The DUT has two implicit ports, "clock" and "reset" + * @note Execution of the command sequences begins after reset is deasserted + * + * @see [[firrtlTests.execution.TestExecution]] + * @see [[firrtlTests.execution.VerilogExecution]] + * + * @example {{{ + * class AndTester extends SimpleExecutionTest with VerilogExecution { + * val body = "reg r : UInt<32>, clock with: (reset => (reset, UInt<32>(0)))" + * val commands = Seq( + * Expect("r", 0), + * Poke("r", 3), + * Step(1), + * Expect("r", 3) + * ) + * } + * }}} + */ +abstract class SimpleExecutionTest extends FirrtlPropSpec { + this: TestExecution => + + /** + * Text representing the body of the DUT. This is useful for testing + * statement-level language features, and cuts out the overhead of + * writing a top-level DUT module and having peeks/pokes point at + * IOs. + */ + val body: String + + /** + * A sequence of commands (peeks, pokes, invalidates, steps) that + * represents how the testbench will progress. The semantics are + * inspired by chisel-testers. + */ + def commands: Seq[SimpleTestCommand] + + private def interpretCommand(eth: ExecutionTestHelper, cmd: SimpleTestCommand) = cmd match { + case Step(n) => eth.step(n) + case Invalidate(expStr) => eth.invalidate(expStr) + case Poke(expStr, value) => eth.poke(expStr, UIntLiteral(value)) + case Expect(expStr, value) => eth.expect(expStr, UIntLiteral(value)) + } + + private def runTest(): Unit = { + val initial = ExecutionTestHelper(body) + val test = commands.foldLeft(initial)(interpretCommand(_, _)) + val testName = this.getClass.getSimpleName + val testDir = createTestDirectory(s"${testName}-generated-src") + runEmittedDUT(test.emit, testDir) + } + + property("Execution of the compiled Verilog for ExecutionTestHelper should succeed") { + runTest() + } +} diff --git a/src/test/scala/firrtlTests/execution/VerilogExecution.scala b/src/test/scala/firrtlTests/execution/VerilogExecution.scala new file mode 100644 index 00000000..17eecc65 --- /dev/null +++ b/src/test/scala/firrtlTests/execution/VerilogExecution.scala @@ -0,0 +1,32 @@ +package firrtlTests.execution + +import java.io.File + +import firrtl._ +import firrtl.ir._ +import firrtlTests._ + +import firrtl.stage.{FirrtlCircuitAnnotation, FirrtlStage} +import firrtl.options.TargetDirAnnotation + +/** + * Mixing in this trait causes a SimpleExecutionTest to be run in Verilog simulation. + */ +trait VerilogExecution extends TestExecution { + this: SimpleExecutionTest => + def runEmittedDUT(c: Circuit, testDir: File): Unit = { + // Run FIRRTL, emit Verilog file + val cAnno = FirrtlCircuitAnnotation(c) + val tdAnno = TargetDirAnnotation(testDir.getAbsolutePath) + (new FirrtlStage).run(AnnotationSeq(Seq(cAnno, tdAnno))) + + // Copy harness resource to test directory + val harness = new File(testDir, s"top.cpp") + copyResourceToFile(cppHarnessResourceName, harness) + + // Make and run Verilog simulation + verilogToCpp(c.main, testDir, Nil, harness).! + cppToExe(c.main, testDir).! + assert(executeExpectingSuccess(c.main, testDir)) + } +} |
