aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorKevin Laeufer2021-03-04 11:23:51 -0800
committerGitHub2021-03-04 19:23:51 +0000
commitc93d6f5319efd7ba42147180c6e2b6f3796ef943 (patch)
treecea5c960c6fd15c1680f43d78fa06a69dda7dc6e /src
parente58ba0c12e5d650983c70a61a45542f0cd43fb88 (diff)
SMT Backend: move undefined memory behavior modelling to firrtl IR level (#2095)
With this PR the smt backend now supports memories with more than two write ports and the conservative memory modelling can be selectively turned off with a new annotation.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Utils.scala73
-rw-r--r--src/main/scala/firrtl/WIR.scala1
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala312
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala31
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala457
-rw-r--r--src/main/scala/firrtl/ir/IR.scala4
-rw-r--r--src/main/scala/firrtl/ir/Serializer.scala7
-rw-r--r--src/main/scala/firrtl/passes/ResolveKinds.scala2
-rw-r--r--src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala3
-rw-r--r--src/main/scala/firrtl/transforms/DeadCodeElimination.scala7
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala20
-rw-r--r--src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala360
12 files changed, 1060 insertions, 217 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 72884e25..3d0f19b8 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -900,6 +900,79 @@ object Utils extends LazyLogging {
def maskBigInt(value: BigInt, width: Int): BigInt = {
value & ((BigInt(1) << width) - 1)
}
+
+ /** Returns true iff the expression is a Literal or a Literal cast to a different type. */
+ def isLiteral(e: Expression): Boolean = e match {
+ case _: Literal => true
+ case DoPrim(op, args, _, _) if isCast(op) => args.exists(isLiteral)
+ case _ => false
+ }
+
+ /** Applies the firrtl And primop. Automatically constant propagates when one of the expressions is True or False. */
+ def and(e1: Expression, e2: Expression): Expression = {
+ assert(e1.tpe == e2.tpe)
+ (e1, e2) match {
+ case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width)
+ case (True(), b) => b
+ case (a, True()) => a
+ case (False(), _) => False()
+ case (_, False()) => False()
+ case (a, b) if a == b => a
+ case (a, b) => DoPrim(PrimOps.And, Seq(a, b), Nil, BoolType)
+ }
+ }
+
+ /** Applies the firrtl Eq primop. */
+ def eq(e1: Expression, e2: Expression): Expression = DoPrim(PrimOps.Eq, Seq(e1, e2), Nil, BoolType)
+
+ /** Applies the firrtl Or primop. Automatically constant propagates when one of the expressions is True or False. */
+ def or(e1: Expression, e2: Expression): Expression = {
+ assert(e1.tpe == e2.tpe)
+ (e1, e2) match {
+ case (a: UIntLiteral, b: UIntLiteral) => UIntLiteral(a.value | b.value, a.width)
+ case (True(), _) => True()
+ case (_, True()) => True()
+ case (False(), b) => b
+ case (a, False()) => a
+ case (a, b) if a == b => a
+ case (a, b) => DoPrim(PrimOps.Or, Seq(a, b), Nil, BoolType)
+ }
+ }
+
+ /** Applies the firrtl Not primop. Automatically constant propagates when the expressions is True or False. */
+ def not(e: Expression): Expression = e match {
+ case True() => False()
+ case False() => True()
+ case a => DoPrim(PrimOps.Not, Seq(a), Nil, BoolType)
+ }
+
+ /** implies(e1, e2) = or(not(e1), e2). Automatically constant propagates when one of the expressions is True or False. */
+ def implies(e1: Expression, e2: Expression): Expression = or(not(e1), e2)
+
+ /** Builds a Mux expression with the correct type. */
+ def mux(cond: Expression, tval: Expression, fval: Expression): Expression = {
+ require(tval.tpe == fval.tpe)
+ Mux(cond, tval, fval, tval.tpe)
+ }
+
+ object True {
+ private val _True = UIntLiteral(1, IntWidth(1))
+
+ /** Matches `UInt<1>(1)` */
+ def unapply(e: UIntLiteral): Boolean = e.value == 1 && e.width == _True.width
+
+ /** Returns `UInt<1>(1)` */
+ def apply(): UIntLiteral = _True
+ }
+ object False {
+ private val _False = UIntLiteral(0, IntWidth(1))
+
+ /** Matches `UInt<1>(0)` */
+ def unapply(e: UIntLiteral): Boolean = e.value == 0 && e.width == _False.width
+
+ /** Returns `UInt<1>(0)` */
+ def apply(): UIntLiteral = _False
+ }
}
object MemoizedHash {
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 78536a36..e9dd95bc 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -13,6 +13,7 @@ trait Kind
case object WireKind extends Kind
case object PoisonKind extends Kind
case object RegKind extends Kind
+case object RandomKind extends Kind
case object InstanceKind extends Kind
case object PortKind extends Kind
case object NodeKind extends Kind
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index c1857667..d3a1ed68 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -5,9 +5,12 @@ package firrtl.backends.experimental.smt
import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation}
import FirrtlExpressionSemantics.getWidth
+import firrtl.backends.experimental.smt.random._
import firrtl.graph.MutableDiGraph
import firrtl.options.Dependency
+import firrtl.passes.MemPortUtils.memPortField
import firrtl.passes.PassException
+import firrtl.passes.memlib.VerilogMemDelays
import firrtl.stage.Forms
import firrtl.stage.TransformManager.TransformDependency
import firrtl.transforms.{DeadCodeElimination, PropagatePresetAnnotations}
@@ -58,7 +61,8 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
// TODO: We only really need [[Forms.MidForm]] + LowerTypes, but we also want to fail if there are CombLoops
// TODO: We also would like to run some optimization passes, but RemoveValidIf won't allow us to model DontCare
// precisely and PadWidths emits ill-typed firrtl.
- override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm
+ override def prerequisites: Seq[Dependency[Transform]] = Forms.LowForm ++
+ Seq(Dependency(UndefinedMemoryBehaviorPass), Dependency(VerilogMemDelays))
override def invalidates(a: Transform): Boolean = false
// since this pass only runs on the main module, inlining needs to happen before
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
@@ -169,8 +173,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
}
// turn memories into state
- val memoryEncoding = new MemoryEncoding(makeRandom, scan.namespace)
- val memoryStatesAndOutputs = scan.memories.map(m => memoryEncoding.onMemory(m, scan.connects, memInit.get(m.name)))
+ val memoryStatesAndOutputs = scan.memories.map(m => onMemory(m, scan.connects, memInit.get(m.name)))
// replace pseudo assigns for memory outputs
val memOutputs = memoryStatesAndOutputs.flatMap(_._2).toMap
val signalsWithMem = signals.map { s =>
@@ -185,7 +188,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
case _ => true
}
)
- val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1)
+ val states = regStates.toArray ++ memoryStatesAndOutputs.map(_._1)
// generate comments from infos
val comments = mutable.HashMap[String, String]()
@@ -247,233 +250,116 @@ private class ModuleToTransitionSystem extends LazyLogging {
}
}
- private val InfoSeparator = ", "
- private val InfoPrefix = "@ "
- private def serializeInfo(info: ir.Info): Option[String] = info match {
- case ir.NoInfo => None
- case f: ir.FileInfo => Some(f.escaped)
- case m: ir.MultiInfo =>
- val infos = m.flatten
- if (infos.isEmpty) { None }
- else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
- }
-
- private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
- private def makeRandom(baseName: String, width: Int): BVExpr = {
- // TODO: actually ensure that there cannot be any name clashes with other identifiers
- val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
- val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
- val sym = BVSymbol(name, width)
- randoms(name) = sym
- sym
- }
-}
-
-private class MemoryEncoding(makeRandom: (String, Int) => BVExpr, namespace: Namespace) extends LazyLogging {
type Connects = Iterable[(String, BVExpr)]
- def onMemory(
- defMem: ir.DefMemory,
- connects: Connects,
- initValue: Option[MemoryInitValue]
- ): (Iterable[State], Connects) = {
- // we can only work on appropriately lowered memories
- assert(
- defMem.dataType.isInstanceOf[ir.GroundType],
- s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!"
- )
- assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
+ private def onMemory(m: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (State, Connects) = {
+ checkMem(m)
- // collect all memory meta-data in a custom class
- val m = new MemInfo(defMem)
+ // map of inputs to the memory
+ val inputs = connects.filter(_._1.startsWith(m.name)).toMap
- // find all connections related to this memory
- val inputs = connects.filter(_._1.startsWith(m.prefix)).toMap
+ // derive the type of the memory from the dataType and depth
+ val dataWidth = getWidth(m.dataType)
+ val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
+ val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth)
// there could be a constant init
- val init = initValue.map(getInit(m, _))
-
- // parse and check read and write ports
- val writers = defMem.writers.map(w => new WritePort(m, w, inputs))
- val readers = defMem.readers.map(r => new ReadPort(m, r, inputs))
-
- // derive next state from all write ports
- assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.")
- val next: ArrayExpr = if (writers.isEmpty) { m.sym }
- else {
- if (writers.length > 2) {
- throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})")
- }
- val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) }
- if (writers.length == 1) { validData }
- else {
- assert(writers.length == 2)
- val conflict = writers.head.doesConflict(writers.last)
- val conflictData = writers.head.makeRandomData("_write_write_collision")
- val conflictStore = ArrayStore(m.sym, writers.head.addr, conflictData)
- ArrayIte(conflict, conflictStore, validData)
- }
- }
- val state = State(m.sym, init, Some(next))
+ val init = initValue.map(getInit(m, indexWidth, dataWidth, _))
+ init.foreach(e => assert(e.dataWidth == memSymbol.dataWidth && e.indexWidth == memSymbol.indexWidth))
- // derive data signals from all read ports
- assert(defMem.readLatency >= 0)
- if (defMem.readLatency > 1) {
- throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})")
- }
- val readPortSignals = if (defMem.readLatency == 0) {
- readers.map { r =>
- // combinatorial read
- if (defMem.readUnderWrite != ir.ReadUnderWrite.New) {
- logger.warn(
- s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
- s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored."
- )
- }
- // since we do a combinatorial read, the "old" data is the current data
- val data = r.read()
- r.data.name -> data
- }
- } else { Seq() }
- val readPortSignalsAndStates = if (defMem.readLatency == 1) {
- readers.map { r =>
- defMem.readUnderWrite match {
- case ir.ReadUnderWrite.New =>
- // create a state to save the address and the enable signal
- val enPrev = BVSymbol(namespace.newName(r.en.name + "_prev"), r.en.width)
- val addrPrev = BVSymbol(namespace.newName(r.addr.name + "_prev"), r.addr.width)
- val signal = r.data.name -> r.read(addr = addrPrev, en = enPrev)
- val states = Seq(State(enPrev, None, next = Some(r.en)), State(addrPrev, None, next = Some(r.addr)))
- (Seq(signal), states)
- case ir.ReadUnderWrite.Undefined =>
- // check for potential read/write conflicts in which case we need to return an arbitrary value
- val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r)))
- val next = if (anyWriteToTheSameAddress == False) { r.read() }
- else {
- val readUnderWriteData = r.makeRandomData("_read_under_write_undefined")
- BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.read())
- }
- (Seq(), Seq(State(r.data, init = None, next = Some(next))))
- case ir.ReadUnderWrite.Old =>
- // we create a register for the read port data
- (Seq(), Seq(State(r.data, init = None, next = Some(r.read()))))
- }
+ // derive next state expression
+ val next = if (m.writers.isEmpty) {
+ memSymbol
+ } else {
+ m.writers.foldLeft[ArrayExpr](memSymbol) {
+ case (prev, write) =>
+ // update
+ val addr = BVSymbol(memPortField(m, write, "addr").serialize, indexWidth)
+ val data = BVSymbol(memPortField(m, write, "data").serialize, dataWidth)
+ val update = ArrayStore(prev, index = addr, data = data)
+
+ // update guard
+ val en = BVSymbol(memPortField(m, write, "en").serialize, 1)
+ val mask = BVSymbol(memPortField(m, write, "mask").serialize, 1)
+ val alwaysEnabled = Seq(en, mask).forall(s => inputs(s.name) == True)
+ if (alwaysEnabled) { update }
+ else {
+ ArrayIte(and(en, mask), update, prev)
+ }
}
- } else { Seq() }
+ }
- val allReadPortSignals = readPortSignals ++ readPortSignalsAndStates.flatMap(_._1)
- val readPortStates = readPortSignalsAndStates.flatMap(_._2)
+ val state = State(memSymbol, init, Some(next))
- (state +: readPortStates, allReadPortSignals)
- }
+ // derive read expressions
+ val readSignals = m.readers.map { read =>
+ val addr = BVSymbol(memPortField(m, read, "addr").serialize, indexWidth)
+ memPortField(m, read, "data").serialize -> ArrayRead(memSymbol, addr)
+ }
- private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match {
- case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth)
- case MemoryArrayInit(values) =>
- assert(
- values.length == m.depth,
- s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
- )
- // in order to get a more compact encoding try to find the most common values
- val histogram = mutable.LinkedHashMap[BigInt, Int]()
- values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
- val baseValue = histogram.maxBy(_._2)._1
- val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth)
- values.zipWithIndex
- .filterNot(_._1 == baseValue)
- .foldLeft[ArrayExpr](base) {
- case (array, (value, index)) =>
- ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
- }
- case other => throw new RuntimeException(s"Unsupported memory init option: $other")
+ (state, readSignals)
}
- private class MemInfo(m: ir.DefMemory) {
- val name = m.name
- val depth = m.depth
- // derrive the type of the memory from the dataType and depth
- val dataWidth = getWidth(m.dataType)
- val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
- val sym = ArraySymbol(m.name, indexWidth, dataWidth)
- val prefix = m.name + "."
- val fullAddressRange = (BigInt(1) << indexWidth) == m.depth
- lazy val depthBV = BVLiteral(m.depth, indexWidth)
- def isValidAddress(addr: BVExpr): BVExpr = {
- if (fullAddressRange) { True }
- else {
- BVComparison(Compare.Greater, depthBV, addr, signed = false)
- }
- }
- }
- private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) {
- val en: BVSymbol = makeField("en", 1)
- val data: BVSymbol = makeField("data", memory.dataWidth)
- val addr: BVSymbol = makeField("addr", memory.indexWidth)
- protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width)
- // make sure that all widths are correct
- assert(inputs(en.name).width == en.width)
- assert(inputs(addr.name).width == addr.width)
- val enIsTrue: Boolean = inputs(en.name) == True
- def makeRandomData(suffix: String): BVExpr =
- makeRandom(memory.name + "_" + name + suffix, memory.dataWidth)
- def read(addr: BVSymbol = addr, en: BVSymbol = en): BVExpr = {
- val canBeOutOfRange = !memory.fullAddressRange
- val canBeDisabled = !enIsTrue
- val data = ArrayRead(memory.sym, addr)
- val dataWithRangeCheck = if (canBeOutOfRange) {
- val outOfRangeData = makeRandomData("_addr_out_of_range")
- BVIte(memory.isValidAddress(addr), data, outOfRangeData)
- } else { data }
- val dataWithEnabledCheck = if (canBeDisabled) {
- val disabledData = makeRandomData("_not_enabled")
- BVIte(en, dataWithRangeCheck, disabledData)
- } else { dataWithRangeCheck }
- dataWithEnabledCheck
- }
- }
- private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {
- assert(inputs(data.name).width == data.width)
- val mask: BVSymbol = makeField("mask", 1)
- assert(inputs(mask.name).width == mask.width)
- val maskIsTrue: Boolean = inputs(mask.name) == True
- val doWrite: BVExpr = (enIsTrue, maskIsTrue) match {
- case (true, true) => True
- case (true, false) => mask
- case (false, true) => en
- case (false, false) => and(en, mask)
- }
- def doesConflict(r: ReadPort): BVExpr = {
- val sameAddress = BVEqual(r.addr, addr)
- if (doWrite == True) { sameAddress }
- else { and(doWrite, sameAddress) }
- }
- def doesConflict(w: WritePort): BVExpr = {
- val bothWrite = and(doWrite, w.doWrite)
- val sameAddress = BVEqual(addr, w.addr)
- if (bothWrite == True) { sameAddress }
- else { and(bothWrite, sameAddress) }
- }
- def writeTo(array: ArrayExpr): ArrayExpr = {
- val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
- val update = ArrayStore(array, index = addr, data = data)
- if (doUpdate == True) update else ArrayIte(doUpdate, update, array)
+ private def getInit(m: ir.DefMemory, indexWidth: Int, dataWidth: Int, initValue: MemoryInitValue): ArrayExpr =
+ initValue match {
+ case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, dataWidth), indexWidth)
+ case MemoryArrayInit(values) =>
+ assert(
+ values.length == m.depth,
+ s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
+ )
+ // in order to get a more compact encoding try to find the most common values
+ val histogram = mutable.LinkedHashMap[BigInt, Int]()
+ values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
+ val baseValue = histogram.maxBy(_._2)._1
+ val base = ArrayConstant(BVLiteral(baseValue, dataWidth), indexWidth)
+ values.zipWithIndex
+ .filterNot(_._1 == baseValue)
+ .foldLeft[ArrayExpr](base) {
+ case (array, (value, index)) =>
+ ArrayStore(array, BVLiteral(index, indexWidth), BVLiteral(value, dataWidth))
+ }
+ case other => throw new RuntimeException(s"Unsupported memory init option: $other")
}
- }
- private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {}
-
+ // TODO: add to BV expression library
private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match {
case (True, True) => True
case (True, x) => x
case (x, True) => x
case _ => BVOp(Op.And, a, b)
}
- private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b)
+
private val True = BVLiteral(1, 1)
- private val False = BVLiteral(0, 1)
- private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b))
- private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b))
+ private def checkMem(m: ir.DefMemory): Unit = {
+ assert(m.readLatency == 0, "Expected read latency to be 0. Did you run VerilogMemDelays?")
+ assert(m.writeLatency == 1, "Expected read latency to be 1. Did you run VerilogMemDelays?")
+ assert(
+ m.dataType.isInstanceOf[ir.GroundType],
+ s"Memory $m is of type ${m.dataType} which is not a ground type!"
+ )
+ assert(m.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
+ }
+
+ private val InfoSeparator = ", "
+ private val InfoPrefix = "@ "
+ private def serializeInfo(info: ir.Info): Option[String] = info match {
+ case ir.NoInfo => None
+ case f: ir.FileInfo => Some(f.escaped)
+ case m: ir.MultiInfo =>
+ val infos = m.flatten
+ if (infos.isEmpty) { None }
+ else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
+ }
+
+ private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
+ private def makeRandom(baseName: String, width: Int): BVExpr = {
+ // TODO: actually ensure that there cannot be any name clashes with other identifiers
+ val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
+ val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
+ val sym = BVSymbol(name, width)
+ randoms(name) = sym
+ sym
+ }
}
// performas a first pass over the module collecting all connections, wires, registers, input and outputs
@@ -526,6 +412,12 @@ private class ModuleScanner(
}
private[firrtl] def onStatement(s: ir.Statement): Unit = s match {
+ case DefRandom(info, name, tpe, _, _) =>
+ namespace.newName(name)
+ assert(!isClock(tpe), "rand should never be a clock!")
+ // we model random sources as inputs and ignore the enable signal
+ infos.append(name -> info)
+ inputs.append(BVSymbol(name, getWidth(tpe)))
case ir.DefWire(info, name, tpe) =>
namespace.newName(name)
if (!isClock(tpe)) {
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala
new file mode 100644
index 00000000..7381056e
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/random/DefRandom.scala
@@ -0,0 +1,31 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.backends.experimental.smt.random
+
+import firrtl.Utils
+import firrtl.ir._
+
+/** Named source of random values. If there is no clock expression, than it will be clocked by the global clock. */
+case class DefRandom(
+ info: Info,
+ name: String,
+ tpe: Type,
+ clock: Option[Expression],
+ en: Expression = Utils.True())
+ extends Statement
+ with HasInfo
+ with IsDeclaration
+ with CanBeReferenced
+ with UseSerializer {
+ def mapStmt(f: Statement => Statement): Statement = this
+ def mapExpr(f: Expression => Expression): Statement =
+ DefRandom(info, name, tpe, clock.map(f), f(en))
+ def mapType(f: Type => Type): Statement = this.copy(tpe = f(tpe))
+ def mapString(f: String => String): Statement = this.copy(name = f(name))
+ def mapInfo(f: Info => Info): Statement = this.copy(info = f(info))
+ def foreachStmt(f: Statement => Unit): Unit = ()
+ def foreachExpr(f: Expression => Unit): Unit = { clock.foreach(f); f(en) }
+ def foreachType(f: Type => Unit): Unit = f(tpe)
+ def foreachString(f: String => Unit): Unit = f(name)
+ def foreachInfo(f: Info => Unit): Unit = f(info)
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
new file mode 100644
index 00000000..91e77433
--- /dev/null
+++ b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
@@ -0,0 +1,457 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package firrtl.backends.experimental.smt.random
+
+import firrtl.Utils.{isLiteral, kind, BoolType}
+import firrtl.WrappedExpression.{we, weq}
+import firrtl._
+import firrtl.annotations.NoTargetAnnotation
+import firrtl.backends.experimental.smt._
+import firrtl.ir._
+import firrtl.options.Dependency
+import firrtl.passes.MemPortUtils.memPortField
+import firrtl.passes.memlib.AnalysisUtils.Connects
+import firrtl.passes.memlib.InferReadWritePass.checkComplement
+import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays}
+import firrtl.stage.Forms
+import firrtl.transforms.{ConstantPropagation, RemoveWires}
+
+import scala.collection.mutable
+
+/** Chooses which undefined memory behaviors should be instrumented. */
+case class UndefinedMemoryBehaviorOptions(
+ randomizeWriteWriteConflicts: Boolean = true,
+ assertNoOutOfBoundsWrites: Boolean = false,
+ randomizeOutOfBoundsRead: Boolean = true,
+ randomizeDisabledReads: Boolean = true,
+ randomizeReadWriteConflicts: Boolean = true)
+ extends NoTargetAnnotation
+
+/** Adds sources of randomness to model the various "undefined behaviors" of firrtl memory.
+ * - Write/Write conflict: leads to arbitrary value written to write address
+ * - Out-of-bounds write: assertion failure (disabled by default)
+ * - Out-Of-bounds read: leads to arbitrary value being read
+ * - Read w/ en=0: leads to arbitrary value being read
+ * - Read/Write conflict: leads to arbitrary value being read
+ */
+object UndefinedMemoryBehaviorPass extends Transform with DependencyAPIMigration {
+ override def prerequisites = Forms.LowForm
+ override def optionalPrerequisiteOf = Seq(Dependency(VerilogMemDelays))
+ override def invalidates(a: Transform) = a match {
+ // this pass might destroy SSA form, as we add a wire for the data field of every read port
+ case _: RemoveWires => true
+ // TODO: should we add some optimization passes here? we could be generating some dead code.
+ case _ => false
+ }
+
+ override protected def execute(state: CircuitState): CircuitState = {
+ val opts = state.annotations.collect { case o: UndefinedMemoryBehaviorOptions => o }
+ require(opts.size < 2, s"Multiple options: $opts")
+ val opt = opts.headOption.getOrElse(UndefinedMemoryBehaviorOptions())
+
+ val c = state.circuit.mapModule(onModule(_, opt))
+ state.copy(circuit = c)
+ }
+
+ private def onModule(m: DefModule, opt: UndefinedMemoryBehaviorOptions): DefModule = m match {
+ case mod: Module =>
+ val mems = findMems(mod)
+ if (mems.isEmpty) { mod }
+ else {
+ val namespace = Namespace(mod)
+ val connects = AnalysisUtils.getConnects(mod)
+ new InstrumentMems(opt, mems, connects, namespace).run(mod)
+ }
+ case other => other
+ }
+
+ /** finds all memory instantiations in a circuit */
+ private def findMems(m: Module): List[DefMemory] = {
+ val mems = mutable.ListBuffer[DefMemory]()
+ m.foreachStmt(findMems(_, mems))
+ mems.toList
+ }
+ private def findMems(s: Statement, mems: mutable.ListBuffer[DefMemory]): Unit = s match {
+ case mem: DefMemory => mems.append(mem)
+ case other => other.foreachStmt(findMems(_, mems))
+ }
+}
+
+private class InstrumentMems(
+ opt: UndefinedMemoryBehaviorOptions,
+ mems: List[DefMemory],
+ connects: Connects,
+ namespace: Namespace) {
+ def run(m: Module): DefModule = {
+ // ensure that all memories are the kind we can support
+ mems.foreach(checkSupported(m.name, _))
+
+ // transform circuit
+ val body = m.body.mapStmt(transform)
+ m.copy(body = Block(body +: newStmts.toList))
+ }
+
+ // used to replace memory signals like `m.r.data` in RHS expressions
+ private val exprReplacements = mutable.HashMap[String, Expression]()
+ // add new statements at the end of the circuit
+ private val newStmts = mutable.ListBuffer[Statement]()
+ // disconnect references so that they can be reassigned
+ private val doDisconnect = mutable.HashSet[String]()
+
+ // generates new expression replacements and immediately uses them
+ private def transform(s: Statement): Statement = s.mapStmt(transform) match {
+ case mem: DefMemory => onMem(mem)
+ case sx: Connect if doDisconnect.contains(sx.loc.serialize) => EmptyStmt // Filter old mem connections
+ case sx => sx.mapExpr(swapMemRefs)
+ }
+ private def swapMemRefs(e: Expression): Expression = e.mapExpr(swapMemRefs) match {
+ case sf: RefLikeExpression => exprReplacements.getOrElse(sf.serialize, sf)
+ case ex => ex
+ }
+
+ private def onMem(m: DefMemory): Statement = {
+ // collect wire and random statement defines
+ val declarations = mutable.ListBuffer[Statement]()
+
+ // only for non power of 2 memories do we have to worry about reading or writing out of bounds
+ val canBeOutOfBounds = !isPow2(m.depth)
+
+ // only if we have at least two write ports, can there be conflicts
+ val canHaveWriteWriteConflicts = m.writers.size > 1
+
+ // only certain memory types exhibit undefined read/write conflicts
+ val readWriteUndefined = (m.readLatency == m.writeLatency) && (m.readUnderWrite == ReadUnderWrite.Undefined)
+ assert(
+ m.readLatency == 0 || m.readLatency == m.writeLatency,
+ "TODO: what happens if a sync read mem has asymmetrical latencies?"
+ )
+
+ // a write port is enabled iff mask & en
+ val writeEn = m.writers.map { write =>
+ val enRef = memPortField(m, write, "en")
+ val maskRef = memPortField(m, write, "mask")
+
+ val prods = getProductTerms(enRef) ++ getProductTerms(maskRef)
+
+ // if we can have write/write conflicts, we are going to change the mask and enable pins
+ val expr = if (canHaveWriteWriteConflicts) {
+ val maskIsOne = isTrue(connects(maskRef.serialize))
+ // if the mask is connected to a constant true, we do not need to consider it, this is a common case
+ if (maskIsOne) {
+ val enWire = disconnectInput(m.info, enRef)
+ declarations += enWire
+ Reference(enWire)
+ } else {
+ val maskWire = disconnectInput(m.info, maskRef)
+ val enWire = disconnectInput(m.info, enRef)
+ // create a node for the conjunction
+ val nodeName = namespace.newName(s"${m.name}_${write}_mask_and_en")
+ val node = DefNode(m.info, nodeName, Utils.and(Reference(maskWire), Reference(enWire)))
+ declarations ++= List(maskWire, enWire, node)
+ Reference(node)
+ }
+ } else {
+ Utils.and(enRef, maskRef)
+ }
+ (expr, prods)
+ }
+
+ // implement the three undefined read behaviors
+ m.readers.foreach { read =>
+ // many memories have their read enable hard wired to true
+ val canBeDisabled = !isTrue(memPortField(m, read, "en"))
+ val readEn = if (canBeDisabled) memPortField(m, read, "en") else Utils.True()
+ val addr = memPortField(m, read, "addr")
+
+ // collect signals that would lead to a randomization
+ var doRand = List[Expression]()
+
+ // randomize the read value when the address is out of bounds
+ if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) {
+ val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr)))
+ val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond)
+ declarations += node
+ doRand = Reference(node) +: doRand
+ }
+
+ if (readWriteUndefined && opt.randomizeReadWriteConflicts) {
+ val (cond, d) = readWriteConflict(m, read, writeEn)
+ declarations ++= d
+ val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond)
+ declarations += node
+ doRand = Reference(node) +: doRand
+ }
+
+ // randomize the read value when the read is disabled
+ if (canBeDisabled && opt.randomizeDisabledReads) {
+ val cond = Utils.not(readEn)
+ val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_disabled"), cond)
+ declarations += node
+ doRand = Reference(node) +: doRand
+ }
+
+ // if there are no signals that would require a randomization, there is nothing to do
+ if (doRand.isEmpty) {
+ // nothing to do
+ } else {
+ val doRandName = s"${m.name}_${read}_do_rand"
+ val doRandNode = if (doRand.size == 1) { doRand.head }
+ else {
+ val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_do_rand"), doRand.reduce(Utils.or))
+ declarations += node
+ Reference(node)
+ }
+ val doRandSignal = if (m.readLatency == 0) { doRandNode }
+ else {
+ val clock = memPortField(m, read, "clk")
+ val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency)
+ declarations ++= regDecls
+ signal
+ }
+
+ // all old rhs references to m.r.data need to replace with m_r_data which might be random
+ val dataRef = memPortField(m, read, "data")
+ val dataWire = DefWire(m.info, namespace.newName(s"${m.name}_${read}_data"), m.dataType)
+ declarations += dataWire
+ exprReplacements(dataRef.serialize) = Reference(dataWire)
+
+ // create a source of randomness and connect the new wire either to the actual data port or to the random value
+ val randName = namespace.newName(s"${m.name}_${read}_rand_data")
+ val random = DefRandom(m.info, randName, m.dataType, Some(memPortField(m, read, "clk")), doRandSignal)
+ declarations += random
+ val data = Utils.mux(doRandSignal, Reference(random), dataRef)
+ newStmts.append(Connect(m.info, Reference(dataWire), data))
+ }
+ }
+
+ // write
+ if (opt.randomizeWriteWriteConflicts) {
+ declarations ++= writeWriteConflicts(m, writeEn)
+ }
+
+ // add an assertion that if the write is taking place, then the address must be in range
+ if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) {
+ m.writers.zip(writeEn).foreach {
+ case (write, (combinedEn, _)) =>
+ val addr = memPortField(m, write, "addr")
+ val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr))
+ val clk = memPortField(m, write, "clk")
+ val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read"))
+ newStmts.append(a)
+ }
+ }
+
+ Block(m +: declarations.toList)
+ }
+
+ private def pipeline(
+ info: Info,
+ clk: Expression,
+ prefix: String,
+ e: Expression,
+ latency: Int
+ ): (Expression, Seq[Statement]) = {
+ require(latency > 0)
+ val regs = (1 to latency).map { i =>
+ val name = namespace.newName(prefix + s"_r$i")
+ DefRegister(info, name, e.tpe, clk, Utils.False(), Reference(name, e.tpe, RegKind, UnknownFlow))
+ }
+ val expr = regs.foldLeft(e) {
+ case (prev, reg) =>
+ newStmts.append(Connect(info, Reference(reg), prev))
+ Reference(reg)
+ }
+ (expr, regs)
+ }
+
+ private def readWriteConflict(
+ m: DefMemory,
+ read: String,
+ writeEn: Seq[(Expression, ProdTerms)]
+ ): (Expression, Seq[Statement]) = {
+ if (m.writers.isEmpty) return (Utils.False(), List())
+ val declarations = mutable.ListBuffer[Statement]()
+
+ val readEn = memPortField(m, read, "en")
+ val readProd = getProductTerms(readEn)
+
+ // create all conflict signals
+ val conflicts = m.writers.zip(writeEn).map {
+ case (write, (writeEn, writeProd)) =>
+ if (isMutuallyExclusive(readProd, writeProd)) {
+ Utils.False()
+ } else {
+ val name = namespace.newName(s"${m.name}_${read}_${write}_rwc")
+ val bothEn = Utils.and(readEn, writeEn)
+ val sameAddr = Utils.eq(memPortField(m, read, "addr"), memPortField(m, write, "addr"))
+ // we need a wire because this condition might be used in a random statement
+ val wire = DefWire(m.info, name, BoolType)
+ declarations += wire
+ newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr)))
+ Reference(wire)
+ }
+ }
+
+ (conflicts.reduce(Utils.or), declarations.toList)
+ }
+
+ private type ProdTerms = Seq[Expression]
+ private def writeWriteConflicts(m: DefMemory, writeEn: Seq[(Expression, ProdTerms)]): Seq[Statement] = {
+ if (m.writers.size < 2) return List()
+ val declarations = mutable.ListBuffer[Statement]()
+
+ // we first create all conflict signals:
+ val conflict =
+ m.writers
+ .zip(writeEn)
+ .zipWithIndex
+ .flatMap {
+ case ((w1, (en1, en1Prod)), i1) =>
+ m.writers.zip(writeEn).drop(i1 + 1).map {
+ case (w2, (en2, en2Prod)) =>
+ if (isMutuallyExclusive(en1Prod, en2Prod)) {
+ (w1, w2) -> Utils.False()
+ } else {
+ val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc")
+ val bothEn = Utils.and(en1, en2)
+ val sameAddr = Utils.eq(memPortField(m, w1, "addr"), memPortField(m, w2, "addr"))
+ // we need a wire because this condition might be used in a random statement
+ val wire = DefWire(m.info, name, BoolType)
+ declarations += wire
+ newStmts.append(Connect(m.info, Reference(wire), Utils.and(bothEn, sameAddr)))
+ (w1, w2) -> Reference(wire)
+ }
+ }
+ }
+ .toMap
+
+ // now we calculate the new enable and data signals
+ m.writers.zip(writeEn).zipWithIndex.foreach {
+ case ((w1, (en1, _)), i1) =>
+ val prev = m.writers.take(i1)
+ val next = m.writers.drop(i1 + 1)
+
+ // the write is enabled if the original enable is true and there are no prior conflicts
+ val en = if (prev.isEmpty) {
+ en1
+ } else {
+ val prevConflicts = prev.map(o => conflict(o, w1)).reduce(Utils.or)
+ Utils.and(en1, Utils.not(prevConflicts))
+ }
+
+ // we write random data if there is a conflict with any of the next ports
+ if (next.isEmpty) {
+ // nothing to do, leave data as is
+ } else {
+ val nextConflicts = next.map(n => conflict(w1, n)).reduce(Utils.or)
+ // if the conflict expression is more complex, create a node for the signal
+ val hasConflict = nextConflicts match {
+ case _: DoPrim | _: Mux =>
+ val node = DefNode(m.info, namespace.newName(s"${m.name}_${w1}_wwc_active"), nextConflicts)
+ declarations += node
+ Reference(node)
+ case _ => nextConflicts
+ }
+
+ // create the source of randomness
+ val name = namespace.newName(s"${m.name}_${w1}_wwc_data")
+ val random = DefRandom(m.info, name, m.dataType, Some(memPortField(m, w1, "clk")), hasConflict)
+ declarations.append(random)
+ // replace the old data input
+ val dataWire = disconnectInput(m.info, memPortField(m, w1, "data"))
+ declarations += dataWire
+ // generate new data input
+ val data = Utils.mux(hasConflict, Reference(random), Reference(dataWire))
+ newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data))
+ }
+
+ // connect data enable signals
+ val maskIsOne = isTrue(connects(memPortField(m, w1, "mask").serialize))
+ if (!maskIsOne) {
+ newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True()))
+ }
+ newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en))
+ }
+
+ declarations.toList
+ }
+
+ /** check whether two signals can be proven to be mutually exclusive */
+ private def isMutuallyExclusive(prodA: ProdTerms, prodB: ProdTerms): Boolean = {
+ // this uses the same approach as the InferReadWrite pass
+ val proofOfMutualExclusion = prodA.find(a => prodB.exists(b => checkComplement(a, b)))
+ proofOfMutualExclusion.nonEmpty
+ }
+
+ /** replace a memory port with a wire */
+ private def disconnectInput(info: Info, signal: RefLikeExpression): DefWire = {
+ // disconnect the old value
+ doDisconnect.add(signal.serialize)
+
+ // if the old value is a literal, we just replace all references to it with this literal
+ val oldValue = connects(signal.serialize)
+ if (isLiteral(oldValue)) {
+ println("TODO: better code for literal")
+ }
+
+ // create a new wire and replace all references to the original port with this wire
+ val wire = DefWire(info, copyName(signal), signal.tpe)
+ exprReplacements(signal.serialize) = Reference(wire)
+ // connect the old expression to the new wire
+ val con = Connect(info, Reference(wire), connects(signal.serialize))
+ newStmts.append(con)
+
+ // the wire definition should end up right after the memory definition
+ wire
+ }
+
+ private def copyName(ref: RefLikeExpression): String =
+ namespace.newName(ref.serialize.replace('.', '_'))
+
+ private def isInBounds(depth: BigInt, addr: Expression): Expression = {
+ val width = getWidth(addr)
+ // depth >= addr
+ DoPrim(PrimOps.Geq, List(UIntLiteral(depth, width), addr), List(), BoolType)
+ }
+
+ private def isPow2(v: BigInt): Boolean = ((v - 1) & v) == 0
+
+ private def checkSupported(modName: String, m: DefMemory): Unit = {
+ assert(m.readwriters.isEmpty, s"[$modName] Combined read/write ports are currently not supported!")
+ if (m.writeLatency != 1) {
+ throw new UnsupportedFeatureException(s"[$modName] memories with write latency > 1 (${m.name})")
+ }
+ if (m.readLatency > 1) {
+ throw new UnsupportedFeatureException(s"[$modName] memories with read latency > 1 (${m.name})")
+ }
+ }
+
+ private def getProductTerms(e: Expression): ProdTerms =
+ InferReadWritePass.getProductTerms(connects)(e)
+
+ /** tries to expand the expression based on the connects we collected */
+ private def expandExpr(e: Expression, fuel: Int): Expression = {
+ e match {
+ case m @ Mux(cond, tval, fval, _) =>
+ m.copy(cond = expandExpr(cond, fuel), tval = expandExpr(tval, fuel), fval = expandExpr(fval, fuel))
+ case p @ DoPrim(_, args, _, _) =>
+ p.copy(args = args.map(expandExpr(_, fuel)))
+ case r: RefLikeExpression =>
+ if (fuel > 0) {
+ connects.get(r.serialize) match {
+ case None => r
+ case Some(expr) => expandExpr(expr, fuel - 1)
+ }
+ } else {
+ r
+ }
+ case other => other
+ }
+ }
+
+ private def isTrue(e: Expression): Boolean = simplifyExpr(expandExpr(e, fuel = 2)) == Utils.True()
+
+ private def simplifyExpr(e: Expression): Expression = {
+ e // TODO: better simplification could improve the resulting circuit size
+ }
+}
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index 9c3d6186..13ba3d46 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -4,6 +4,7 @@ package firrtl
package ir
import Utils.{dec2string, trim}
+import firrtl.backends.experimental.smt.random.DefRandom
import dataclass.{data, since}
import firrtl.constraint.{Constraint, IsKnown, IsVar}
import org.apache.commons.text.translate.{AggregateTranslator, JavaUnicodeEscaper, LookupTranslator}
@@ -242,6 +243,9 @@ object Reference {
/** Creates a Reference from a Register */
def apply(reg: DefRegister): Reference = Reference(reg.name, reg.tpe, RegKind, UnknownFlow)
+ /** Creates a Reference from a Random Source */
+ def apply(rnd: DefRandom): Reference = Reference(rnd.name, rnd.tpe, RandomKind, UnknownFlow)
+
/** Creates a Reference from a Node */
def apply(node: DefNode): Reference = Reference(node.name, node.value.tpe, NodeKind, SourceFlow)
diff --git a/src/main/scala/firrtl/ir/Serializer.scala b/src/main/scala/firrtl/ir/Serializer.scala
index caea0a9c..dca902fe 100644
--- a/src/main/scala/firrtl/ir/Serializer.scala
+++ b/src/main/scala/firrtl/ir/Serializer.scala
@@ -2,6 +2,8 @@
package firrtl.ir
+import firrtl.Utils
+import firrtl.backends.experimental.smt.random.DefRandom
import firrtl.constraint.Constraint
object Serializer {
@@ -114,6 +116,11 @@ object Serializer {
case DefRegister(info, name, tpe, clock, reset, init) =>
b ++= "reg "; b ++= name; b ++= " : "; s(tpe); b ++= ", "; s(clock); b ++= " with :"; newLineAndIndent(1)
b ++= "reset => ("; s(reset); b ++= ", "; s(init); b += ')'; s(info)
+ case DefRandom(info, name, tpe, clock, en) =>
+ b ++= "rand "; b ++= name; b ++= " : "; s(tpe);
+ if (clock.isDefined) { b ++= ", "; s(clock.get); }
+ en match { case Utils.True() => case _ => b ++= " when "; s(en) }
+ s(info)
case DefInstance(info, name, module, _) => b ++= "inst "; b ++= name; b ++= " of "; b ++= module; s(info)
case DefMemory(
info,
diff --git a/src/main/scala/firrtl/passes/ResolveKinds.scala b/src/main/scala/firrtl/passes/ResolveKinds.scala
index e3218467..745be1e2 100644
--- a/src/main/scala/firrtl/passes/ResolveKinds.scala
+++ b/src/main/scala/firrtl/passes/ResolveKinds.scala
@@ -5,6 +5,7 @@ package firrtl.passes
import firrtl._
import firrtl.ir._
import firrtl.Mappers._
+import firrtl.backends.experimental.smt.random.DefRandom
import firrtl.traversals.Foreachers._
object ResolveKinds extends Pass {
@@ -31,6 +32,7 @@ object ResolveKinds extends Pass {
case sx: DefRegister => kinds(sx.name) = RegKind
case sx: WDefInstance => kinds(sx.name) = InstanceKind
case sx: DefMemory => kinds(sx.name) = MemKind
+ case sx: DefRandom => kinds(sx.name) = RandomKind
case _ =>
}
s.map(resolve_stmt(kinds))
diff --git a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
index 8fb2dc88..143b925a 100644
--- a/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
+++ b/src/main/scala/firrtl/passes/memlib/VerilogMemDelays.scala
@@ -177,7 +177,8 @@ class MemDelayAndReadwriteTransformer(m: DefModule) {
object VerilogMemDelays extends Pass {
- override def prerequisites = firrtl.stage.Forms.LowForm :+ Dependency(firrtl.passes.RemoveValidIf)
+ override def prerequisites = firrtl.stage.Forms.LowForm
+ override val optionalPrerequisites = Seq(Dependency(firrtl.passes.RemoveValidIf))
override val optionalPrerequisiteOf =
Seq(Dependency[VerilogEmitter], Dependency[SystemVerilogEmitter])
diff --git a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
index 1d9bfd0e..f72585d1 100644
--- a/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
+++ b/src/main/scala/firrtl/transforms/DeadCodeElimination.scala
@@ -11,6 +11,7 @@ import firrtl.analyses.InstanceKeyGraph
import firrtl.Mappers._
import firrtl.Utils.{kind, throwInternalError}
import firrtl.MemoizedHash._
+import firrtl.backends.experimental.smt.random.DefRandom
import firrtl.options.{Dependency, RegisteredTransform, ShellOption}
import collection.mutable
@@ -126,6 +127,11 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend
val node = LogicNode(mod.name, name)
depGraph.addVertex(node)
Seq(clock, reset, init).flatMap(getDeps(_)).foreach(ref => depGraph.addPairWithEdge(node, ref))
+ case DefRandom(_, name, _, clock, en) =>
+ val node = LogicNode(mod.name, name)
+ depGraph.addVertex(node)
+ val inputs = clock ++: en +: Nil
+ inputs.flatMap(getDeps).foreach(ref => depGraph.addPairWithEdge(node, ref))
case DefNode(_, name, value) =>
val node = LogicNode(mod.name, name)
depGraph.addVertex(node)
@@ -225,6 +231,7 @@ class DeadCodeElimination extends Transform with RegisteredTransform with Depend
val tpe = decl match {
case _: DefNode => "node"
case _: DefRegister => "reg"
+ case _: DefRandom => "rand"
case _: DefWire => "wire"
case _: Port => "port"
case _: DefMemory => "mem"
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index f2907db2..7500b386 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -11,6 +11,7 @@ import firrtl.WrappedExpression._
import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
import firrtl.Utils.getGroundZero
+import firrtl.backends.experimental.smt.random.DefRandom
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
@@ -41,13 +42,13 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case _ => false
}
- // Extract all expressions that are references to a Node, Wire, or Reg
+ // Extract all expressions that are references to a Node, Wire, Reg or Rand
// Since we are operating on LowForm, they can only be WRefs
private def extractNodeWireRegRefs(expr: Expression): Seq[WRef] = {
val refs = mutable.ArrayBuffer.empty[WRef]
def rec(e: Expression): Expression = {
e match {
- case ref @ WRef(_, _, WireKind | NodeKind | RegKind, _) => refs += ref
+ case ref @ WRef(_, _, WireKind | NodeKind | RegKind | RandomKind, _) => refs += ref
case nested @ (_: Mux | _: DoPrim | _: ValidIf) => nested.foreach(rec)
case _ => // Do nothing
}
@@ -59,8 +60,9 @@ class RemoveWires extends Transform with DependencyAPIMigration {
// Transform netlist into DefNodes
private def getOrderedNodes(
- netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
- regInfo: mutable.Map[WrappedExpression, DefRegister]
+ netlist: mutable.LinkedHashMap[WrappedExpression, (Seq[Expression], Info)],
+ regInfo: mutable.Map[WrappedExpression, DefRegister],
+ randInfo: mutable.Map[WrappedExpression, DefRandom]
): Try[Seq[Statement]] = {
val digraph = new MutableDiGraph[WrappedExpression]
for ((sink, (exprs, _)) <- netlist) {
@@ -80,7 +82,8 @@ class RemoveWires extends Transform with DependencyAPIMigration {
ordered.map { key =>
val WRef(name, _, kind, _) = key.e1
kind match {
- case RegKind => regInfo(key)
+ case RegKind => regInfo(key)
+ case RandomKind => randInfo(key)
case WireKind | NodeKind =>
val (Seq(rhs), info) = netlist(key)
DefNode(info, name, rhs)
@@ -100,6 +103,8 @@ class RemoveWires extends Transform with DependencyAPIMigration {
val wireInfo = mutable.HashMap.empty[WrappedExpression, Info]
// Additional info about registers
val regInfo = mutable.HashMap.empty[WrappedExpression, DefRegister]
+ // Additional info about rand statements
+ val randInfo = mutable.HashMap.empty[WrappedExpression, DefRandom]
def onStmt(stmt: Statement): Statement = {
stmt match {
@@ -115,6 +120,9 @@ class RemoveWires extends Transform with DependencyAPIMigration {
val initDep = Some(reg.init).filter(we(WRef(reg)) != we(_)) // Dependency exists IF reg doesn't init itself
regInfo(we(WRef(reg))) = reg
netlist(we(WRef(reg))) = (Seq(reg.clock) ++ resetDep ++ initDep, reg.info)
+ case rand: DefRandom =>
+ randInfo(we(Reference(rand))) = rand
+ netlist(we(Reference(rand))) = (rand.clock ++: rand.en +: List(), rand.info)
case decl: CanBeReferenced =>
// Keep all declarations except for nodes and non-Analog wires and "other" statements.
// Thus this is expected to match DefInstance and DefMemory which both do not connect to
@@ -148,7 +156,7 @@ class RemoveWires extends Transform with DependencyAPIMigration {
m match {
case mod @ Module(info, name, ports, body) =>
onStmt(body)
- getOrderedNodes(netlist, regInfo) match {
+ getOrderedNodes(netlist, regInfo, randInfo) match {
case Success(logic) =>
Module(info, name, ports, Block(List() ++ decls ++ logic ++ otherStmts))
// If we hit a CyclicException, just abort removing wires
diff --git a/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala
new file mode 100644
index 00000000..c6f89b93
--- /dev/null
+++ b/src/test/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorSpec.scala
@@ -0,0 +1,360 @@
+package firrtl.backends.experimental.smt.random
+
+import firrtl.options.Dependency
+import firrtl.testutils.LeanTransformSpec
+
+class UndefinedMemoryBehaviorSpec extends LeanTransformSpec(Seq(Dependency(UndefinedMemoryBehaviorPass))) {
+ behavior.of("UndefinedMemoryBehaviorPass")
+
+ it should "model write-write conflicts between 2 ports" in {
+
+ val circuit = compile(UBMSources.writeWriteConflict, List()).circuit
+ // println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // a random value should be declared for the data written on a write-write conflict
+ assert(result.contains("rand m_a_wwc_data : UInt<32>, m.a.clk when m_a_b_wwc"))
+
+ // a write-write conflict occurs when both ports are enabled and the addresses match
+ assert(result.contains("m_a_b_wwc <= and(and(m_a_en, m_b_en), eq(m.a.addr, m.b.addr))"))
+
+ // the data of read port a depends on whether there is a write-write conflict
+ assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)"))
+
+ // the enable of read port b depends on whether there is a write-write conflict
+ assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))"))
+ }
+
+ it should "model write-write conflicts between 3 ports" in {
+
+ val circuit = compile(UBMSources.writeWriteConflict3, List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // when there is more than one next write port, a "active" node is created
+ assert(result.contains("node m_a_wwc_active = or(m_a_b_wwc, m_a_c_wwc)"))
+
+ // a random value should be declared for the data written on a write-write conflict
+ assert(result.contains("rand m_a_wwc_data : UInt<32>, m.a.clk when m_a_wwc_active"))
+ assert(result.contains("rand m_b_wwc_data : UInt<32>, m.b.clk when m_b_c_wwc"))
+
+ // a write-write conflict occurs when both ports are enabled and the addresses match
+ Seq(("a", "b"), ("a", "c"), ("b", "c")).foreach {
+ case (w1, w2) =>
+ assert(
+ result.contains(s"m_${w1}_${w2}_wwc <= and(and(m_${w1}_en, m_${w2}_en), eq(m.${w1}.addr, m.${w2}.addr))")
+ )
+ }
+
+ // the data of read port a depends on whether there is a write-write conflict
+ assert(result.contains("m.a.data <= mux(m_a_wwc_active, m_a_wwc_data, m_a_data)"))
+
+ // the data of read port b depends on whether there is a write-write conflict
+ assert(result.contains("m.b.data <= mux(m_b_c_wwc, m_b_wwc_data, m_b_data)"))
+
+ // the enable of read port b depends on whether there is a write-write conflict
+ assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))"))
+
+ // the enable of read port c depends on whether there is a write-write conflict
+ // note that in this case we do not add an extra node since the disjunction is only used once
+ assert(result.contains("m.c.en <= and(m_c_en, not(or(m_a_c_wwc, m_b_c_wwc)))"))
+ }
+
+ it should "model write-write conflicts more efficiently when ports are mutually exclusive" in {
+
+ val circuit = compile(UBMSources.writeWriteConflict3Exclusive, List()).circuit
+ // println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // we should not compute the conflict between a and c since it is impossible
+ assert(!result.contains("node m_a_c_wwc = and(and(m_a_en, m_c_en), eq(m.a.addr, m.c.addr))"))
+
+ // the enable of port b depends on whether there is a conflict with a
+ assert(result.contains("m.b.en <= and(m_b_en, not(m_a_b_wwc))"))
+
+ // the data of port b depends on whether these is a conflict with c
+ assert(result.contains("m.b.data <= mux(m_b_c_wwc, m_b_wwc_data, m_b_data)"))
+
+ // the enable of port c only depend on whether there is a conflict with b since c and a cannot conflict
+ assert(result.contains("m.c.en <= and(m_c_en, not(m_b_c_wwc))"))
+
+ // the data of port a only depends on whether there is a conflict with b since a and c cannot conflict
+ assert(result.contains("m.a.data <= mux(m_a_b_wwc, m_a_wwc_data, m_a_data)"))
+ }
+
+ it should "assert out-of-bounds writes when told to" in {
+ val anno = List(UndefinedMemoryBehaviorOptions(assertNoOutOfBoundsWrites = true))
+
+ val circuit = compile(UBMSources.readWrite(30, 0), anno).circuit
+ // println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ assert(
+ result.contains(
+ """assert(m.a.clk, or(not(and(m.a.en, m.a.mask)), geq(UInt<5>("h1e"), m.a.addr)), UInt<1>("h1"), "out of bounds read")"""
+ )
+ )
+ }
+
+ it should "model out-of-bounds reads" in {
+ val circuit = compile(UBMSources.readWrite(30, 0), List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // an out of bounds read happens if the depth is not greater or equal to the address
+ assert(result.contains("node m_r_oob = not(geq(UInt<5>(\"h1e\"), m.r.addr))"))
+
+ // the source of randomness needs to be triggered when there is an out of bounds read
+ assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_oob"))
+
+ // the data is random when there is an oob
+ assert(result.contains("m_r_data <= mux(m_r_oob, m_r_rand_data, m.r.data)"))
+ }
+
+ it should "model un-enabled reads w/o out-of-bounds" in {
+ // without possible out-of-bounds
+ val circuit = compile(UBMSources.readEnable(32), List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // the memory is disabled when it is not enabled
+ assert(result.contains("node m_r_disabled = not(m.r.en)"))
+
+ // the source of randomness needs to be triggered when there is an read while the port is disabled
+ assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_disabled"))
+
+ // the data is random when there is an un-enabled read
+ assert(result.contains("m_r_data <= mux(m_r_disabled, m_r_rand_data, m.r.data)"))
+ }
+
+ it should "model un-enabled reads with out-of-bounds" in {
+ // with possible out-of-bounds
+ val circuit = compile(UBMSources.readEnable(30), List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // the memory is disabled when it is not enabled
+ assert(result.contains("node m_r_disabled = not(m.r.en)"))
+
+ // an out of bounds read happens if the depth is not greater or equal to the address and the memory is enabled
+ assert(result.contains("node m_r_oob = and(m.r.en, not(geq(UInt<5>(\"h1e\"), m.r.addr)))"))
+
+ // the two possible issues are combined into a single signal
+ assert(result.contains("node m_r_do_rand = or(m_r_disabled, m_r_oob)"))
+
+ // the source of randomness needs to be triggered when either issue occurs
+ assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand"))
+
+ // the data is random when either issue occurs
+ assert(result.contains("m_r_data <= mux(m_r_do_rand, m_r_rand_data, m.r.data)"))
+ }
+
+ it should "model un-enabled reads with out-of-bounds with read pipelining" in {
+ // with read latency one, we need to pipeline the `do_rand` signal
+ val circuit = compile(UBMSources.readEnable(30, 1), List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // pipeline register
+ assert(result.contains("m_r_do_rand_r1 <= m_r_do_rand"))
+
+ // the source of randomness needs to be triggered by the pipeline register
+ assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand_r1"))
+
+ // the data is random when the pipeline register is 1
+ assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)"))
+ }
+
+ it should "model read/write conflicts when they are undefined" in {
+ val circuit = compile(UBMSources.readWrite(32, 1), List()).circuit
+ //println(circuit.serialize)
+ val result = circuit.serialize.split('\n').map(_.trim)
+
+ // detect read/write conflicts
+ assert(result.contains("m_r_a_rwc <= and(and(m.r.en, and(m.a.en, m.a.mask)), eq(m.r.addr, m.a.addr))"))
+
+ // delay the signal
+ assert(result.contains("m_r_do_rand_r1 <= m_r_rwc"))
+
+ // randomize the data
+ assert(result.contains("rand m_r_rand_data : UInt<32>, m.r.clk when m_r_do_rand_r1"))
+ assert(result.contains("m_r_data <= mux(m_r_do_rand_r1, m_r_rand_data, m.r.data)"))
+ }
+}
+
+private object UBMSources {
+
+ val writeWriteConflict =
+ s"""
+ |circuit Test:
+ | module Test:
+ | input c : Clock
+ | input preset: AsyncReset
+ | input addr : UInt<8>
+ | input data : UInt<32>
+ | input aEn : UInt<1>
+ | input bEn : UInt<1>
+ |
+ | mem m:
+ | data-type => UInt<32>
+ | depth => 32
+ | reader => r
+ | writer => a, b
+ | read-latency => 0
+ | write-latency => 1
+ | read-under-write => undefined
+ |
+ | m.r.clk <= c
+ | m.r.en <= UInt(1)
+ | m.r.addr <= addr
+ |
+ | ; both read ports write to the same address and the same data
+ | m.a.clk <= c
+ | m.a.en <= aEn
+ | m.a.addr <= addr
+ | m.a.data <= data
+ | m.a.mask <= UInt(1)
+ | m.b.clk <= c
+ | m.b.en <= bEn
+ | m.b.addr <= addr
+ | m.b.data <= data
+ | m.b.mask <= UInt(1)
+ """.stripMargin
+
+ val writeWriteConflict3 =
+ s"""
+ |circuit Test:
+ | module Test:
+ | input c : Clock
+ | input preset: AsyncReset
+ | input addr : UInt<8>
+ | input data : UInt<32>
+ | input aEn : UInt<1>
+ | input bEn : UInt<1>
+ | input cEn : UInt<1>
+ |
+ | mem m:
+ | data-type => UInt<32>
+ | depth => 32
+ | reader => r
+ | writer => a, b, c
+ | read-latency => 0
+ | write-latency => 1
+ | read-under-write => undefined
+ |
+ | m.r.clk <= c
+ | m.r.en <= UInt(1)
+ | m.r.addr <= addr
+ |
+ | ; both read ports write to the same address and the same data
+ | m.a.clk <= c
+ | m.a.en <= aEn
+ | m.a.addr <= addr
+ | m.a.data <= data
+ | m.a.mask <= UInt(1)
+ | m.b.clk <= c
+ | m.b.en <= bEn
+ | m.b.addr <= addr
+ | m.b.data <= data
+ | m.b.mask <= UInt(1)
+ | m.c.clk <= c
+ | m.c.en <= cEn
+ | m.c.addr <= addr
+ | m.c.data <= data
+ | m.c.mask <= UInt(1)
+ """.stripMargin
+
+ val writeWriteConflict3Exclusive =
+ s"""
+ |circuit Test:
+ | module Test:
+ | input c : Clock
+ | input preset: AsyncReset
+ | input addr : UInt<8>
+ | input data : UInt<32>
+ | input aEn : UInt<1>
+ | input bEn : UInt<1>
+ |
+ | mem m:
+ | data-type => UInt<32>
+ | depth => 32
+ | reader => r
+ | writer => a, b, c
+ | read-latency => 0
+ | write-latency => 1
+ | read-under-write => undefined
+ |
+ | m.r.clk <= c
+ | m.r.en <= UInt(1)
+ | m.r.addr <= addr
+ |
+ | ; both read ports write to the same address and the same data
+ | m.a.clk <= c
+ | m.a.en <= aEn
+ | m.a.addr <= addr
+ | m.a.data <= data
+ | m.a.mask <= UInt(1)
+ | m.b.clk <= c
+ | m.b.en <= bEn
+ | m.b.addr <= addr
+ | m.b.data <= data
+ | m.b.mask <= UInt(1)
+ | m.c.clk <= c
+ | m.c.en <= not(aEn)
+ | m.c.addr <= addr
+ | m.c.data <= data
+ | m.c.mask <= UInt(1)
+ """.stripMargin
+
+ def readWrite(depth: Int, readLatency: Int) =
+ s"""circuit CollisionTest:
+ | module CollisionTest:
+ | input c : Clock
+ | input preset: AsyncReset
+ | input addr : UInt<8>
+ | input data : UInt<32>
+ | output dataOut : UInt<32>
+ |
+ | mem m:
+ | data-type => UInt<32>
+ | depth => $depth
+ | reader => r
+ | writer => a
+ | read-latency => $readLatency
+ | write-latency => 1
+ | read-under-write => undefined
+ |
+ | m.r.clk <= c
+ | m.r.en <= UInt(1)
+ | m.r.addr <= addr
+ | dataOut <= m.r.data
+ |
+ | m.a.clk <= c
+ | m.a.mask <= UInt(1)
+ | m.a.en <= UInt(1)
+ | m.a.addr <= addr
+ | m.a.data <= data
+ |""".stripMargin
+
+ def readEnable(depth: Int, latency: Int = 0) =
+ s"""circuit Test:
+ | module Test:
+ | input c : Clock
+ | input addr : UInt<8>
+ | input en : UInt<1>
+ | output data : UInt<32>
+ |
+ | mem m:
+ | data-type => UInt<32>
+ | depth => $depth
+ | reader => r
+ | read-latency => $latency
+ | write-latency => 1
+ | read-under-write => old
+ |
+ | m.r.clk <= c
+ | m.r.en <= en
+ | m.r.addr <= addr
+ | data <= m.r.data
+ |""".stripMargin
+}