aboutsummaryrefslogtreecommitdiff
path: root/src/main
diff options
context:
space:
mode:
authorKevin Laeufer2021-03-08 15:39:15 -0800
committerGitHub2021-03-08 23:39:15 +0000
commit29d57a612df69ae4a6db4b3755fc292e5a539e11 (patch)
tree653723bacce2419c76c64dfcf847e47e81980905 /src/main
parentc93d6f5319efd7ba42147180c6e2b6f3796ef943 (diff)
SMT: memory port inout fields cannot be used as RHS expressions (#2105)
* SMT: memory port inout fields cannot be used as RHS expressions * smt: add end2end check for read enable modelling
Diffstat (limited to 'src/main')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala146
1 files changed, 75 insertions, 71 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
index 91e77433..5fd0e680 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/random/UndefinedMemoryBehaviorPass.scala
@@ -2,8 +2,7 @@
package firrtl.backends.experimental.smt.random
-import firrtl.Utils.{isLiteral, kind, BoolType}
-import firrtl.WrappedExpression.{we, weq}
+import firrtl.Utils.{isLiteral, BoolType}
import firrtl._
import firrtl.annotations.NoTargetAnnotation
import firrtl.backends.experimental.smt._
@@ -14,7 +13,7 @@ import firrtl.passes.memlib.AnalysisUtils.Connects
import firrtl.passes.memlib.InferReadWritePass.checkComplement
import firrtl.passes.memlib.{AnalysisUtils, InferReadWritePass, VerilogMemDelays}
import firrtl.stage.Forms
-import firrtl.transforms.{ConstantPropagation, RemoveWires}
+import firrtl.transforms.RemoveWires
import scala.collection.mutable
@@ -111,7 +110,10 @@ private class InstrumentMems(
private def onMem(m: DefMemory): Statement = {
// collect wire and random statement defines
- val declarations = mutable.ListBuffer[Statement]()
+ implicit val declarations: mutable.ListBuffer[Statement] = mutable.ListBuffer[Statement]()
+
+ // cache for the expressions of memory inputs
+ implicit val cache: mutable.HashMap[String, Expression] = mutable.HashMap[String, Expression]()
// only for non power of 2 memories do we have to worry about reading or writing out of bounds
val canBeOutOfBounds = !isPow2(m.depth)
@@ -132,42 +134,23 @@ private class InstrumentMems(
val maskRef = memPortField(m, write, "mask")
val prods = getProductTerms(enRef) ++ getProductTerms(maskRef)
+ val expr = Utils.and(readInput(m.info, enRef), readInput(m.info, maskRef))
- // if we can have write/write conflicts, we are going to change the mask and enable pins
- val expr = if (canHaveWriteWriteConflicts) {
- val maskIsOne = isTrue(connects(maskRef.serialize))
- // if the mask is connected to a constant true, we do not need to consider it, this is a common case
- if (maskIsOne) {
- val enWire = disconnectInput(m.info, enRef)
- declarations += enWire
- Reference(enWire)
- } else {
- val maskWire = disconnectInput(m.info, maskRef)
- val enWire = disconnectInput(m.info, enRef)
- // create a node for the conjunction
- val nodeName = namespace.newName(s"${m.name}_${write}_mask_and_en")
- val node = DefNode(m.info, nodeName, Utils.and(Reference(maskWire), Reference(enWire)))
- declarations ++= List(maskWire, enWire, node)
- Reference(node)
- }
- } else {
- Utils.and(enRef, maskRef)
- }
(expr, prods)
}
// implement the three undefined read behaviors
m.readers.foreach { read =>
// many memories have their read enable hard wired to true
- val canBeDisabled = !isTrue(memPortField(m, read, "en"))
- val readEn = if (canBeDisabled) memPortField(m, read, "en") else Utils.True()
- val addr = memPortField(m, read, "addr")
+ val canBeDisabled = !isTrue(readInput(m, read, "en"))
+ val readEn = if (canBeDisabled) readInput(m, read, "en") else Utils.True()
// collect signals that would lead to a randomization
var doRand = List[Expression]()
// randomize the read value when the address is out of bounds
if (canBeOutOfBounds && opt.randomizeOutOfBoundsRead) {
+ val addr = readInput(m, read, "addr")
val cond = Utils.and(readEn, Utils.not(isInBounds(m.depth, addr)))
val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_oob"), cond)
declarations += node
@@ -175,8 +158,7 @@ private class InstrumentMems(
}
if (readWriteUndefined && opt.randomizeReadWriteConflicts) {
- val (cond, d) = readWriteConflict(m, read, writeEn)
- declarations ++= d
+ val cond = readWriteConflict(m, read, writeEn)
val node = DefNode(m.info, namespace.newName(s"${m.name}_${read}_rwc"), cond)
declarations += node
doRand = Reference(node) +: doRand
@@ -203,7 +185,7 @@ private class InstrumentMems(
}
val doRandSignal = if (m.readLatency == 0) { doRandNode }
else {
- val clock = memPortField(m, read, "clk")
+ val clock = readInput(m, read, "clk")
val (signal, regDecls) = pipeline(m.info, clock, doRandName, doRandNode, m.readLatency)
declarations ++= regDecls
signal
@@ -217,7 +199,7 @@ private class InstrumentMems(
// create a source of randomness and connect the new wire either to the actual data port or to the random value
val randName = namespace.newName(s"${m.name}_${read}_rand_data")
- val random = DefRandom(m.info, randName, m.dataType, Some(memPortField(m, read, "clk")), doRandSignal)
+ val random = DefRandom(m.info, randName, m.dataType, Some(readInput(m, read, "clk")), doRandSignal)
declarations += random
val data = Utils.mux(doRandSignal, Reference(random), dataRef)
newStmts.append(Connect(m.info, Reference(dataWire), data))
@@ -226,16 +208,16 @@ private class InstrumentMems(
// write
if (opt.randomizeWriteWriteConflicts) {
- declarations ++= writeWriteConflicts(m, writeEn)
+ writeWriteConflicts(m, writeEn)
}
// add an assertion that if the write is taking place, then the address must be in range
if (canBeOutOfBounds && opt.assertNoOutOfBoundsWrites) {
m.writers.zip(writeEn).foreach {
case (write, (combinedEn, _)) =>
- val addr = memPortField(m, write, "addr")
+ val addr = readInput(m, write, "addr")
val cond = Utils.implies(combinedEn, isInBounds(m.depth, addr))
- val clk = memPortField(m, write, "clk")
+ val clk = readInput(m, write, "clk")
val a = Verification(Formal.Assert, m.info, clk, cond, Utils.True(), StringLit("out of bounds read"))
newStmts.append(a)
}
@@ -268,11 +250,13 @@ private class InstrumentMems(
m: DefMemory,
read: String,
writeEn: Seq[(Expression, ProdTerms)]
- ): (Expression, Seq[Statement]) = {
- if (m.writers.isEmpty) return (Utils.False(), List())
- val declarations = mutable.ListBuffer[Statement]()
+ )(
+ implicit cache: mutable.HashMap[String, Expression],
+ declarations: mutable.ListBuffer[Statement]
+ ): Expression = {
+ if (m.writers.isEmpty) return Utils.False()
- val readEn = memPortField(m, read, "en")
+ val readEn = readInput(m, read, "en")
val readProd = getProductTerms(readEn)
// create all conflict signals
@@ -283,7 +267,7 @@ private class InstrumentMems(
} else {
val name = namespace.newName(s"${m.name}_${read}_${write}_rwc")
val bothEn = Utils.and(readEn, writeEn)
- val sameAddr = Utils.eq(memPortField(m, read, "addr"), memPortField(m, write, "addr"))
+ val sameAddr = Utils.eq(readInput(m, read, "addr"), readInput(m, write, "addr"))
// we need a wire because this condition might be used in a random statement
val wire = DefWire(m.info, name, BoolType)
declarations += wire
@@ -292,13 +276,18 @@ private class InstrumentMems(
}
}
- (conflicts.reduce(Utils.or), declarations.toList)
+ conflicts.reduce(Utils.or)
}
private type ProdTerms = Seq[Expression]
- private def writeWriteConflicts(m: DefMemory, writeEn: Seq[(Expression, ProdTerms)]): Seq[Statement] = {
- if (m.writers.size < 2) return List()
- val declarations = mutable.ListBuffer[Statement]()
+ private def writeWriteConflicts(
+ m: DefMemory,
+ writeEn: Seq[(Expression, ProdTerms)]
+ )(
+ implicit cache: mutable.HashMap[String, Expression],
+ declarations: mutable.ListBuffer[Statement]
+ ): Unit = {
+ if (m.writers.size < 2) return
// we first create all conflict signals:
val conflict =
@@ -314,7 +303,7 @@ private class InstrumentMems(
} else {
val name = namespace.newName(s"${m.name}_${w1}_${w2}_wwc")
val bothEn = Utils.and(en1, en2)
- val sameAddr = Utils.eq(memPortField(m, w1, "addr"), memPortField(m, w2, "addr"))
+ val sameAddr = Utils.eq(readInput(m, w1, "addr"), readInput(m, w2, "addr"))
// we need a wire because this condition might be used in a random statement
val wire = DefWire(m.info, name, BoolType)
declarations += wire
@@ -355,25 +344,24 @@ private class InstrumentMems(
// create the source of randomness
val name = namespace.newName(s"${m.name}_${w1}_wwc_data")
- val random = DefRandom(m.info, name, m.dataType, Some(memPortField(m, w1, "clk")), hasConflict)
+ val random = DefRandom(m.info, name, m.dataType, Some(readInput(m, w1, "clk")), hasConflict)
declarations.append(random)
- // replace the old data input
- val dataWire = disconnectInput(m.info, memPortField(m, w1, "data"))
- declarations += dataWire
+
// generate new data input
- val data = Utils.mux(hasConflict, Reference(random), Reference(dataWire))
+ val data = Utils.mux(hasConflict, Reference(random), readInput(m, w1, "data"))
newStmts.append(Connect(m.info, memPortField(m, w1, "data"), data))
+ doDisconnect.add(memPortField(m, w1, "data").serialize)
}
// connect data enable signals
- val maskIsOne = isTrue(connects(memPortField(m, w1, "mask").serialize))
+ val maskIsOne = isTrue(readInput(m, w1, "mask"))
if (!maskIsOne) {
newStmts.append(Connect(m.info, memPortField(m, w1, "mask"), Utils.True()))
+ doDisconnect.add(memPortField(m, w1, "mask").serialize)
}
newStmts.append(Connect(m.info, memPortField(m, w1, "en"), en))
+ doDisconnect.add(memPortField(m, w1, "en").serialize)
}
-
- declarations.toList
}
/** check whether two signals can be proven to be mutually exclusive */
@@ -383,27 +371,43 @@ private class InstrumentMems(
proofOfMutualExclusion.nonEmpty
}
- /** replace a memory port with a wire */
- private def disconnectInput(info: Info, signal: RefLikeExpression): DefWire = {
- // disconnect the old value
- doDisconnect.add(signal.serialize)
-
- // if the old value is a literal, we just replace all references to it with this literal
- val oldValue = connects(signal.serialize)
- if (isLiteral(oldValue)) {
- println("TODO: better code for literal")
- }
+ /** memory inputs my not be read, only assigned to, thus we might need to add a wire to make them accessible */
+ private def readInput(
+ info: Info,
+ signal: RefLikeExpression
+ )(
+ implicit cache: mutable.HashMap[String, Expression],
+ declarations: mutable.ListBuffer[Statement]
+ ): Expression =
+ cache.getOrElseUpdate(
+ signal.serialize, {
+ // if it is a literal, we just return it
+ val value = connects(signal.serialize)
+ if (isLiteral(value)) {
+ value
+ } else {
+ // otherwise we make a wire that refelect the value
+ val wire = DefWire(info, copyName(signal), signal.tpe)
+ declarations += wire
- // create a new wire and replace all references to the original port with this wire
- val wire = DefWire(info, copyName(signal), signal.tpe)
- exprReplacements(signal.serialize) = Reference(wire)
- // connect the old expression to the new wire
- val con = Connect(info, Reference(wire), connects(signal.serialize))
- newStmts.append(con)
+ // connect the old expression to the new wire
+ val con = Connect(info, Reference(wire), value)
+ newStmts.append(con)
- // the wire definition should end up right after the memory definition
- wire
- }
+ // use a reference to this new wire
+ Reference(wire)
+ }
+ }
+ )
+ private def readInput(
+ m: DefMemory,
+ port: String,
+ field: String
+ )(
+ implicit cache: mutable.HashMap[String, Expression],
+ declarations: mutable.ListBuffer[Statement]
+ ): Expression =
+ readInput(m.info, memPortField(m, port, field))
private def copyName(ref: RefLikeExpression): String =
namespace.newName(ref.serialize.replace('.', '_'))