diff options
| author | Kevin Laeufer | 2021-03-04 11:23:51 -0800 |
|---|---|---|
| committer | GitHub | 2021-03-04 19:23:51 +0000 |
| commit | c93d6f5319efd7ba42147180c6e2b6f3796ef943 (patch) | |
| tree | cea5c960c6fd15c1680f43d78fa06a69dda7dc6e /src | |
| parent | e58ba0c12e5d650983c70a61a45542f0cd43fb88 (diff) | |
SMT Backend: move undefined memory behavior modelling to firrtl IR level (#2095)
With this PR the smt backend now supports memories
with more than two write ports and the conservative
memory modelling can be selectively turned off with
a new annotation.
Diffstat (limited to 'src')
12 files changed, 1060 insertions, 217 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 72884e25..3d0f19b8 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -900,6 +900,79 @@ object Utils extends LazyLogging { def maskBigInt(value: BigInt, width: Int): BigInt = { value & ((BigInt(1) << width) - 1) } + + /** Returns true iff the expression is a Literal or a Literal cast to a different type. */ + def isLiteral(e: Expression): Boolean = e match { + case _: Literal => true + case DoPrim(op, args, _, _) if isCast(op) => args.exists(isLiteral) + case _ => false + } + + /** Applies the firrtl And primop. Automatically constant propagates when one of the expressions is True or False. */ + def and(e1: Expression, e2: Expression): Expression = { + assert(e1.tpe == e2.tpe) + (e1, e2) match { + case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width) + case (True(), b) => b + case (a, True()) => a + case (False(), _) => False() + case (_, False()) => False() + case (a, b) if a == b => a + case (a, b) => DoPrim(PrimOps.And, Seq(a, b), Nil, BoolType) + } + } + + /** Applies the firrtl Eq primop. */ + def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, BoolType) + + /** Applies the firrtl Or primop. Automatically constant propagates when one of the expressions is True or False. */ + def or(e1: Expression, e2: Expression): Expression = { + assert(e1.tpe == e2.tpe) + (e1, e2) match { + case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width) + case (True(), _) => True() + case (_, True()) => True() + case (False(), b) => b + case (a, False()) => a + case (a, b) if a == b => a + case (a, b) => DoPrim(PrimOps.Or, Seq(a, b), Nil, BoolType) + } + } + + /** Applies the firrtl Not primop. Automatically constant propagates when the expressions is True or False. */ + def not(e: Expression): Expression = e match { + case True() => False() + case False() => True() + case a => DoPrim(PrimOps.Not, Seq(a), Nil, BoolType) + } + + /** implies(e1, e2) = or(not(e1), e2). Automatically constant propagates when one of the expressions is True or False. */ + def implies(e1: Expression, e2: Expression): Expression = or(not(e1), e2) + + /** Builds a Mux expression with the correct type. */ + def mux(cond: Expression, tval: Expression, fval: Expression): Expression = { + require(tval.tpe == fval.tpe) + Mux(cond, tval, fval, tval.tpe) + } + + object True { + private val _True = UIntLiteral(1, IntWidth(1)) + + /** Matches `UInt<1>(1)` */ + def unapply(e: UIntLiteral): Boolean = e.value == 1 && e.width == _True.width + + /** Returns `UInt<1>(1)` */ + def apply(): UIntLiteral = _True + } + object False { + private val _False = UIntLiteral(0, IntWidth(1)) + + /** Matches `UInt<1>(0)` */ + def unapply(e: UIntLiteral): Boolean = e.value == 0 && e.width == _False.width + + /** Returns `UInt<1>(0)` */ + def apply(): UIntLiteral = _False + } } object MemoizedHash { diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 78536a36..e9dd95bc 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -13,6 +13,7 @@ trait Kind case object WireKind extends Kind case object PoisonKind extends Kind case object RegKind extends Kind +case object RandomKind extends Kind case object InstanceKind extends Kind case object PortKind extends Kind case object NodeKind extends Kind diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index c1857667..d3a1ed68 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -5,9 +5,12 @@ package firrtl.backends.experimental.smt import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} import FirrtlExpressionSemantics.getWidth +import firrtl.backends.experimental.smt.random._ import firrtl.graph.MutableDiGraph import firrtl.options.Dependency +import firrtl.passes.MemPortUtils.memPortField import firrtl.passes.PassException +import firrtl.passes.memlib.VerilogMemDelays import firrtl.stage.Forms import firrtl.stage.TransformManager.TransformDependency import firrtl.transforms.{DeadCodeElimination, PropagatePresetAnnotations} @@ -58,7 +61,8 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration { // TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops // TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare // precisely and PadWidths emits ill-typed firrtl. - override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm + override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++ + Seq(Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays)) override def invalidates(a: Transform): Boolean = false // since this pass only runs on the main module, inlining needs to happen before override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances]) @@ -169,8 +173,7 @@ private class ModuleToTransitionSystem extends LazyLogging { onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs) } // turn memories into state - val memoryEncoding = new MemoryEncoding(makeRandom, scan.namespace) - val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name))) + val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name))) // replace pseudo assigns for memory outputs val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap val signalsWithMem = signals.map { s => @@ -185,7 +188,7 @@ private class ModuleToTransitionSystem extends LazyLogging { case _ => true } ) - val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1) + val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1) // generate comments from infos val comments = mutable.HashMap[String, String]() @@ -247,233 +250,116 @@ private class ModuleToTransitionSystem extends LazyLogging { } } - private val InfoSeparator = ", " - private val InfoPrefix = "@ " - private def serializeInfo(info: ir.Info): Option[String] = info match { - case ir.NoInfo => None - case f: ir.FileInfo => Some(f.escaped) - case m: ir.MultiInfo => - val infos = m.flatten - if (infos.isEmpty) { None } - else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } - } - - private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() - private def makeRandom(baseName: String, width: Int): BVExpr = { - // TODO: actually ensure that there cannot be any name clashes with other identifiers - val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) - val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get - val sym = BVSymbol(name, width) - randoms(name) = sym - sym - } -} - -private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Namespace) extends LazyLogging { type Connects = Iterable[(String, BVExpr)] - def onMemory( - defMem: ir.DefMemory, - connects: Connects, - initValue: Option[MemoryInitValue] - ): (Iterable[State], Connects) = { - // we can only work on appropriately lowered memories - assert( - defMem.dataType.isInstanceOf[ir.GroundType], - s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!" - ) - assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = { + checkMem(m) - // collect all memory meta-data in a custom class - val m = new MemInfo(defMem) + // map of inputs to the memory + val inputs = connects.filter(_._1.startsWith(m.name)).toMap - // find all connections related to this memory - val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap + // derive the type of the memory from the dataType and depth + val dataWidth = getWidth(m.dataType) + val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) + val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) // there could be a constant init - val init = initValue.map(getInit(m, _)) - - // parse and check read and write ports - val writers = defMem.writers.map(w => new WritePort(m, w, inputs)) - val readers = defMem.readers.map(r => new ReadPort(m, r, inputs)) - - // derive next state from all write ports - assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.") - val next: ArrayExpr = if (writers.isEmpty) { m.sym } - else { - if (writers.length > 2) { - throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})") - } - val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) } - if (writers.length == 1) { validData } - else { - assert(writers.length == 2) - val conflict = writers.head.doesConflict(writers.last) - val conflictData = writers.head.makeRandomData("_write_write_collision") - val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData) - ArrayIte(conflict, conflictStore, validData) - } - } - val state = State(m.sym, init, Some(next)) + val init = initValue.map(getInit(m, indexWidth, dataWidth, _)) + init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth)) - // derive data signals from all read ports - assert(defMem.readLatency >= 0) - if (defMem.readLatency > 1) { - throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})") - } - val readPortSignals = if (defMem.readLatency == 0) { - readers.map { r => - // combinatorial read - if (defMem.readUnderWrite != ir.ReadUnderWrite.New) { - logger.warn( - s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." + - s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored." - ) - } - // since we do a combinatorial read, the "old" data is the current data - val data = r.read() - r.data.name -> data - } - } else { Seq() } - val readPortSignalsAndStates = if (defMem.readLatency == 1) { - readers.map { r => - defMem.readUnderWrite match { - case ir.ReadUnderWrite.New => - // create a state to save the address and the enable signal - val enPrev = BVSymbol(namespace.newName(r.en.name + "_prev"), r.en.width) - val addrPrev = BVSymbol(namespace.newName(r.addr.name + "_prev"), r.addr.width) - val signal = r.data.name -> r.read(addr = addrPrev, en = enPrev) - val states = Seq(State(enPrev, None, next = Some(r.en)), State(addrPrev, None, next = Some(r.addr))) - (Seq(signal), states) - case ir.ReadUnderWrite.Undefined => - // check for potential read/write conflicts in which case we need to return an arbitrary value - val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r))) - val next = if (anyWriteToTheSameAddress == False) { r.read() } - else { - val readUnderWriteData = r.makeRandomData("_read_under_write_undefined") - BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.read()) - } - (Seq(), Seq(State(r.data, init = None, next = Some(next)))) - case ir.ReadUnderWrite.Old => - // we create a register for the read port data - (Seq(), Seq(State(r.data, init = None, next = Some(r.read())))) - } + // derive next state expression + val next = if (m.writers.isEmpty) { + memSymbol + } else { + m.writers.foldLeft[ArrayExpr](memSymbol) { + case (prev, write) => + // update + val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth) + val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth) + val update = ArrayStore(prev, index = addr, data = data) + + // update guard + val en = BVSymbol(memPortField(m, write, "en").serialize, 1) + val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1) + val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True) + if (alwaysEnabled) { update } + else { + ArrayIte(and(en, mask), update, prev) + } } - } else { Seq() } + } - val allReadPortSignals = readPortSignals ++ readPortSignalsAndStates.flatMap(_._1) - val readPortStates = readPortSignalsAndStates.flatMap(_._2) + val state = State(memSymbol, init, Some(next)) - (state +: readPortStates, allReadPortSignals) - } + // derive read expressions + val readSignals = m.readers.map { read => + val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth) + memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr) + } - private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match { - case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth) - case MemoryArrayInit(values) => - assert( - values.length == m.depth, - s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" - ) - // in order to get a more compact encoding try to find the most common values - val histogram = mutable.LinkedHashMap[BigInt, Int]() - values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) - val baseValue = histogram.maxBy(_._2)._1 - val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth) - values.zipWithIndex - .filterNot(_._1 == baseValue) - .foldLeft[ArrayExpr](base) { - case (array, (value, index)) => - ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth)) - } - case other => throw new RuntimeException(s"Unsupported memory init option: $other") + (state, readSignals) } - private class MemInfo(m: ir.DefMemory) { - val name = m.name - val depth = m.depth - // derrive the type of the memory from the dataType and depth - val dataWidth = getWidth(m.dataType) - val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) - val sym = ArraySymbol(m.name, indexWidth, dataWidth) - val prefix = m.name + "." - val fullAddressRange = (BigInt(1) << indexWidth) == m.depth - lazy val depthBV = BVLiteral(m.depth, indexWidth) - def isValidAddress(addr: BVExpr): BVExpr = { - if (fullAddressRange) { True } - else { - BVComparison(Compare.Greater, depthBV, addr, signed = false) - } - } - } - private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) { - val en: BVSymbol = makeField("en", 1) - val data: BVSymbol = makeField("data", memory.dataWidth) - val addr: BVSymbol = makeField("addr", memory.indexWidth) - protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width) - // make sure that all widths are correct - assert(inputs(en.name).width == en.width) - assert(inputs(addr.name).width == addr.width) - val enIsTrue: Boolean = inputs(en.name) == True - def makeRandomData(suffix: String): BVExpr = - makeRandom(memory.name + "_" + name + suffix, memory.dataWidth) - def read(addr: BVSymbol = addr, en: BVSymbol = en): BVExpr = { - val canBeOutOfRange = !memory.fullAddressRange - val canBeDisabled = !enIsTrue - val data = ArrayRead(memory.sym, addr) - val dataWithRangeCheck = if (canBeOutOfRange) { - val outOfRangeData = makeRandomData("_addr_out_of_range") - BVIte(memory.isValidAddress(addr), data, outOfRangeData) - } else { data } - val dataWithEnabledCheck = if (canBeDisabled) { - val disabledData = makeRandomData("_not_enabled") - BVIte(en, dataWithRangeCheck, disabledData) - } else { dataWithRangeCheck } - dataWithEnabledCheck - } - } - private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) { - assert(inputs(data.name).width == data.width) - val mask: BVSymbol = makeField("mask", 1) - assert(inputs(mask.name).width == mask.width) - val maskIsTrue: Boolean = inputs(mask.name) == True - val doWrite: BVExpr = (enIsTrue, maskIsTrue) match { - case (true, true) => True - case (true, false) => mask - case (false, true) => en - case (false, false) => and(en, mask) - } - def doesConflict(r: ReadPort): BVExpr = { - val sameAddress = BVEqual(r.addr, addr) - if (doWrite == True) { sameAddress } - else { and(doWrite, sameAddress) } - } - def doesConflict(w: WritePort): BVExpr = { - val bothWrite = and(doWrite, w.doWrite) - val sameAddress = BVEqual(addr, w.addr) - if (bothWrite == True) { sameAddress } - else { and(bothWrite, sameAddress) } - } - def writeTo(array: ArrayExpr): ArrayExpr = { - val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr)) - val update = ArrayStore(array, index = addr, data = data) - if (doUpdate == True) update else ArrayIte(doUpdate, update, array) + private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr = + initValue match { + case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth) + case MemoryArrayInit(values) => + assert( + values.length == m.depth, + s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!" + ) + // in order to get a more compact encoding try to find the most common values + val histogram = mutable.LinkedHashMap[BigInt, Int]() + values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0)) + val baseValue = histogram.maxBy(_._2)._1 + val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth) + values.zipWithIndex + .filterNot(_._1 == baseValue) + .foldLeft[ArrayExpr](base) { + case (array, (value, index)) => + ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth)) + } + case other => throw new RuntimeException(s"Unsupported memory init option: $other") } - } - private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr) - extends MemPort(memory, name, inputs) {} - + // TODO: add to BV expression library private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match { case (True, True) => True case (True, x) => x case (x, True) => x case _ => BVOp(Op.And, a, b) } - private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b) + private val True = BVLiteral(1, 1) - private val False = BVLiteral(0, 1) - private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b)) - private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b)) + private def checkMem(m: ir.DefMemory): Unit = { + assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?") + assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?") + assert( + m.dataType.isInstanceOf[ir.GroundType], + s"Memory $m is of type ${m.dataType} which is not a ground type!" + ) + assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.") + } + + private val InfoSeparator = ", " + private val InfoPrefix = "@ " + private def serializeInfo(info: ir.Info): Option[String] = info match { + case ir.NoInfo => None + case f: ir.FileInfo => Some(f.escaped) + case m: ir.MultiInfo => + val infos = m.flatten + if (infos.isEmpty) { None } + else { Some(infos.map(_.escaped).mkString(InfoSeparator)) } + } + + private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]() + private def makeRandom(baseName: String, width: Int): BVExpr = { + // TODO: actually ensure that there cannot be any name clashes with other identifiers + val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii) + val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get + val sym = BVSymbol(name, width) + randoms(name) = sym + sym + } } // performas a first pass over the module collecting all connections, wires, registers, input and outputs @@ -526,6 +412,12 @@ private class ModuleScanner( } private[firrtl] def onStatement(s: ir.Statement): Unit = s match { + case DefRandom(info, name, tpe, _, _) => + namespace.newName(name) + assert(!isClock(tpe), "rand should never be a clock!") + // we model random sources as inputs and ignore the enable signal + infos.append(name -> info) + inputs.append(BVSymbol(name, getWidth(tpe))) case ir.DefWire(info, name, tpe) => namespace.newName(name) if (!isClock(tpe)) { diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala new file mode 100644 index 00000000..7381056e --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.random + +import firrtl.Utils +import firrtl.ir._ + +/** Named source of random values. If there is no clock expression, than it will be clocked by the global clock. */ +case class DefRandom( + info: Info, + name: String, + tpe: Type, + clock: Option[Expression], + en: Expression = Utils.True()) + extends Statement + with HasInfo + with IsDeclaration + with CanBeReferenced + with UseSerializer { + def mapStmt(f: Statement => Statement): Statement = this + def mapExpr(f: Expression => Expression): Statement = + DefRandom(info, name, tpe, clock.map(f), f(en)) + def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe)) + def mapString(f: String => String): Statement = this.copy(name = f(name)) + def mapInfo(f: Info => Info): Statement = this.copy(info = f(info)) + def foreachStmt(f: Statement => Unit): Unit = () + def foreachExpr(f: Expression => Unit): Unit = { clock.foreach(f); f(en) } + def foreachType(f: Type => Unit): Unit = f(tpe) + def foreachString(f: String => Unit): Unit = f(name) + def foreachInfo(f: Info => Unit): Unit = f(info) +} diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala new file mode 100644 index 00000000..91e77433 --- /dev/null +++ b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala @@ -0,0 +1,457 @@ +// SPDX-License-Identifier: Apache-2.0 + +package firrtl.backends.experimental.smt.random + +import firrtl.Utils.{isLiteral, kind, BoolType} +import firrtl.WrappedExpression.{we, weq} +import firrtl._ +import firrtl.annotations.NoTargetAnnotation +import firrtl.backends.experimental.smt._ +import firrtl.ir._ +import firrtl.options.Dependency +import firrtl.passes.MemPortUtils.memPortField +import firrtl.passes.memlib.AnalysisUtils.Connects +import firrtl.passes.memlib.InferReadWritePass.checkComplement +import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays} +import firrtl.stage.Forms +import firrtl.transforms.{ConstantPropagation, RemoveWires} + +import scala.collection.mutable + +/** Chooses which undefined memory behaviors should be instrumented. */ +case class UndefinedMemoryBehaviorOptions( + randomizeWriteWriteConflicts: Boolean = true, + assertNoOutOfBoundsWrites: Boolean = false, + randomizeOutOfBoundsRead: Boolean = true, + randomizeDisabledReads: Boolean = true, + randomizeReadWriteConflicts: Boolean = true) + extends NoTargetAnnotation + +/** Adds sources of randomness to model the various "undefined behaviors" of firrtl memory. + * - Write/Write conflict: leads to arbitrary value written to write address + * - Out-of-bounds write: assertion failure (disabled by default) + * - Out-Of-bounds read: leads to arbitrary value being read + * - Read w/ en=0: leads to arbitrary value being read + * - Read/Write conflict: leads to arbitrary value being read + */ +object UndefinedMemoryBehaviorPass extends Transform with DependencyAPIMigration { + override def prerequisites = Forms.LowForm + override def optionalPrerequisiteOf = Seq(Dependency(VerilogMemDelays)) + override def invalidates(a: Transform) = a match { + // this pass might destroy SSA form, as we add a wire for the data field of every read port + case _: RemoveWires => true + // TODO: should we add some optimization passes here? we could be generating some dead code. + case _ => false + } + + override protected def execute(state: CircuitState): CircuitState = { + val opts = state.annotations.collect { case o: UndefinedMemoryBehaviorOptions => o } + require(opts.size < 2, s"Multiple options: $opts") + val opt = opts.headOption.getOrElse(UndefinedMemoryBehaviorOptions()) + + val c = state.circuit.mapModule(onModule(_, opt)) + state.copy(circuit = c) + } + + private def onModule(m: DefModule, opt: UndefinedMemoryBehaviorOptions): DefModule = m match { + case mod: Module => + val mems = findMems(mod) + if (mems.isEmpty) { mod } + else { + val namespace = Namespace(mod) + val connects = AnalysisUtils.getConnects(mod) + new InstrumentMems(opt, mems, connects, namespace).run(mod) + } + case other => other + } + + /** finds all memory instantiations in a circuit */ + private def findMems(m: Module): List[DefMemory] = { + val mems = mutable.ListBuffer[DefMemory]() + m.foreachStmt(findMems(_, mems)) + mems.toList + } + private def findMems(s: Statement, mems: mutable.ListBuffer[DefMemory]): Unit = s match { + case mem: DefMemory => mems.append(mem) + case other => other.foreachStmt(findMems(_, mems)) + } +} + +private class InstrumentMems( + opt: UndefinedMemoryBehaviorOptions, + mems: List[DefMemory], + connects: Connects, + namespace: Namespace) { + def run(m: Module): DefModule = { + // ensure that all memories are the kind we can support + mems.foreach(checkSupported(m.name, _)) + + // transform circuit + val body = m.body.mapStmt(transform) + m.copy(body = Block(body +: newStmts.toList)) + } + + // used to replace memory signals like `m.r.data` in RHS expressions + private val exprReplacements = mutable.HashMap[String, Expression]() + // add new statements at the end of the circuit + private val newStmts = mutable.ListBuffer[Statement]() + // disconnect references so that they can be reassigned + private val doDisconnect = mutable.HashSet[String]() + + // generates new expression replacements and immediately uses them + private def transform(s: Statement): Statement = s.mapStmt(transform) match { + case mem: DefMemory => onMem(mem) + case sx: Connect if doDisconnect.contains(sx.loc.serialize) => EmptyStmt // Filter old mem connections + case sx => sx.mapExpr(swapMemRefs) + } + private def swapMemRefs(e: Expression): Expression = e.mapExpr(swapMemRefs) match { + case sf: RefLikeExpression => exprReplacements.getOrElse(sf.serialize, sf) + case ex => ex + } + + private def onMem(m: DefMemory): Statement = { + // collect wire and random statement defines + val declarations = mutable.ListBuffer[Statement]() + + // only for non power of 2 memories do we have to worry about reading or writing out of bounds + val canBeOutOfBounds = !isPow2(m.depth) + + // only if we have at least two write ports, can there be conflicts + val canHaveWriteWriteConflicts = m.writers.size > 1 + + // only certain memory types exhibit undefined read/write conflicts + val readWriteUndefined = (m.readLatency == m.writeLatency) && (m.readUnderWrite == ReadUnderWrite.Undefined) + assert( + m.readLatency == 0 || m.readLatency == m.writeLatency, + "TODO: what happens if a sync read mem has asymmetrical latencies?" + ) + + // a write port is enabled iff mask & en + val writeEn = m.writers.map { write => + val enRef = memPortField(m, write, "en") + val maskRef = memPortField(m, write, "mask") + + val prods = getProductTerms(enRef) ++ getProductTerms(maskRef) + + // if we can have write/write conflicts, we are going to change the mask and enable pins + val expr = if (canHaveWriteWriteConflicts) { + val maskIsOne = isTrue(connects(maskRef.serialize)) + // if the mask is connected to a constant true, we do not need to consider it, this is a common case + if (maskIsOne) { + val enWire = disconnectInput(m.info, enRef) + declarations += enWire + Reference(enWire) + } else { + val maskWire = disconnectInput(m.info, maskRef) + val enWire = disconnectInput(m.info, enRef) + // create a node for the conjunction + val nodeName = namespace.newName(s"${m.name}_${write}_mask_and_en") + val node = DefNode(m.info, nodeName, Utils.and(Reference(maskWire), Reference(enWire))) + declarations ++= List(maskWire, enWire, node) + Reference(node) + } + } else { + Utils.and(enRef, maskRef) + } + (expr, prods) + } + + // implement the three undefined read behaviors + m.readers.foreach { read => + // many memories have their read enable hard wired to true + val canBeDisabled = !isTrue(memPortField(m, read, "en")) + val readEn = if (canBeDisabled) memPortField(m, read, "en") else Utils.True() + val addr = memPortField(m, read, "addr") + + // collect signals that would lead to a randomization + var doRand = List[Expression]() + + // randomize the read value when the address is out of bounds + if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) { + val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr))) + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond) + declarations += node + doRand = Reference(node) +: doRand + } + + if (readWriteUndefined && opt.randomizeReadWriteConflicts) { + val (cond, d) = readWriteConflict(m, read, writeEn) + declarations ++= d + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond) + declarations += node + doRand = Reference(node) +: doRand + } + + // randomize the read value when the read is disabled + if (canBeDisabled && opt.randomizeDisabledReads) { + val cond = Utils.not(readEn) + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_disabled"), cond) + declarations += node + doRand = Reference(node) +: doRand + } + + // if there are no signals that would require a randomization, there is nothing to do + if (doRand.isEmpty) { + // nothing to do + } else { + val doRandName = s"${m.name}_${read}_do_rand" + val doRandNode = if (doRand.size == 1) { doRand.head } + else { + val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_do_rand"), doRand.reduce(Utils.or)) + declarations += node + Reference(node) + } + val doRandSignal = if (m.readLatency == 0) { doRandNode } + else { + val clock = memPortField(m, read, "clk") + val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency) + declarations ++= regDecls + signal + } + + // all old rhs references to m.r.data need to replace with m_r_data which might be random + val dataRef = memPortField(m, read, "data") + val dataWire = DefWire(m.info, namespace.newName(s"${m.name}_${read}_data"), m.dataType) + declarations += dataWire + exprReplacements(dataRef.serialize) = Reference(dataWire) + + // create a source of randomness and connect the new wire either to the actual data port or to the random value + val randName = namespace.newName(s"${m.name}_${read}_rand_data") + val random = DefRandom(m.info, randName, m.dataType, Some(memPortField(m, read, "clk")), doRandSignal) + declarations += random + val data = Utils.mux(doRandSignal, Reference(random), dataRef) + newStmts.append(Connect(m.info, Reference(dataWire), data)) + } + } + + // write + if (opt.randomizeWriteWriteConflicts) { + declarations ++= writeWriteConflicts(m, writeEn) + } + + // add an assertion that if the write is taking place, then the address must be in range + if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) { + m.writers.zip(writeEn).foreach { + case (write, (combinedEn, _)) => + val addr = memPortField(m, write, "addr") + val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr)) + val clk = memPortField(m, write, "clk") + val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read")) + newStmts.append(a) + } + } + + Block(m +: declarations.toList) + } + + private def pipeline( + info: Info, + clk: Expression, + prefix: String, + e: Expression, + latency: Int + ): (Expression, Seq[Statement]) = { + require(latency > 0) + val regs = (1 to latency).map { i => + val name = namespace.newName(prefix + s"_r$i") + DefRegister(info, name, e.tpe, clk, Utils.False(), Reference(name, e.tpe, RegKind, UnknownFlow)) + } + val expr = regs.foldLeft(e) { + case (prev, reg) => + newStmts.append(Connect(info, Reference(reg), prev)) + Reference(reg) + } + (expr, regs) + } + + private def readWriteConflict( + m: DefMemory, + read: String, + writeEn: Seq[(Expression, ProdTerms)] + ): (Expression, Seq[Statement]) = { + if (m.writers.isEmpty) return (Utils.False(), List()) + val declarations = mutable.ListBuffer[Statement]() + + val readEn = memPortField(m, read, "en") + val readProd = getProductTerms(readEn) + + // create all conflict signals + val conflicts = m.writers.zip(writeEn).map { + case (write, (writeEn, writeProd)) => + if (isMutuallyExclusive(readProd, writeProd)) { + Utils.False() + } else { + val name = namespace.newName(s"${m.name}_${read}_${write}_rwc") + val bothEn = Utils.and(readEn, writeEn) + val sameAddr = Utils.eq(memPortField(m, read, "addr"), memPortField(m, write, "addr")) + // we need a wire because this condition might be used in a random statement + val wire = DefWire(m.info, name, BoolType) + declarations += wire + newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr))) + Reference(wire) + } + } + + (conflicts.reduce(Utils.or), declarations.toList) + } + + private type ProdTerms = Seq[Expression] + private def writeWriteConflicts(m: DefMemory, writeEn: Seq[(Expression, ProdTerms)]): Seq[Statement] = { + if (m.writers.size < 2) return List() + val declarations = mutable.ListBuffer[Statement]() + + // we first create all conflict signals: + val conflict = + m.writers + .zip(writeEn) + .zipWithIndex + .flatMap { + case ((w1, (en1, en1Prod)), i1) => + m.writers.zip(writeEn).drop(i1 + 1).map { + case (w2, (en2, en2Prod)) => + if (isMutuallyExclusive(en1Prod, en2Prod)) { + (w1, w2) -> Utils.False() + } else { + val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc") + val bothEn = Utils.and(en1, en2) + val sameAddr = Utils.eq(memPortField(m, w1, "addr"), memPortField(m, w2, "addr")) + // we need a wire because this condition might be used in a random statement + val wire = DefWire(m.info, name, BoolType) + declarations += wire + newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr))) + (w1, w2) -> Reference(wire) + } + } + } + .toMap + + // now we calculate the new enable and data signals + m.writers.zip(writeEn).zipWithIndex.foreach { + case ((w1, (en1, _)), i1) => + val prev = m.writers.take(i1) + val next = m.writers.drop(i1 + 1) + + // the write is enabled if the original enable is true and there are no prior conflicts + val en = if (prev.isEmpty) { + en1 + } else { + val prevConflicts = prev.map(o => conflict(o, w1)).reduce(Utils.or) + Utils.and(en1, Utils.not(prevConflicts)) + } + + // we write random data if there is a conflict with any of the next ports + if (next.isEmpty) { + // nothing to do, leave data as is + } else { + val nextConflicts = next.map(n => conflict(w1, n)).reduce(Utils.or) + // if the conflict expression is more complex, create a node for the signal + val hasConflict = nextConflicts match { + case _: DoPrim | _: Mux => + val node = DefNode(m.info, namespace.newName(s"${m.name}_${w1}_wwc_active"), nextConflicts) + declarations += node + Reference(node) + case _ => nextConflicts + } + + // create the source of randomness + val name = namespace.newName(s"${m.name}_${w1}_wwc_data") + val random = DefRandom(m.info, name, m.dataType, Some(memPortField(m, w1, "clk")), hasConflict) + declarations.append(random) + // replace the old data input + val dataWire = disconnectInput(m.info, memPortField(m, w1, "data")) + declarations += dataWire + // generate new data input + val data = Utils.mux(hasConflict, Reference(random), Reference(dataWire)) + newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data)) + } + + // connect data enable signals + val maskIsOne = isTrue(connects(memPortField(m, w1, "mask").serialize)) + if (!maskIsOne) { + newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True())) + } + newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en)) + } + + declarations.toList + } + + /** check whether two signals can be proven to be mutually exclusive */ + private def isMutuallyExclusive(prodA: ProdTerms, prodB: ProdTerms): Boolean = { + // this uses the same approach as the InferReadWrite pass + val proofOfMutualExclusion = prodA.find(a => prodB.exists(b => checkComplement(a, b))) + proofOfMutualExclusion.nonEmpty + } + + /** replace a memory port with a wire */ + private def disconnectInput(info: Info, signal: RefLikeExpression): DefWire = { + // disconnect the old value + doDisconnect.add(signal.serialize) + + // if the old value is a literal, we just replace all references to it with this literal + val oldValue = connects(signal.serialize) + if (isLiteral(oldValue)) { + println("TODO: better code for literal") + } + + // create a new wire and replace all references to the original port with this wire + val wire = DefWire(info, copyName(signal), signal.tpe) + exprReplacements(signal.serialize) = Reference(wire) + // connect the old expression to the new wire + val con = Connect(info, Reference(wire), connects(signal.serialize)) + newStmts.append(con) + + // the wire definition should end up right after the memory definition + wire + } + + private def copyName(ref: RefLikeExpression): String = + namespace.newName(ref.serialize.replace('.', '_')) + + private def isInBounds(depth: BigInt, addr: Expression): Expression = { + val width = getWidth(addr) + // depth >= addr + DoPrim(PrimOps.Geq, List(UIntLiteral(depth, width), addr), List(), BoolType) + } + + private def isPow2(v: BigInt): Boolean = ((v - 1) & v) == 0 + + private def checkSupported(modName: String, m: DefMemory): Unit = { + assert(m.readwriters.isEmpty, s"[$modName] Combined read/write ports are currently not supported!") + if (m.writeLatency != 1) { + throw new UnsupportedFeatureException(s"[$modName] memories with write latency > 1 (${m.name})") + } + if (m.readLatency > 1) { + throw new UnsupportedFeatureException(s"[$modName] memories with read latency > 1 (${m.name})") + } + } + + private def getProductTerms(e: Expression): ProdTerms = + InferReadWritePass.getProductTerms(connects)(e) + + /** tries to expand the expression based on the connects we collected */ + private def expandExpr(e: Expression, fuel: Int): Expression = { + e match { + case m @ Mux(cond, tval, fval, _) => + m.copy(cond = expandExpr(cond, fuel), tval = expandExpr(tval, fuel), fval = expandExpr(fval, fuel)) + case p @ DoPrim(_, args, _, _) => + p.copy(args = args.map(expandExpr(_, fuel))) + case r: RefLikeExpression => + if (fuel > 0) { + connects.get(r.serialize) match { + case None => r + case Some(expr) => expandExpr(expr, fuel - 1) + } + } else { + r + } + case other => other + } + } + + private def isTrue(e: Expression): Boolean = simplifyExpr(expandExpr(e, fuel = 2)) == Utils.True() + + private def simplifyExpr(e: Expression): Expression = { + e // TODO: better simplification could improve the resulting circuit size + } +} diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 9c3d6186..13ba3d46 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -4,6 +4,7 @@ package firrtl package ir import Utils.{dec2string, trim} +import firrtl.backends.experimental.smt.random.DefRandom import dataclass.{data, since} import firrtl.constraint.{Constraint, IsKnown, IsVar} import org.apache.commons.text.translate.{AggregateTranslator, JavaUnicodeEscaper, LookupTranslator} @@ -242,6 +243,9 @@ object Reference { /** Creates a Reference from a Register */ def apply(reg: DefRegister): Reference = Reference(reg.name, reg.tpe, RegKind, UnknownFlow) + /** Creates a Reference from a Random Source */ + def apply(rnd: DefRandom): Reference = Reference(rnd.name, rnd.tpe, RandomKind, UnknownFlow) + /** Creates a Reference from a Node */ def apply(node: DefNode): Reference = Reference(node.name, node.value.tpe, NodeKind, SourceFlow) diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala index caea0a9c..dca902fe 100644 --- a/src/main/scala/firrtl/ir/Serializer.scala +++ b/src/main/scala/firrtl/ir/Serializer.scala @@ -2,6 +2,8 @@ package firrtl.ir +import firrtl.Utils +import firrtl.backends.experimental.smt.random.DefRandom import firrtl.constraint.Constraint object Serializer { @@ -114,6 +116,11 @@ object Serializer { case DefRegister(info, name, tpe, clock, reset, init) => b ++= "reg "; b ++= name; b ++= " : "; s(tpe); b ++= ", "; s(clock); b ++= " with :"; newLineAndIndent(1) b ++= "reset => ("; s(reset); b ++= ", "; s(init); b += ')'; s(info) + case DefRandom(info, name, tpe, clock, en) => + b ++= "rand "; b ++= name; b ++= " : "; s(tpe); + if (clock.isDefined) { b ++= ", "; s(clock.get); } + en match { case Utils.True() => case _ => b ++= " when "; s(en) } + s(info) case DefInstance(info, name, module, _) => b ++= "inst "; b ++= name; b ++= " of "; b ++= module; s(info) case DefMemory( info, diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala index e3218467..745be1e2 100644 --- a/src/main/scala/firrtl/passes/ResolveKinds.scala +++ b/src/main/scala/firrtl/passes/ResolveKinds.scala @@ -5,6 +5,7 @@ package firrtl.passes import firrtl._ import firrtl.ir._ import firrtl.Mappers._ +import firrtl.backends.experimental.smt.random.DefRandom import firrtl.traversals.Foreachers._ object ResolveKinds extends Pass { @@ -31,6 +32,7 @@ object ResolveKinds extends Pass { case sx: DefRegister => kinds(sx.name) = RegKind case sx: WDefInstance => kinds(sx.name) = InstanceKind case sx: DefMemory => kinds(sx.name) = MemKind + case sx: DefRandom => kinds(sx.name) = RandomKind case _ => } s.map(resolve_stmt(kinds)) diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala index 8fb2dc88..143b925a 100644 --- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala +++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala @@ -177,7 +177,8 @@ class MemDelayAndReadwriteTransformer(m: DefModule) { object VerilogMemDelays extends Pass { - override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf) + override def prerequisites = firrtl.stage.Forms.LowForm + override val optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf)) override val optionalPrerequisiteOf = Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter]) diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala index 1d9bfd0e..f72585d1 100644 --- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala @@ -11,6 +11,7 @@ import firrtl.analyses.InstanceKeyGraph import firrtl.Mappers._ import firrtl.Utils.{kind, throwInternalError} import firrtl.MemoizedHash._ +import firrtl.backends.experimental.smt.random.DefRandom import firrtl.options.{Dependency, RegisteredTransform, ShellOption} import collection.mutable @@ -126,6 +127,11 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend val node = LogicNode(mod.name, name) depGraph.addVertex(node) Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref)) + case DefRandom(_, name, _, clock, en) => + val node = LogicNode(mod.name, name) + depGraph.addVertex(node) + val inputs = clock ++: en +: Nil + inputs.flatMap(getDeps).foreach(ref => depGraph.addPairWithEdge(node, ref)) case DefNode(_, name, value) => val node = LogicNode(mod.name, name) depGraph.addVertex(node) @@ -225,6 +231,7 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend val tpe = decl match { case _: DefNode => "node" case _: DefRegister => "reg" + case _: DefRandom => "rand" case _: DefWire => "wire" case _: Port => "port" case _: DefMemory => "mem" diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index f2907db2..7500b386 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -11,6 +11,7 @@ import firrtl.WrappedExpression._ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency import firrtl.Utils.getGroundZero +import firrtl.backends.experimental.smt.random.DefRandom import scala.collection.mutable import scala.util.{Failure, Success, Try} @@ -41,13 +42,13 @@ class RemoveWires extends Transform with DependencyAPIMigration { case _ => false } - // Extract all expressions that are references to a Node, Wire, or Reg + // Extract all expressions that are references to a Node, Wire, Reg or Rand // Since we are operating on LowForm, they can only be WRefs private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = { val refs = mutable.ArrayBuffer.empty[WRef] def rec(e: Expression): Expression = { e match { - case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref + case ref @ WRef(_, _, WireKind | NodeKind | RegKind | RandomKind, _) => refs += ref case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec) case _ => // Do nothing } @@ -59,8 +60,9 @@ class RemoveWires extends Transform with DependencyAPIMigration { // Transform netlist into DefNodes private def getOrderedNodes( - netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], - regInfo: mutable.Map[WrappedExpression, DefRegister] + netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)], + regInfo: mutable.Map[WrappedExpression, DefRegister], + randInfo: mutable.Map[WrappedExpression, DefRandom] ): Try[Seq[Statement]] = { val digraph = new MutableDiGraph[WrappedExpression] for ((sink, (exprs, _)) <- netlist) { @@ -80,7 +82,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { ordered.map { key => val WRef(name, _, kind, _) = key.e1 kind match { - case RegKind => regInfo(key) + case RegKind => regInfo(key) + case RandomKind => randInfo(key) case WireKind | NodeKind => val (Seq(rhs), info) = netlist(key) DefNode(info, name, rhs) @@ -100,6 +103,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { val wireInfo = mutable.HashMap.empty[WrappedExpression, Info] // Additional info about registers val regInfo = mutable.HashMap.empty[WrappedExpression, DefRegister] + // Additional info about rand statements + val randInfo = mutable.HashMap.empty[WrappedExpression, DefRandom] def onStmt(stmt: Statement): Statement = { stmt match { @@ -115,6 +120,9 @@ class RemoveWires extends Transform with DependencyAPIMigration { val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself regInfo(we(WRef(reg))) = reg netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info) + case rand: DefRandom => + randInfo(we(Reference(rand))) = rand + netlist(we(Reference(rand))) = (rand.clock ++: rand.en +: List(), rand.info) case decl: CanBeReferenced => // Keep all declarations except for nodes and non-Analog wires and "other" statements. // Thus this is expected to match DefInstance and DefMemory which both do not connect to @@ -148,7 +156,7 @@ class RemoveWires extends Transform with DependencyAPIMigration { m match { case mod @ Module(info, name, ports, body) => onStmt(body) - getOrderedNodes(netlist, regInfo) match { + getOrderedNodes(netlist, regInfo, randInfo) match { case Success(logic) => Module(info, name, ports, Block(List() ++ decls ++ logic ++ otherStmts)) // If we hit a CyclicException, just abort removing wires diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala new file mode 100644 index 00000000..c6f89b93 --- /dev/null +++ b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala @@ -0,0 +1,360 @@ +package firrtl.backends.experimental.smt.random + +import firrtl.options.Dependency +import firrtl.testutils.LeanTransformSpec + +class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(UndefinedMemoryBehaviorPass))) { + behavior.of("UndefinedMemoryBehaviorPass") + + it should "model write-write conflicts between 2 ports" in { + + val circuit = compile(UBMSources.writeWriteConflict, List()).circuit + // println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // a random value should be declared for the data written on a write-write conflict + assert(result.contains("rand m_a_wwc_data : UInt<32>, m.a.clk when m_a_b_wwc")) + + // a write-write conflict occurs when both ports are enabled and the addresses match + assert(result.contains("m_a_b_wwc <= and(and(m_a_en, m_b_en), eq(m.a.addr, m.b.addr))")) + + // the data of read port a depends on whether there is a write-write conflict + assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)")) + + // the enable of read port b depends on whether there is a write-write conflict + assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) + } + + it should "model write-write conflicts between 3 ports" in { + + val circuit = compile(UBMSources.writeWriteConflict3, List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // when there is more than one next write port, a "active" node is created + assert(result.contains("node m_a_wwc_active = or(m_a_b_wwc, m_a_c_wwc)")) + + // a random value should be declared for the data written on a write-write conflict + assert(result.contains("rand m_a_wwc_data : UInt<32>, m.a.clk when m_a_wwc_active")) + assert(result.contains("rand m_b_wwc_data : UInt<32>, m.b.clk when m_b_c_wwc")) + + // a write-write conflict occurs when both ports are enabled and the addresses match + Seq(("a", "b"), ("a", "c"), ("b", "c")).foreach { + case (w1, w2) => + assert( + result.contains(s"m_${w1}_${w2}_wwc <= and(and(m_${w1}_en, m_${w2}_en), eq(m.${w1}.addr, m.${w2}.addr))") + ) + } + + // the data of read port a depends on whether there is a write-write conflict + assert(result.contains("m.a.data <= mux(m_a_wwc_active, m_a_wwc_data, m_a_data)")) + + // the data of read port b depends on whether there is a write-write conflict + assert(result.contains("m.b.data <= mux(m_b_c_wwc, m_b_wwc_data, m_b_data)")) + + // the enable of read port b depends on whether there is a write-write conflict + assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) + + // the enable of read port c depends on whether there is a write-write conflict + // note that in this case we do not add an extra node since the disjunction is only used once + assert(result.contains("m.c.en <= and(m_c_en, not(or(m_a_c_wwc, m_b_c_wwc)))")) + } + + it should "model write-write conflicts more efficiently when ports are mutually exclusive" in { + + val circuit = compile(UBMSources.writeWriteConflict3Exclusive, List()).circuit + // println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // we should not compute the conflict between a and c since it is impossible + assert(!result.contains("node m_a_c_wwc = and(and(m_a_en, m_c_en), eq(m.a.addr, m.c.addr))")) + + // the enable of port b depends on whether there is a conflict with a + assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))")) + + // the data of port b depends on whether these is a conflict with c + assert(result.contains("m.b.data <= mux(m_b_c_wwc, m_b_wwc_data, m_b_data)")) + + // the enable of port c only depend on whether there is a conflict with b since c and a cannot conflict + assert(result.contains("m.c.en <= and(m_c_en, not(m_b_c_wwc))")) + + // the data of port a only depends on whether there is a conflict with b since a and c cannot conflict + assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)")) + } + + it should "assert out-of-bounds writes when told to" in { + val anno = List(UndefinedMemoryBehaviorOptions(assertNoOutOfBoundsWrites = true)) + + val circuit = compile(UBMSources.readWrite(30, 0), anno).circuit + // println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + assert( + result.contains( + """assert(m.a.clk, or(not(and(m.a.en, m.a.mask)), geq(UInt<5>("h1e"), m.a.addr)), UInt<1>("h1"), "out of bounds read")""" + ) + ) + } + + it should "model out-of-bounds reads" in { + val circuit = compile(UBMSources.readWrite(30, 0), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // an out of bounds read happens if the depth is not greater or equal to the address + assert(result.contains("node m_r_oob = not(geq(UInt<5>(\"h1e\"), m.r.addr))")) + + // the source of randomness needs to be triggered when there is an out of bounds read + assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_oob")) + + // the data is random when there is an oob + assert(result.contains("m_r_data <= mux(m_r_oob, m_r_rand_data, m.r.data)")) + } + + it should "model un-enabled reads w/o out-of-bounds" in { + // without possible out-of-bounds + val circuit = compile(UBMSources.readEnable(32), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // the memory is disabled when it is not enabled + assert(result.contains("node m_r_disabled = not(m.r.en)")) + + // the source of randomness needs to be triggered when there is an read while the port is disabled + assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_disabled")) + + // the data is random when there is an un-enabled read + assert(result.contains("m_r_data <= mux(m_r_disabled, m_r_rand_data, m.r.data)")) + } + + it should "model un-enabled reads with out-of-bounds" in { + // with possible out-of-bounds + val circuit = compile(UBMSources.readEnable(30), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // the memory is disabled when it is not enabled + assert(result.contains("node m_r_disabled = not(m.r.en)")) + + // an out of bounds read happens if the depth is not greater or equal to the address and the memory is enabled + assert(result.contains("node m_r_oob = and(m.r.en, not(geq(UInt<5>(\"h1e\"), m.r.addr)))")) + + // the two possible issues are combined into a single signal + assert(result.contains("node m_r_do_rand = or(m_r_disabled, m_r_oob)")) + + // the source of randomness needs to be triggered when either issue occurs + assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand")) + + // the data is random when either issue occurs + assert(result.contains("m_r_data <= mux(m_r_do_rand, m_r_rand_data, m.r.data)")) + } + + it should "model un-enabled reads with out-of-bounds with read pipelining" in { + // with read latency one, we need to pipeline the `do_rand` signal + val circuit = compile(UBMSources.readEnable(30, 1), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // pipeline register + assert(result.contains("m_r_do_rand_r1 <= m_r_do_rand")) + + // the source of randomness needs to be triggered by the pipeline register + assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand_r1")) + + // the data is random when the pipeline register is 1 + assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)")) + } + + it should "model read/write conflicts when they are undefined" in { + val circuit = compile(UBMSources.readWrite(32, 1), List()).circuit + //println(circuit.serialize) + val result = circuit.serialize.split('\n').map(_.trim) + + // detect read/write conflicts + assert(result.contains("m_r_a_rwc <= and(and(m.r.en, and(m.a.en, m.a.mask)), eq(m.r.addr, m.a.addr))")) + + // delay the signal + assert(result.contains("m_r_do_rand_r1 <= m_r_rwc")) + + // randomize the data + assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand_r1")) + assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)")) + } +} + +private object UBMSources { + + val writeWriteConflict = + s""" + |circuit Test: + | module Test: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | input aEn : UInt<1> + | input bEn : UInt<1> + | + | mem m: + | data-type => UInt<32> + | depth => 32 + | reader => r + | writer => a, b + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | + | ; both read ports write to the same address and the same data + | m.a.clk <= c + | m.a.en <= aEn + | m.a.addr <= addr + | m.a.data <= data + | m.a.mask <= UInt(1) + | m.b.clk <= c + | m.b.en <= bEn + | m.b.addr <= addr + | m.b.data <= data + | m.b.mask <= UInt(1) + """.stripMargin + + val writeWriteConflict3 = + s""" + |circuit Test: + | module Test: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | input aEn : UInt<1> + | input bEn : UInt<1> + | input cEn : UInt<1> + | + | mem m: + | data-type => UInt<32> + | depth => 32 + | reader => r + | writer => a, b, c + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | + | ; both read ports write to the same address and the same data + | m.a.clk <= c + | m.a.en <= aEn + | m.a.addr <= addr + | m.a.data <= data + | m.a.mask <= UInt(1) + | m.b.clk <= c + | m.b.en <= bEn + | m.b.addr <= addr + | m.b.data <= data + | m.b.mask <= UInt(1) + | m.c.clk <= c + | m.c.en <= cEn + | m.c.addr <= addr + | m.c.data <= data + | m.c.mask <= UInt(1) + """.stripMargin + + val writeWriteConflict3Exclusive = + s""" + |circuit Test: + | module Test: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | input aEn : UInt<1> + | input bEn : UInt<1> + | + | mem m: + | data-type => UInt<32> + | depth => 32 + | reader => r + | writer => a, b, c + | read-latency => 0 + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | + | ; both read ports write to the same address and the same data + | m.a.clk <= c + | m.a.en <= aEn + | m.a.addr <= addr + | m.a.data <= data + | m.a.mask <= UInt(1) + | m.b.clk <= c + | m.b.en <= bEn + | m.b.addr <= addr + | m.b.data <= data + | m.b.mask <= UInt(1) + | m.c.clk <= c + | m.c.en <= not(aEn) + | m.c.addr <= addr + | m.c.data <= data + | m.c.mask <= UInt(1) + """.stripMargin + + def readWrite(depth: Int, readLatency: Int) = + s"""circuit CollisionTest: + | module CollisionTest: + | input c : Clock + | input preset: AsyncReset + | input addr : UInt<8> + | input data : UInt<32> + | output dataOut : UInt<32> + | + | mem m: + | data-type => UInt<32> + | depth => $depth + | reader => r + | writer => a + | read-latency => $readLatency + | write-latency => 1 + | read-under-write => undefined + | + | m.r.clk <= c + | m.r.en <= UInt(1) + | m.r.addr <= addr + | dataOut <= m.r.data + | + | m.a.clk <= c + | m.a.mask <= UInt(1) + | m.a.en <= UInt(1) + | m.a.addr <= addr + | m.a.data <= data + |""".stripMargin + + def readEnable(depth: Int, latency: Int = 0) = + s"""circuit Test: + | module Test: + | input c : Clock + | input addr : UInt<8> + | input en : UInt<1> + | output data : UInt<32> + | + | mem m: + | data-type => UInt<32> + | depth => $depth + | reader => r + | read-latency => $latency + | write-latency => 1 + | read-under-write => old + | + | m.r.clk <= c + | m.r.en <= en + | m.r.addr <= addr + | data <= m.r.data + |""".stripMargin +} |
